### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributions as dist
from torchsummary import summary
import math
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from pathlib import Path
import re
from skimage.metrics  import structural_similarity as ssim
import cv2
import pickle

from importlib import reload
import visualization

# locals
import model_architectures

reload(model_architectures)
from model_architectures import VAE, Data3D

reload(visualization)
from visualization import brain_diff

  from .autonotebook import tqdm as notebook_tqdm


### Define Paths

In [2]:
research_dir = r"D:/school/research"
code_dir = os.path.join(research_dir, "code")
p2_dir = os.path.join(code_dir, "paper_two_code")
model_dir = os.path.join(p2_dir, "models")
metrics_dir = os.path.join(p2_dir, "metrics")
mmmetrics_dir = os.path.join(metrics_dir, "multimodal")
data_dir = os.path.join(research_dir, "data")
dhcp_rel2 = os.path.join(data_dir, "dhcp_rel2")
processed_dir = os.path.join(dhcp_rel2, "processed")
volume_dir = os.path.join(processed_dir, "volumes")
l1_dir = os.path.join(volume_dir, "l1")
l5_dir = os.path.join(volume_dir, "l5")

### Get Data

In [3]:
np.random.seed(42)
num_samples = int(len(os.listdir(l1_dir)) / 2)
samples = np.array([i for i in range(0, num_samples)])
np.random.shuffle(samples)

split_val = int(0.8 * num_samples)
train_indices = samples[0:split_val]
val_indices = samples[split_val:]

num_test = int(len(os.listdir(l5_dir)) / 2)
test_indices = np.array([i for i in range(0, num_test)])

train = Data3D(l1_dir, train_indices, t2_only=False)
val = Data3D(l1_dir, val_indices, t2_only=False)
test = Data3D(l5_dir, test_indices, t2_only=False)

batch_size = 1
train_loader = DataLoader(train, batch_size=batch_size)#, num_workers=1)
val_loader = DataLoader(val, batch_size=batch_size)#, num_workers=1)

### Helper Functions

In [4]:
std_bands = [
    {
        "range": "-inf to -4",
        "low": -np.inf,
        "high": -4
    },
    {
        "range": "-4 to -3",
        "low": -4,
        "high": -3
    },
    {
        "range": "-3 to -2",
        "low": -3,
        "high": -2
    },
    {
        "range": "2 to 3",
        "low": 2,
        "high": 3
    },
    {
        "range": "3 to 4",
        "low": 3,
        "high": 4
    },
    {
        "range": "4 to inf",
        "low": 4,
        "high": np.inf
    }
]

def connected_components_2d(diff):
    _, labels = cv2.connectedComponents(diff)
    
    return labels


def calculate_clusters(diff, threshold=2):
    cluster_diff = diff.copy()
    filter_mask = (diff > -threshold) & (diff < threshold)
    cluster_diff[filter_mask] = 0
    cluster_diff[cluster_diff != 0] = 1
    prepared_diff = cluster_diff.astype('uint8')
    sizes = []
    
    for idx in range(0, 256):
        clusters = connected_components_2d(prepared_diff[:,:,idx])
        values, counts = np.unique(clusters, return_counts=True)
        sizes.extend(counts[1:])
        
    return sizes

def get_overall_metrics(model, data):
    criterion = nn.MSELoss()
    losses = []
    clusters = []
    banded_metrics = {}
    for band in std_bands:
        banded_metrics[f"{band['range']} pixels"] = []
        banded_metrics[f"{band['range']} clusters"] = []

    for sample in data:
        x = torch.Tensor(sample).reshape((1,) + sample.shape).cuda()
        pred = model(x)
        
        losses.append(float(criterion(x, pred).cpu()))
        
        og = sample
        pred = pred.reshape(og.shape).detach().cpu().numpy()
        
        diff = np.sum(og, axis=0) - np.sum(pred, axis=0)
        diff_norm = diff / 0.1
        
        clusters.append(calculate_clusters(diff_norm))
        
        for band in std_bands:
            filter_mask = (diff_norm > band["low"]) & (diff_norm < band["high"])
            banded_diff = diff_norm.copy()
            banded_diff[np.invert(filter_mask)] = 0
            banded_diff[banded_diff != 0] = 1
            banded_diff = banded_diff.astype('uint8')
            
            banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))
            
            sizes = []
            for idx in range(0, 256):
                cc = connected_components_2d(banded_diff[:,:,idx])
                values, counts = np.unique(cc, return_counts=True)
                sizes.extend(counts[1:])
            banded_metrics[f"{band['range']} clusters"].append(sizes)

    return losses, clusters, banded_metrics

