In [None]:
from google.colab import drive
drive.mount('/content/drive')
drive_folder = '/content/drive/My Drive/682_final_proj'
%cd {drive_folder}

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/My Drive/682_final_proj


In [None]:
import matplotlib.pylab as plt
import nibabel as nib
import random
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.model_selection import KFold
from torch import optim
from torchvision.transforms import CenterCrop

In [None]:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

In [None]:
def resize_tensor(img_tensor, crop_size, target_size, is_seg=False):
  depth, height, width = img_tensor.shape
  crop_depth, crop_height, crop_width = crop_size
  depth_start = (depth - crop_depth) // 2
  height_start = (height - crop_height) // 2
  width_start = (width - crop_width) // 2

  cropped_img = img_tensor[depth_start:(depth_start + crop_depth)]

  center_crop = CenterCrop((crop_height, crop_width))
  cropped_img = center_crop(cropped_img)

  cropped_img = cropped_img.unsqueeze(0).unsqueeze(0)

  mode = 'nearest' if is_seg else 'trilinear'
  resized_tensor = F.interpolate(cropped_img, size=target_size, mode=mode, align_corners=None if is_seg else True)

  return resized_tensor

In [None]:
def weights_init_kaiming(m):
  classname = m.__class__.__name__
  if classname.find('Conv') != -1:
    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
  elif classname.find('Linear') != -1:
    nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
  elif classname.find('BatchNorm') != -1:
    nn.init.normal_(m.weight.data, 1.0, 0.02)
    nn.init.constant_(m.bias.data, 0.0)

def init_weights(net, init_type='kaiming'):
  net.apply(weights_init_kaiming)

In [None]:
class UnetConv3(nn.Module):
  def __init__(self, in_size, out_size, is_batchnorm, kernel_size=(3,3,1), padding_size=(1,1,0), init_stride=(1,1,1)):
    super(UnetConv3, self).__init__()

    if is_batchnorm:
      self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size),
                                  nn.InstanceNorm3d(out_size),
                                  nn.ReLU(inplace=True),)
      self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                                  nn.InstanceNorm3d(out_size),
                                  nn.ReLU(inplace=True),)
    else:
      self.conv1 = nn.Sequential(nn.Conv3d(in_size, out_size, kernel_size, init_stride, padding_size),
                                  nn.ReLU(inplace=True),)
      self.conv2 = nn.Sequential(nn.Conv3d(out_size, out_size, kernel_size, 1, padding_size),
                                  nn.ReLU(inplace=True),)

    for m in self.children():
      init_weights(m, init_type='kaiming')

  def forward(self, inputs):
    outputs = self.conv1(inputs)
    outputs = self.conv2(outputs)
    return outputs

