In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributions as dist
from torchsummary import summary
import math
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from pathlib import Path
import re
from skimage.metrics  import structural_similarity as ssim
import plotly.io as pio
import plotly.express as px
import pandas as pd
import pickle
import cv2

pio.renderers.default = 'iframe'

from importlib import reload
import visualization

# locals
import model_architectures

reload(model_architectures)
from model_architectures import VAESegment, Data3DSegToSeg, SegMaskData, Data3DSingleSegToSingleSeg, Data3DSegToSegT1, Data3DSingleSegToSingleSegT1

reload(visualization)
from visualization import brain_diff, viz_slices

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
research_dir = r"D:/school/research"
code_dir = os.path.join(research_dir, "code")
p2_dir = os.path.join(code_dir, "paper_two_code")
model_dir = os.path.join(p2_dir, "models")
data_dir = os.path.join(research_dir, "data")
dhcp_rel2 = os.path.join(data_dir, "dhcp_rel2")
processed_dir = os.path.join(dhcp_rel2, "processed")
volume_dir = os.path.join(processed_dir, "volumes")
seg_dir = os.path.join(processed_dir, "segments")
seg_vol_dir = os.path.join(processed_dir, "volume_segments")
pred_dir = os.path.join(dhcp_rel2, "predictions")
seg_pred_dir = os.path.join(pred_dir, "vae_9seg")
metrics_dir = os.path.join(p2_dir, "metrics")
seg_metrics_dir = os.path.join(metrics_dir, "seg_to_seg")

l1_dir = os.path.join(volume_dir, "l1")
l5_dir = os.path.join(volume_dir, "l5")

l1_seg_dir = os.path.join(seg_dir, "l1")
l5_seg_dir = os.path.join(seg_dir, "l5")

l1_seg_vol_dir = os.path.join(seg_vol_dir, "l1")
l5_seg_vol_dir = os.path.join(seg_vol_dir, "l5")

l1_seg_pred_dir = os.path.join(seg_pred_dir, "l1")
l5_seg_pred_dir = os.path.join(seg_pred_dir, "l5")

In [3]:
segments = [
    "Cerebrospinal Fluid",
    "Cortical Grey Matter",
    "White Matter",
    "Background",
    "Ventricle",
    "Cerebelum",
    "Deep Grey Matter",
    "Brainstem",
    "Hippocampus"
]

std_bands = [
    {
        "range": "-inf to -4",
        "low": -np.inf,
        "high": -4
    },
    {
        "range": "-4 to -3",
        "low": -4,
        "high": -3
    },
    {
        "range": "-3 to -2",
        "low": -3,
        "high": -2
    },
    {
        "range": "2 to 3",
        "low": 2,
        "high": 3
    },
    {
        "range": "3 to 4",
        "low": 3,
        "high": 4
    },
    {
        "range": "4 to inf",
        "low": 4,
        "high": np.inf
    }
]

def connected_components_2d(diff):
    _, labels = cv2.connectedComponents(diff)
    
    return labels


def calculate_clusters(diff, threshold=2):
    cluster_diff = diff.copy()
    filter_mask = (diff > -threshold) & (diff < threshold)
    cluster_diff[filter_mask] = 0
    cluster_diff[cluster_diff != 0] = 1
    prepared_diff = cluster_diff.astype('uint8')
    sizes = []
    
    for idx in range(0, 256):
        clusters = connected_components_2d(prepared_diff[:,:,idx])
        values, counts = np.unique(clusters, return_counts=True)
        sizes.extend(counts[1:])
        
    return sizes


