#### Segment Info
![image.png](attachment:c282dee1-3af5-42ae-b4d0-fe36060c77f5.png)

In [1]:
segments = [
    "Cerebrospinal Fluid",
    "Cortical Grey Matter",
    "White Matter",
    "Background",
    "Ventricle",
    "Cerebelum",
    "Deep Grey Matter",
    "Brainstem",
    "Hippocampus"
]

### Import Libraries

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.distributions as dist
from torchsummary import summary
import math
import os
import numpy as np
import time
import matplotlib.pyplot as plt
from prettytable import PrettyTable
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from pathlib import Path
import re
from skimage.metrics  import structural_similarity as ssim
import plotly.io as pio
import plotly.express as px
import pandas as pd

pio.renderers.default = 'iframe'

from importlib import reload


# locals
import model_architectures
import visualization
import unet

reload(model_architectures)
from model_architectures import VAESegment, Data3DSingleSegT2, SegMaskData

reload(unet)
from unet import UNet

reload(visualization)
from visualization import brain_diff, viz_slices

  from .autonotebook import tqdm as notebook_tqdm


### Define Paths

In [2]:
research_dir = r"D:/school/research"
code_dir = os.path.join(research_dir, "code")
model_dir = os.path.join(code_dir, "explore_again", "models")
data_dir = os.path.join(research_dir, "data")
dhcp_rel2 = os.path.join(data_dir, "dhcp_rel2")
processed_dir = os.path.join(dhcp_rel2, "processed")
volume_dir = os.path.join(processed_dir, "volumes")
seg_dir = os.path.join(processed_dir, "segments")
seg_vol_dir = os.path.join(processed_dir, "volume_segments")
pred_dir = os.path.join(dhcp_rel2, "predictions")
seg_pred_dir = os.path.join(pred_dir, "vae_9seg")

l1_dir = os.path.join(volume_dir, "l1")
l5_dir = os.path.join(volume_dir, "l5")

l1_seg_dir = os.path.join(seg_dir, "l1")
l5_seg_dir = os.path.join(seg_dir, "l5")

l1_seg_vol_dir = os.path.join(seg_vol_dir, "l1")
l5_seg_vol_dir = os.path.join(seg_vol_dir, "l5")

l1_seg_pred_dir = os.path.join(seg_pred_dir, "l1")
l5_seg_pred_dir = os.path.join(seg_pred_dir, "l5")

In [3]:

metrics_dir = os.path.join(code_dir, "explore_again", "metrics", "individual_segments")
test = np.load(os.path.join(metrics_dir, f"vae_rel2t2_seg0_train_metrics.npy"))

In [5]:
np.mean(test)

0.005799173

### Define Data Parameters

In [4]:
np.random.seed(42)
num_samples = int(len(os.listdir(l1_dir)) / 2)
samples = np.array([i for i in range(0, num_samples)])
np.random.shuffle(samples)

split_val = int(0.8 * num_samples)
train_indices = samples[0:split_val]
val_indices = samples[split_val:]

num_test = int(len(os.listdir(l5_dir)) / 2)
test_indices = np.array([i for i in range(0, num_test)])

### Train Models for Each Segment