def save_values(loss, cluster, bands, name):
    np.save(os.path.join(mmmetrics_dir, f"{name}_losses.npy"), loss)
    np.save(os.path.join(mmmetrics_dir, f"{name}_clusters.npy"), cluster)
    with open(os.path.join(mmmetrics_dir, f"{name}_banded_metrics.pkl"), "wb") as f:
        pickle.dump(bands,f)

### Get Metrics 

In [158]:
%%time

model_harness = [
    {
        "info": "t1_t2",
        "model_name": "vae_rel2_t1_t2_second_session.pt",
        "t1_only": False,
        "t2_only": False,
        "model": VAE(2)
    },
    {
        "info": "t2",
        "model_name": "vae_rel2_t2_second_session.pt",
        "t1_only": False,
        "t2_only": True,
        "model": VAE(1)
    },
    {
        "info": "t1",
        "model_name": "vae_rel2_t1_second_session.pt",
        "t1_only": True,
        "t2_only": False,
        "model": VAE(1)
    }
]

for model_info in model_harness:
    print(model_info["info"])
    # Define data
    train = Data3D(l1_dir, train_indices, t1_only=model_info["t1_only"], t2_only=model_info["t2_only"])
    val = Data3D(l1_dir, val_indices, t1_only=model_info["t1_only"], t2_only=model_info["t2_only"])
    test = Data3D(l5_dir, test_indices, t1_only=model_info["t1_only"], t2_only=model_info["t2_only"])

    # Define model
    model = model_info["model"]
    model_path = os.path.join(model_dir, model_info["model_name"])
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()
    
    print("train")
    train_losses, train_clusters, train_bands = get_overall_metrics(model, train)
    save_values(train_losses, train_clusters, train_bands, f"{model_info['info']}_train")
    
    print("val")
    val_losses, val_clusters, val_bands = get_overall_metrics(model, val)
    save_values(val_losses, val_clusters, val_bands, f"{model_info['info']}_val")
    
    print("test")
    test_losses, test_clusters, test_bands = get_overall_metrics(model, test)
    save_values(test_losses, test_clusters, test_bands, f"{model_info['info']}_test")
    
    

t1_t2
train


  arr = np.asanyarray(arr)


val
test
t2
train
val
test
t1
train
val
test
CPU times: total: 11h 56min 43s
Wall time: 53min 6s


### Get Split Metrics for Multimodal

In [88]:
def get_multimodal_metrics(model, data):
    criterion = nn.MSELoss()
    t1_losses = []
    t2_losses = []
    t1_clusters = []
    t2_clusters = []
    t1_banded_metrics = {}
    for band in std_bands:
        t1_banded_metrics[f"{band['range']} pixels"] = []
        t1_banded_metrics[f"{band['range']} clusters"] = []
    t2_banded_metrics = {}
    for band in std_bands:
        t2_banded_metrics[f"{band['range']} pixels"] = []
        t2_banded_metrics[f"{band['range']} clusters"] = []
        
    for sample in data:
        x = torch.Tensor(sample).reshape((1,) + sample.shape).cuda()
        pred = model(x)
        
        t1_losses.append(float(criterion(x[0][0], pred[0][0]).cpu()))
        t2_losses.append(float(criterion(x[0][1], pred[0][1]).cpu()))
        
        t1_og = sample[0]
        t2_og = sample[1]
        pred = pred.reshape(sample.shape).detach().cpu().numpy()
        t1_pred = pred[0]
        t2_pred = pred[1]
        
        
        t1_diff = t1_og - t1_pred
        t1_diff_norm = t1_diff / 0.1
        
        t2_diff = t2_og - t2_pred
        t2_diff_norm = t2_diff / 0.1
        
        t1_clusters.append(calculate_clusters(t1_diff_norm))
        t2_clusters.append(calculate_clusters(t2_diff_norm))
        
        for band in std_bands:
            filter_mask = (t1_diff_norm > band["low"]) & (t1_diff_norm < band["high"])
            banded_diff = t1_diff_norm.copy()
            banded_diff[np.invert(filter_mask)] = 0
            banded_diff[banded_diff != 0] = 1
            banded_diff = banded_diff.astype('uint8')
            
            t1_banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))
            
            sizes = []
            for idx in range(0, 256):
                cc = connected_components_2d(banded_diff[:,:,idx])
                values, counts = np.unique(cc, return_counts=True)
                sizes.extend(counts[1:])
            t1_banded_metrics[f"{band['range']} clusters"].append(sizes)
            
        for band in std_bands:
            filter_mask = (t2_diff_norm > band["low"]) & (t2_diff_norm < band["high"])
            banded_diff = t2_diff_norm.copy()
            banded_diff[np.invert(filter_mask)] = 0
            banded_diff[banded_diff != 0] = 1
            banded_diff = banded_diff.astype('uint8')
            
            t2_banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))
            
            sizes = []
            for idx in range(0, 256):
                cc = connected_components_2d(banded_diff[:,:,idx])
                values, counts = np.unique(cc, return_counts=True)
                sizes.extend(counts[1:])
            t2_banded_metrics[f"{band['range']} clusters"].append(sizes)
    return t1_losses, t2_losses, t1_clusters, t2_clusters, t1_banded_metrics, t2_banded_metrics