def get_overall_metrics(model, data):
    criterion = nn.MSELoss()
    losses = []
    segment_losses = {s: [] for s in segments}
    clusters = []
    segment_clusters = {s: [] for s in segments}
    banded_metrics = {}
    for band in std_bands:
        banded_metrics[f"{band['range']} pixels"] = []
        banded_metrics[f"{band['range']} clusters"] = []
    
    banded_segment_metrics = {}
    for band in std_bands:
        banded_segment_metrics[f"{band['range']} pixels"] = {s: [] for s in segments}
        banded_segment_metrics[f"{band['range']} clusters"] = {s: [] for s in segments}
        
    counter = 0
    for sample in data:
        x = torch.Tensor(sample).reshape((1,) + sample.shape).cuda()
        pred = model(x)
        
        losses.append(float(criterion(x, pred).cpu()))
        
        for idx, segment in enumerate(segments):
            segment_losses[segment].append(float(criterion(x[:,idx,:,:,:], pred[:,idx,:,:,:]).cpu()))
        
        # Get overall cluster data
        og = sample
        pred = pred.reshape(og.shape).detach().cpu().numpy()
        
        diff = np.sum(og, axis=0) - np.sum(pred, axis=0)
        diff_norm = diff / 0.1
        
        clusters.append(np.mean(calculate_clusters(diff_norm)))
        
        # Get segmented cluster data
        for idx, segment in enumerate(segments):
            diff = og[idx,:,:,:] - pred[idx,:,:,:]
            diff_norm = diff / 0.1
            segment_clusters[segment].append(np.mean(calculate_clusters(diff_norm)))
        
        # For overall banded metrics
        for band in std_bands:
            diff = np.sum(og, axis=0) - np.sum(pred, axis=0)
            diff_norm = diff / 0.1
            filter_mask = (diff_norm > band["low"]) & (diff_norm < band["high"])
            banded_diff = diff_norm.copy()
            banded_diff[np.invert(filter_mask)] = 0
            banded_diff[banded_diff != 0] = 1
            banded_diff = banded_diff.astype('uint8')
            
            banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))
            
            sizes = []
            for idx in range(0, 256):
                cc = connected_components_2d(banded_diff[:,:,idx])
                values, counts = np.unique(cc, return_counts=True)
                sizes.extend(counts[1:])
            banded_metrics[f"{band['range']} clusters"].append(np.mean(sizes))
        
        # For segmented banded metrics
        for band in std_bands:
            for idx, segment in enumerate(segments):
                diff = og[idx,:,:,:] - pred[idx,:,:,:]
                diff_norm = diff / 0.1
                
                filter_mask = (diff_norm > band["low"]) & (diff_norm < band["high"])
                banded_diff = diff_norm.copy()
                banded_diff[np.invert(filter_mask)] = 0
                banded_diff[banded_diff != 0] = 1
                banded_diff = banded_diff.astype('uint8')
                
                banded_segment_metrics[f"{band['range']} pixels"][segment].append(np.sum(banded_diff))
                
                sizes = []
                for idx in range(0, 256):
                    cc = connected_components_2d(banded_diff[:,:,idx])
                    values, counts = np.unique(cc, return_counts=True)
                    sizes.extend(counts[1:])
                if len(sizes) == 0:
                    banded_segment_metrics[f"{band['range']} clusters"][segment].append(0)
                else:
                    banded_segment_metrics[f"{band['range']} clusters"][segment].append(np.mean(sizes))

    return losses, segment_losses, clusters, segment_clusters, banded_metrics, banded_segment_metrics


def save_values(losses, segment_losses, clusters, segment_clusters, banded_metrics, banded_segment_metrics, name, run):
    prefix = f"{name}_{run}"
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_losses.npy"), losses)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_segment_losses.npy"), segment_losses)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_clusters.npy"), clusters)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_segment_clusters.npy"), segment_clusters)
    with open(os.path.join(seg_metrics_dir, f"{prefix}_banded_metrics.pkl"), "wb") as f:
        pickle.dump(banded_metrics,f)
    with open(os.path.join(seg_metrics_dir, f"{prefix}_banded_segment_metrics.pkl"), "wb") as f:
        pickle.dump(banded_segment_metrics,f)