class UnetUp3_CT(nn.Module):
  def __init__(self, in_size, out_size, is_batchnorm=True):
    super(UnetUp3_CT, self).__init__()
    self.conv = UnetConv3(in_size + out_size, out_size, is_batchnorm, kernel_size=(3,3,3), padding_size=(1,1,1))
    self.up = nn.Upsample(scale_factor=(2, 2, 2), mode='trilinear')

    for m in self.children():
      if m.__class__.__name__.find('UnetConv3') != -1: continue
      init_weights(m, init_type='kaiming')

  def forward(self, inputs1, inputs2):
    outputs2 = self.up(inputs2)
    offset = outputs2.size()[2] - inputs1.size()[2]
    padding = 2 * [offset // 2, offset // 2, 0]
    outputs1 = F.pad(inputs1, padding)
    return self.conv(torch.cat([outputs1, outputs2], 1))

In [None]:
class unet_3D(nn.Module):
  def __init__(self, feature_scale=8, n_classes=3, is_deconv=True, in_channels=2, is_batchnorm=True):
    super(unet_3D, self).__init__()
    self.is_deconv = is_deconv
    self.in_channels = in_channels
    self.is_batchnorm = is_batchnorm
    self.feature_scale = feature_scale

    filters = [16, 32, 64]
    filters = [int(x / self.feature_scale) for x in filters]

    self.conv1 = UnetConv3(self.in_channels, filters[0], self.is_batchnorm)
    self.maxpool1 = nn.MaxPool3d(kernel_size=(2, 2, 2))

    self.conv2 = UnetConv3(filters[0], filters[1], self.is_batchnorm)
    self.maxpool2 = nn.MaxPool3d(kernel_size=(2, 2, 2))

    self.center = UnetConv3(filters[1], filters[2], self.is_batchnorm)

    self.up_concat2 = UnetUp3_CT(filters[2], filters[1], is_batchnorm)
    self.up_concat1 = UnetUp3_CT(filters[1], filters[0], is_batchnorm)

    self.final = nn.Conv3d(filters[0], n_classes, 1)

    self.dropout1 = nn.Dropout(p=0.2)
    self.dropout2 = nn.Dropout(p=0.2)

    # Initialise weights
    for m in self.modules():
      if isinstance(m, nn.Conv3d) or isinstance(m, nn.BatchNorm3d):
          init_weights(m, init_type='kaiming')

  def forward(self, inputs):
    conv1 = self.conv1(inputs)
    maxpool1 = self.maxpool1(conv1)

    conv2 = self.conv2(maxpool1)
    maxpool2 = self.maxpool2(conv2)

    center = self.center(maxpool2)
    center = self.dropout1(center)

    up2 = self.up_concat2(conv2, center)
    up1 = self.up_concat1(conv1, up2)
    up1 = self.dropout2(up1)

    final = self.final(up1)

    return final

In [None]:
HC_END = 29
BPDWOP_END = 64
BPDWP_END = 83
SS_END = 103

N = SS_END

CROP_SIZE = (128, 160, 160)
TARGET_SIZE = (96, 120, 120)

def get_dir_str(i):
  if i <= HC_END:
    return "HC"
  elif i <= BPDWOP_END:
    return "BPDwoPsy"
  elif i <= BPDWP_END:
    return "BPDwPsy"
  else:
    return "SS"

def load_data(num_images):
  images = torch.zeros(num_images, 1, 96, 120, 120).type(dtype)
  segs = torch.zeros(num_images, 1, 96, 120, 120).type(dtype)
  slice_ind = 0

  for i in random.sample(range(1, 104), num_images):
    dir_str = get_dir_str(i)
    str_i = f'{i:03}'
    dir_path = f'./final_dataset/{dir_str}/{dir_str}_{str_i}'

    img_path = f'{dir_path}/{dir_str}_{str_i}_procimg.nii.gz'
    seg_path = f'{dir_path}/{dir_str}_{str_i}_seg.nii.gz' if i <= BPDWP_END else f'{dir_path}/{dir_str}_{str_i}.seg.nii.gz'

    img_tensor = torch.from_numpy(nib.load(img_path).get_fdata())
    seg_tensor = torch.from_numpy(nib.load(seg_path).get_fdata())

    img_tensor_re = resize_tensor(img_tensor, CROP_SIZE, TARGET_SIZE, is_seg=False).type(dtype)
    seg_tensor_re = resize_tensor(seg_tensor, CROP_SIZE, TARGET_SIZE, is_seg=True).type(dtype)

    images[slice_ind] = img_tensor_re
    segs[slice_ind] = seg_tensor_re
    slice_ind += 1

  return images, segs

train_images, train_segs = load_data(20)
test_images, test_segs = load_data(5)

atlas_dir_path = './final_dataset/HC/HC_013'
atlas_img_path, atlas_seg_path = f'{atlas_dir_path}/HC_013_procimg.nii.gz', f'{atlas_dir_path}/HC_013_seg.nii.gz'
orig_img, orig_seg = torch.from_numpy(nib.load(atlas_img_path).get_fdata()), torch.from_numpy(nib.load(atlas_seg_path).get_fdata())
atlas_img = resize_tensor(orig_img, CROP_SIZE, TARGET_SIZE, is_seg=False).type(dtype)
atlas_seg = resize_tensor(orig_seg, CROP_SIZE, TARGET_SIZE, is_seg=False).type(dtype)

In [None]:
def warp(img, trans_map, is_seg=False):
  grid = trans_map.clone().permute(0, 2, 3, 4, 1)

  mode = 'nearest' if is_seg else 'bilinear'
  warped_img = torch.empty_like(img)
  warped_img = F.grid_sample(img, grid, mode=mode, padding_mode='border', align_corners=True)

  return warped_img.reshape(img.shape)

In [None]:
def charbonnier(x, epsilon=0.001, gamma=0.45):
  return (x**2 + epsilon**2)**gamma

#l_sim
def l2_image_similarity_loss(u, u_hat):
  squared_diff = torch.square(u - u_hat)
  return torch.mean(squared_diff)

# l_cyc
def cycle_loss(l, l_tilde):
  return torch.mean(torch.abs(l_tilde - l))

# l_anatomy_cyc
def anatomy_cycle_loss(true_seg, reconstructed_seg):
  all_sum = 2 * torch.sum(torch.mul(true_seg, reconstructed_seg))
  true_seg_sum = torch.sum(torch.square(true_seg))
  reconstructed_seg_sum = torch.sum(torch.square(reconstructed_seg))
  return 1 - (all_sum / (true_seg_sum + reconstructed_seg_sum))

# l_diff_cyc
def cycle_transformation_loss(forward_map, backward_map):
  return charbonnier(torch.add(forward_map, backward_map)).sum()

# l_diff_cyc
def diff_cyc_loss(l_s, u_s_hat, l_s_hat):
  return torch.sum(charbonnier(torch.abs(l_s - u_s_hat) - torch.abs(u_s_hat - l_s)))

def total_loss(weight1, weight2, img, f_warp_img, atlas, re_atlas, true_seg, f_warp_seg, re_seg, f_map, b_map):
  return l2_image_similarity_loss(img, f_warp_img) + weight1 * cycle_loss(atlas, re_atlas) + weight2 * (anatomy_cycle_loss(true_seg, re_seg) + cycle_transformation_loss(f_map, b_map) + diff_cyc_loss(true_seg, f_warp_seg, re_seg))

In [None]:
def dice_score(x, y):
  x_flat = x.view(-1)
  y_flat = y.view(-1)

  intersection = (x_flat == y_flat).sum()
  union = x_flat.size(dim=0) + y_flat.size(dim=0)

  dice = (2.0 * intersection) / union
  return dice

In [None]:
class Trainer():
  def __init__(self):
    self.forward_model = unet_3D().type(dtype)
    self.backward_model = unet_3D().type(dtype)
    self.optimizer_f = optim.Adam(self.forward_model.parameters(), lr=0.0002)
    self.optimizer_b = optim.Adam(self.backward_model.parameters(), lr=0.0002)

  def fgen_forward(self, x):
    return self.forward_model(x)

  def bgen_forward(self, x):
    return self.backward_model(x)

In [None]:
def train(model, img_batch, seg_batch, num_epochs, l_weight1, l_weight2):
  for i in range(num_epochs):
    model.optimizer_f.zero_grad()
    model.optimizer_b.zero_grad()

    atlas_expand = atlas_img.expand_as(img_batch)

    concat_f_input = torch.cat((img_batch, atlas_expand), dim=1)
    f_map = model.fgen_forward(concat_f_input)
    f_warp_img = warp(atlas_expand, f_map, is_seg=False)
    f_warp_seg = warp(seg_batch, f_map, is_seg=True)

    concat_b_input = torch.cat((f_warp_img, atlas_expand), dim=1)
    b_map = model.bgen_forward(concat_b_input)
    b_warp_img = warp(f_warp_img, b_map, is_seg=False)
    b_warp_seg = warp(f_warp_seg, b_map, is_seg=True)

    combined_loss = total_loss(l_weight1, l_weight2, img_batch, f_warp_img, atlas_expand, b_warp_img, seg_batch, f_warp_seg, b_warp_seg, f_map, b_map)

    combined_loss.backward()

    # Update model parameters
    model.optimizer_f.step()
    model.optimizer_b.step()

    print(f"Iteration {i + 1}: Combined Loss = {combined_loss}")

In [None]:
def validation(model, test_images, test_segs):
  atlas_img_expand = atlas_img.expand_as(test_images)
  atlas_seg_expand = atlas_seg.expand_as(test_segs)

  concat_f_input = torch.cat((test_images, atlas_img_expand), dim=1)
  f_map = model.fgen_forward(concat_f_input)
  # f_warp_img = warp(atlas_img_expand, f_map, is_seg=False)
  f_warp_seg = warp(atlas_seg_expand, f_map, is_seg=True)

  seg_dice = dice_score(f_warp_seg, test_segs)

  return seg_dice

In [None]:
def k_fold_training(train_images, train_segs, k=5, num_epochs=15, lr=0.0002, l_weight1=10, l_weight2=3):
  kf = KFold(n_splits=k, shuffle=True)

  for fold, (train_index, val_index) in enumerate(kf.split(train_images)):
    X_train, X_val = train_images[train_index], train_images[val_index]
    y_train, y_val = train_segs[train_index], train_segs[val_index]

    model = Trainer(lr)

    train(model, X_train, y_train, num_epochs)
    val_dice_score = validation(model, X_val, y_val)

    print(f"Fold {fold + 1} - DICE SCORE: {val_dice_score}")

# k_fold_training(train_images, train_segs, 5, 15, 0.0002, 5, 2)

In [None]:
model = Trainer()

train(model, train_images, train_segs, num_epochs=15, l_weight1=5, l_weight2=2)
test_dice_score = validation(model, test_images, test_segs)
print("TEST SET DICE SCORE:", test_dice_score)

Iteration 1: Combined Loss = 237143824.0
Iteration 2: Combined Loss = 235611376.0
Iteration 3: Combined Loss = 234480640.0
Iteration 4: Combined Loss = 233460592.0
Iteration 5: Combined Loss = 232449040.0
Iteration 6: Combined Loss = 231337472.0
Iteration 7: Combined Loss = 230258368.0
Iteration 8: Combined Loss = 229210384.0
Iteration 9: Combined Loss = 228159760.0
Iteration 10: Combined Loss = 227096640.0
Iteration 11: Combined Loss = 226300288.0
Iteration 12: Combined Loss = 225391648.0
Iteration 13: Combined Loss = 224591984.0
Iteration 14: Combined Loss = 223901104.0
Iteration 15: Combined Loss = 223207328.0
TEST SET DICE SCORE: tensor(0.6878, device='cuda:0')