def save_multimodal_values(t1_loss, t2_loss, t1_cluster, t2_cluster, t1_bands, t2_bands, name):
    np.save(os.path.join(mmmetrics_dir, f"{name}_t1_losses.npy"), t1_loss)
    np.save(os.path.join(mmmetrics_dir, f"{name}_t2_losses.npy"), t2_loss)
    
    np.save(os.path.join(mmmetrics_dir, f"{name}_t1_clusters.npy"), t1_cluster)
    np.save(os.path.join(mmmetrics_dir, f"{name}_t2_clusters.npy"), t2_cluster)
    
    with open(os.path.join(mmmetrics_dir, f"{name}_t1_banded_metrics.pkl"), "wb") as f:
        pickle.dump(t1_bands,f)
    with open(os.path.join(mmmetrics_dir, f"{name}_t2_banded_metrics.pkl"), "wb") as f:
        pickle.dump(t2_bands,f)

In [89]:
model_harness = [
    {
        "info": "t1_t2",
        "model_name": "vae_rel2_t1_t2_second_session.pt",
        "t1_only": False,
        "t2_only": False,
        "model": VAE(2)
    }
]

for model_info in model_harness:
    print(model_info["info"])
    # Define data
    train = Data3D(l1_dir, train_indices, t1_only=model_info["t1_only"], t2_only=model_info["t2_only"])
    val = Data3D(l1_dir, val_indices, t1_only=model_info["t1_only"], t2_only=model_info["t2_only"])
    test = Data3D(l5_dir, test_indices, t1_only=model_info["t1_only"], t2_only=model_info["t2_only"])

    # Define model
    model = model_info["model"]
    model_path = os.path.join(model_dir, model_info["model_name"])
    model.load_state_dict(torch.load(model_path))
    model.cuda()
    model.eval()
    
    print("train")
    train_t1_losses, train_t2_losses, train_t1_clusters, train_t2_clusters, train_t1_banded_metrics, train_t2_banded_metrics = get_multimodal_metrics(model, train)
    save_multimodal_values(train_t1_losses, train_t2_losses, train_t1_clusters, train_t2_clusters, train_t1_banded_metrics, train_t2_banded_metrics, "multimodal_train")
    
    print("val")
    val_t1_losses, val_t2_losses, val_t1_clusters, val_t2_clusters, val_t1_banded_metrics, val_t2_banded_metrics = get_multimodal_metrics(model, val)
    save_multimodal_values(val_t1_losses, val_t2_losses, val_t1_clusters, val_t2_clusters, val_t1_banded_metrics, val_t2_banded_metrics, "multimodal_val")
    
    print("test")
    test_t1_losses, test_t2_losses, test_t1_clusters, test_t2_clusters, test_t1_banded_metrics, test_t2_banded_metrics = get_multimodal_metrics(model, test)
    save_multimodal_values(test_t1_losses, test_t2_losses, test_t1_clusters, test_t2_clusters, test_t1_banded_metrics, test_t2_banded_metrics, "multimodal_test")

t1_t2
train


  arr = np.asanyarray(arr)


val
test