def get_single_seg_metrics(model, data):
    criterion = nn.MSELoss()
    losses = []
    clusters = []
    banded_metrics = {}
    for band in std_bands:
        banded_metrics[f"{band['range']} pixels"] = []
        banded_metrics[f"{band['range']} clusters"] = []

    for sample in data:
        x = torch.Tensor(sample).reshape((1,) + sample.shape).cuda()
        pred = model(x)
        
        losses.append(float(criterion(x, pred).cpu()))
        
        og = sample
        pred = pred.reshape(og.shape).detach().cpu().numpy()
        
        diff = np.sum(og, axis=0) - np.sum(pred, axis=0)
        diff_norm = diff / 0.1
        
        clusters.append(np.mean(calculate_clusters(diff_norm)))
        
        for band in std_bands:
            filter_mask = (diff_norm > band["low"]) & (diff_norm < band["high"])
            banded_diff = diff_norm.copy()
            banded_diff[np.invert(filter_mask)] = 0
            banded_diff[banded_diff != 0] = 1
            banded_diff = banded_diff.astype('uint8')
            
            banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))
            
            sizes = []
            for idx in range(0, 256):
                cc = connected_components_2d(banded_diff[:,:,idx])
                values, counts = np.unique(cc, return_counts=True)
                sizes.extend(counts[1:])
            banded_metrics[f"{band['range']} clusters"].append(np.mean(sizes))

    return losses, clusters, banded_metrics

def save_values(losses, segment_losses, clusters, segment_clusters, banded_metrics, banded_segment_metrics, name, run):
    prefix = f"{name}_{run}"
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_losses.npy"), losses)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_segment_losses.npy"), segment_losses)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_clusters.npy"), clusters)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_segment_clusters.npy"), segment_clusters)
    with open(os.path.join(seg_metrics_dir, f"{prefix}_banded_metrics.pkl"), "wb") as f:
        pickle.dump(banded_metrics,f)
    with open(os.path.join(seg_metrics_dir, f"{prefix}_banded_segment_metrics.pkl"), "wb") as f:
        pickle.dump(banded_segment_metrics,f)

        
def save_seg_values(losses, clusters, banded_metrics, name, run):
    prefix = f"{name}_{run}"
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_losses.npy"), losses)
    np.save(os.path.join(seg_metrics_dir, f"{prefix}_clusters.npy"), clusters)
    with open(os.path.join(seg_metrics_dir, f"{prefix}_banded_metrics.pkl"), "wb") as f:
        pickle.dump(banded_metrics,f)

In [4]:
np.random.seed(42)
num_samples = int(len(os.listdir(l1_dir)) / 2)
samples = np.array([i for i in range(0, num_samples)])
np.random.shuffle(samples)

split_val = int(0.8 * num_samples)
train_indices = samples[0:split_val]
val_indices = samples[split_val:]

num_test = int(len(os.listdir(l5_dir)) / 2)
test_indices = np.array([i for i in range(0, num_test)])

train = Data3DSegToSeg(l1_dir, l1_seg_vol_dir, train_indices)
val = Data3DSegToSeg(l1_dir, l1_seg_vol_dir, val_indices)
test = Data3DSegToSeg(l5_dir, l5_seg_vol_dir, test_indices)

In [26]:
base_model_name = "vae_rel2t2_seg_to_seg"
criterion = nn.MSELoss()
losses = []
clusters = []
banded_metrics = {}
for band in std_bands:
    banded_metrics[f"{band['range']} pixels"] = []
    banded_metrics[f"{band['range']} clusters"] = []

for sample in train:
    pred_img = np.empty_like(sample)
    for segment_number in range(0, len(segments)):
        model_path = os.path.join(model_dir, f"{base_model_name}{segment_number}.pt")
        model = VAESegment(1, 1)
        model.load_state_dict(torch.load(model_path))
        model.cuda()
        model.eval()

        x = torch.Tensor(sample[segment_number]).reshape((1, 1, 256, 256, 256)).cuda()
        pred = model(x)
        pred_img[segment_number] = pred.reshape((256, 256, 256)).detach().cpu().numpy()
    
    losses.append(float(criterion(torch.Tensor(sample), torch.Tensor(pred_img))))
    
    diff = np.sum(sample, axis=0) - np.sum(pred_img, axis=0)
    diff_norm = diff / 0.1
    
    clusters.append(np.mean(calculate_clusters(diff_norm)))
    
    # For overall banded metrics
    for band in std_bands:
        diff = np.sum(sample, axis=0) - np.sum(pred_img, axis=0)
        diff_norm = diff / 0.1
        filter_mask = (diff_norm > band["low"]) & (diff_norm < band["high"])
        banded_diff = diff_norm.copy()
        banded_diff[np.invert(filter_mask)] = 0
        banded_diff[banded_diff != 0] = 1
        banded_diff = banded_diff.astype('uint8')

        banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))

        sizes = []
        for idx in range(0, 256):
            cc = connected_components_2d(banded_diff[:,:,idx])
            values, counts = np.unique(cc, return_counts=True)
            sizes.extend(counts[1:])
        banded_metrics[f"{band['range']} clusters"].append(np.mean(sizes))
    break