In [5]:
for segment_number in range(0, len(segments)):
    print(f"Train model for segment {segment_number}")
    
    # Load data for segment
    train = Data3DSingleSegT2(l1_dir, l1_seg_vol_dir, train_indices, segment=segment_number)
    val = Data3DSingleSegT2(l1_dir, l1_seg_vol_dir, val_indices, segment=segment_number)
    test = Data3DSingleSegT2(l5_dir, l5_seg_vol_dir, test_indices, segment=segment_number)

    batch_size = 1
    train_loader = DataLoader(train, batch_size=batch_size)#, num_workers=1)
    val_loader = DataLoader(val, batch_size=batch_size)#, num_workers=1)
    
    # Define output paths now :)
    model_path = os.path.join(model_dir, f"vae_rel2t2_seg{segment_number}.pt")
    
    # Define the model
    model = VAESegment(1, 1)
    model.cuda()
    criterion = nn.MSELoss()
    learning_rate = 1e-3
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, eps=1e-07)
    
    # Train the model
    num_epochs = 10
    for epoch in range(num_epochs):
        total_loss = 0
        with tqdm(train_loader, unit="batch", ascii=' >=') as tepoch:
            model.train()
            data_counter = 0
            for data in tepoch:
                # clear gradients
                optimizer.zero_grad()

                # forward
                x, y = data
                x = x.cuda()
                y = y.cuda()
                output = model(x)
                loss = criterion(y, output)

                # backward
                loss.backward()
                optimizer.step()

                total_loss += loss.data
                data_counter += 1

            total_loss /= data_counter

            tepoch.set_postfix(loss=loss.item())
            val_loss = 0
            model.eval()
            data_counter = 0
            for data in val_loader:
                with torch.no_grad():
                    x, y = data
                    x = x.cuda()
                    y = y.cuda()
                    output = model(x)
                    loss = criterion(y, output)

                    val_loss += loss.data
                    data_counter += 1
            
            val_loss /= data_counter
            
            print('epoch [{}/{}], train_loss:{:.4f}, val_loss:{:.4f}'
              .format(epoch+1, num_epochs, total_loss, val_loss))
        
    torch.save(model.state_dict(), model_path)

Train model for segment 0




epoch [1/10], train_loss:0.0306, val_loss:0.0088




epoch [2/10], train_loss:0.0075, val_loss:0.0069




epoch [3/10], train_loss:0.0065, val_loss:0.0063




epoch [4/10], train_loss:0.0061, val_loss:0.0061




epoch [5/10], train_loss:0.0060, val_loss:0.0060




epoch [6/10], train_loss:0.0059, val_loss:0.0060




epoch [7/10], train_loss:0.0059, val_loss:0.0059




epoch [8/10], train_loss:0.0059, val_loss:0.0059




epoch [9/10], train_loss:0.0059, val_loss:0.0059




epoch [10/10], train_loss:0.0058, val_loss:0.0059
Train model for segment 1




epoch [1/10], train_loss:0.0453, val_loss:0.0039




epoch [2/10], train_loss:0.0035, val_loss:0.0033




epoch [3/10], train_loss:0.0032, val_loss:0.0032




epoch [4/10], train_loss:0.0031, val_loss:0.0031




epoch [5/10], train_loss:0.0030, val_loss:0.0030




epoch [6/10], train_loss:0.0029, val_loss:0.0030




epoch [7/10], train_loss:0.0029, val_loss:0.0029




epoch [8/10], train_loss:0.0029, val_loss:0.0029




epoch [9/10], train_loss:0.0029, val_loss:0.0029




epoch [10/10], train_loss:0.0029, val_loss:0.0029
Train model for segment 2




epoch [1/10], train_loss:0.0841, val_loss:0.0054




epoch [2/10], train_loss:0.0050, val_loss:0.0048




epoch [3/10], train_loss:0.0047, val_loss:0.0046




epoch [4/10], train_loss:0.0045, val_loss:0.0046




epoch [5/10], train_loss:0.0044, val_loss:0.0045




epoch [6/10], train_loss:0.0043, val_loss:0.0043




epoch [7/10], train_loss:0.0042, val_loss:0.0041




epoch [8/10], train_loss:0.0041, val_loss:0.0041




epoch [9/10], train_loss:0.0041, val_loss:0.0040




epoch [10/10], train_loss:0.0040, val_loss:0.0040
Train model for segment 3




epoch [1/10], train_loss:0.0450, val_loss:0.0012




epoch [2/10], train_loss:0.0009, val_loss:0.0007




epoch [3/10], train_loss:0.0006, val_loss:0.0006




epoch [4/10], train_loss:0.0005, val_loss:0.0005




epoch [5/10], train_loss:0.0005, val_loss:0.0005




epoch [6/10], train_loss:0.0004, val_loss:0.0004




epoch [7/10], train_loss:0.0004, val_loss:0.0004




epoch [8/10], train_loss:0.0004, val_loss:0.0004




epoch [9/10], train_loss:0.0004, val_loss:0.0004




epoch [10/10], train_loss:0.0004, val_loss:0.0004
Train model for segment 4




epoch [1/10], train_loss:0.0226, val_loss:0.0015