In [29]:
get_overall_metrics_single_seg("vae_rel2t2_seg_to_seg", train)

([0.001519416575320065],
 [19.46075384143939],
 {'-inf to -4 pixels': [10521],
  '-inf to -4 clusters': [7.563623292595255],
  '-4 to -3 pixels': [52021],
  '-4 to -3 clusters': [8.438118410381184],
  '-3 to -2 pixels': [242948],
  '-3 to -2 clusters': [13.180772569444445],
  '2 to 3 pixels': [85786],
  '2 to 3 clusters': [5.190343659244918],
  '3 to 4 pixels': [38731],
  '3 to 4 clusters': [5.11840888066605],
  '4 to inf pixels': [14534],
  '4 to inf clusters': [6.338421282163105]})

In [27]:
get_overall_metrics_single_seg

{'-inf to -4 pixels': [10521],
 '-inf to -4 clusters': [7.563623292595255],
 '-4 to -3 pixels': [52021],
 '-4 to -3 clusters': [8.438118410381184],
 '-3 to -2 pixels': [242948],
 '-3 to -2 clusters': [13.180772569444445],
 '2 to 3 pixels': [85786],
 '2 to 3 clusters': [5.190343659244918],
 '3 to 4 pixels': [38731],
 '3 to 4 clusters': [5.11840888066605],
 '4 to inf pixels': [14534],
 '4 to inf clusters': [6.338421282163105]}

In [28]:
def get_overall_metrics_single_seg(base_model_name, data):
    criterion = nn.MSELoss()
    losses = []
    clusters = []
    banded_metrics = {}
    for band in std_bands:
        banded_metrics[f"{band['range']} pixels"] = []
        banded_metrics[f"{band['range']} clusters"] = []
        
    for sample in data:
        pred_img = np.empty_like(sample)
        for segment_number in range(0, len(segments)):
            model_path = os.path.join(model_dir, f"{base_model_name}{segment_number}.pt")
            model = VAESegment(1, 1)
            model.load_state_dict(torch.load(model_path))
            model.cuda()
            model.eval()

            x = torch.Tensor(sample[segment_number]).reshape((1, 1, 256, 256, 256)).cuda()
            pred = model(x)
            pred_img[segment_number] = pred.reshape((256, 256, 256)).detach().cpu().numpy()

        losses.append(float(criterion(torch.Tensor(sample), torch.Tensor(pred_img))))

        diff = np.sum(sample, axis=0) - np.sum(pred_img, axis=0)
        diff_norm = diff / 0.1

        clusters.append(np.mean(calculate_clusters(diff_norm)))

        # For overall banded metrics
        for band in std_bands:
            diff = np.sum(sample, axis=0) - np.sum(pred_img, axis=0)
            diff_norm = diff / 0.1
            filter_mask = (diff_norm > band["low"]) & (diff_norm < band["high"])
            banded_diff = diff_norm.copy()
            banded_diff[np.invert(filter_mask)] = 0
            banded_diff[banded_diff != 0] = 1
            banded_diff = banded_diff.astype('uint8')

            banded_metrics[f"{band['range']} pixels"].append(np.sum(banded_diff))

            sizes = []
            for idx in range(0, 256):
                cc = connected_components_2d(banded_diff[:,:,idx])
                values, counts = np.unique(cc, return_counts=True)
                sizes.extend(counts[1:])
            banded_metrics[f"{band['range']} clusters"].append(np.mean(sizes))
        
    return losses, clusters, banded_metrics