epoch [2/10], train_loss:0.0010, val_loss:0.0009




epoch [3/10], train_loss:0.0008, val_loss:0.0007




epoch [4/10], train_loss:0.0007, val_loss:0.0007




epoch [5/10], train_loss:0.0006, val_loss:0.0006




epoch [6/10], train_loss:0.0006, val_loss:0.0006




epoch [7/10], train_loss:0.0006, val_loss:0.0006




epoch [8/10], train_loss:0.0006, val_loss:0.0006




epoch [9/10], train_loss:0.0005, val_loss:0.0006




epoch [10/10], train_loss:0.0005, val_loss:0.0006
Train model for segment 5




epoch [1/10], train_loss:0.0308, val_loss:0.0007




epoch [2/10], train_loss:0.0005, val_loss:0.0004




epoch [3/10], train_loss:0.0004, val_loss:0.0004




epoch [4/10], train_loss:0.0003, val_loss:0.0003




epoch [5/10], train_loss:0.0003, val_loss:0.0003




epoch [6/10], train_loss:0.0003, val_loss:0.0003




epoch [7/10], train_loss:0.0003, val_loss:0.0003




epoch [8/10], train_loss:0.0003, val_loss:0.0003




epoch [9/10], train_loss:0.0003, val_loss:0.0003




epoch [10/10], train_loss:0.0003, val_loss:0.0003
Train model for segment 6




epoch [1/10], train_loss:0.0275, val_loss:0.0008




epoch [2/10], train_loss:0.0007, val_loss:0.0007




epoch [3/10], train_loss:0.0007, val_loss:0.0007




epoch [4/10], train_loss:0.0006, val_loss:0.0006




epoch [5/10], train_loss:0.0006, val_loss:0.0006




epoch [6/10], train_loss:0.0006, val_loss:0.0006




epoch [7/10], train_loss:0.0005, val_loss:0.0005




epoch [8/10], train_loss:0.0005, val_loss:0.0005




epoch [9/10], train_loss:0.0004, val_loss:0.0004




epoch [10/10], train_loss:0.0004, val_loss:0.0004
Train model for segment 7




epoch [1/10], train_loss:0.0423, val_loss:0.0004




epoch [2/10], train_loss:0.0003, val_loss:0.0002




epoch [3/10], train_loss:0.0002, val_loss:0.0001




epoch [4/10], train_loss:0.0001, val_loss:0.0001




epoch [5/10], train_loss:0.0001, val_loss:0.0001




epoch [6/10], train_loss:0.0001, val_loss:0.0001




epoch [7/10], train_loss:0.0001, val_loss:0.0001




epoch [8/10], train_loss:0.0001, val_loss:0.0001




epoch [9/10], train_loss:0.0001, val_loss:0.0001




epoch [10/10], train_loss:0.0001, val_loss:0.0001
Train model for segment 8




epoch [1/10], train_loss:0.0307, val_loss:0.0007




epoch [2/10], train_loss:0.0004, val_loss:0.0003




epoch [3/10], train_loss:0.0002, val_loss:0.0002




epoch [4/10], train_loss:0.0002, val_loss:0.0002




epoch [5/10], train_loss:0.0001, val_loss:0.0001




epoch [6/10], train_loss:0.0001, val_loss:0.0001




epoch [7/10], train_loss:0.0001, val_loss:0.0001




epoch [8/10], train_loss:0.0001, val_loss:0.0001




epoch [9/10], train_loss:0.0001, val_loss:0.0001




epoch [10/10], train_loss:0.0001, val_loss:0.0001


In [6]:
model

VAESegment(
  (model): Sequential(
    (0): ResnetEncoder(
      (pass1): Sequential(
        (0): InstanceNorm3d(1, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (1): ReLU()
        (2): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      )
      (pass2): Sequential(
        (0): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (1): ReLU()
        (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      )
      (conv_bypass): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(4, 4, 4))
      (activation_bypass): ReLU()
    )
    (1): ResnetEncoder(
      (pass1): Sequential(
        (0): InstanceNorm3d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (1): ReLU()
        (2): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
      )
      (pass2): Sequential(
        (0): InstanceNorm3d(32, eps=1e-05, momentum=0