In [1]:
# If you are not working on Google Colab, please set the variable to False.

colab = True

if colab:
  from google.colab import drive
  drive.mount('/content/drive/')

Mounted at /content/drive/


In [2]:
if colab:
  import subprocess
  from google.colab import files

  bashCommand = "pip install torch-summary dival"
  process = subprocess.Popen(bashCommand.split(), stdout=subprocess.PIPE)
  output, error = process.communicate()

import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torchvision.transforms as transforms

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from dival.measure import PSNR, SSIM
from torchsummary import summary

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("Running on GPU", torch.cuda.get_device_name(device))
else:
    device = torch.device("cpu")
    print("Running on CPU")

Running on GPU Tesla T4


In [4]:
# Dataset

class LIDCDataset(Dataset):

  def __init__(self, root_dir_fbp, root_dir_gr, transform = None) :

    self.root_dir_fbp = root_dir_fbp
    self.root_dir_gr = root_dir_gr
    self.file_list = self.create_dataset()
    self.transform = transform

  def __len__(self):
    return len(self.file_list)

  def __getitem__(self, index):

    fbp_path = os.path.join(self.root_dir_fbp, self.file_list[index])
    gr_path = os.path.join(self.root_dir_gr, self.file_list[index])

    fbp_im = np.load(fbp_path)
    fbp_im = np.expand_dims(fbp_im, axis=0)
    fbp_tensor = torch.Tensor(fbp_im)

    gr_im = np.load(gr_path)
    gr_im = np.expand_dims(gr_im, axis=0)
    gr_tensor = torch.Tensor(gr_im)

    if self.transform:

      fbp_tensor = self.transform(fbp_tensor)
      gr_tensor = self.transform(gr_tensor)

    return (fbp_tensor, gr_tensor)

  def create_dataset(self):

    list = []
    for filename in os.listdir(self.root_dir_fbp):
      list.append(filename)
    list.sort()

    return list


In [5]:
class Unet3D(nn.Module):
  def __init__(self, in_ch=1, out_ch=1, channels=None, skip_channels=None):
    super(Unet3D,self).__init__()

    self.inb = DownBlock(in_ch,channels[0], max_pool=False)
    self.down_list = nn.ModuleList()
    self.up_list = nn.ModuleList()

    self.len_channel = len(channels)

    for i in range(self.len_channel-1):
      self.down_list.append(DownBlock(channels[i],channels[i+1], max_pool=True))
      self.up_list.append(UpBlock(channels[self.len_channel-1-i],channels[self.len_channel-2-i], skip_channels[self.len_channel-2-i]))

    self.outb = OutBlock(channels[0],out_ch)

  def forward(self,x):
    xs = [self.inb(x), ]
    for i in range(self.len_channel-1):
      xs.append(self.down_list[i](xs[-1]))
    
    x1 = xs[-1]

    for i in range(self.len_channel-1):
      x1 = self.up_list[i](x1, xs[-2-i])
    x1 = self.outb(x1,x)
    return x1

class DownBlock(nn.Module):
  def __init__(self, in_ch, out_ch, max_pool = True):
    super(DownBlock,self).__init__()
    self.max_pool = max_pool

    self.conv = nn.Sequential(
      nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding =1, padding_mode='zeros'),
      nn.BatchNorm3d(out_ch),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding =1, padding_mode='zeros'),
      nn.BatchNorm3d(out_ch),
      nn.LeakyReLU(0.2, inplace=True))
    
  def forward(self, x):
    if self.max_pool:
      x = nn.MaxPool3d(2, stride=2)(x)
    x = self.conv(x)
    return x

class UpBlock(nn.Module):
  def __init__(self, in_ch, out_ch, skip_ch):
    super(UpBlock,self).__init__()

    self.up = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)

    self.skip_conv = nn.Sequential(
      nn.Conv3d(out_ch, skip_ch, kernel_size=1, stride=1),
      nn.BatchNorm3d(skip_ch),
      nn.LeakyReLU(0.2, inplace=True))

    self.conv = nn.Sequential(
      nn.Conv3d(in_ch+skip_ch, out_ch, kernel_size=3, stride=1, padding = 1, padding_mode='zeros'),
      nn.BatchNorm3d(out_ch),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding = 1, padding_mode='zeros'),
      nn.BatchNorm3d(out_ch),
      nn.LeakyReLU(0.2, inplace=True))
    
  def forward(self, x1, x2):

    x1 = self.up(x1) # Upsample
    x2 = self.skip_conv(x2) # Skip connections 
    x1 = torch.cat((x1, x2), dim=1)
    x1 = self.conv(x1)
    return x1


class OutBlock(nn.Module):
  def __init__(self, in_ch, out_ch):
    super(OutBlock,self).__init__()

    self.conv = nn.Sequential(
      nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=1),
      nn.BatchNorm3d(out_ch),
      nn.LeakyReLU(0.2, inplace=True))
    
  def forward(self, x1, x2):
    x1 = self.conv(x1)
    return torch.add(x1, x2)


In [6]:
channels = [16,32,64,128,256]
skip_channels = channels[:-1]

unet3d = Unet3D(channels=channels, skip_channels=skip_channels).to(device)

In [8]:
# Crate test Dataset

# Please give your path for dataset
train_root_dir_fbp = '.../data/Train/FBP'
train_root_dir_gr = '.../data/Train/Ground_Truth'
train_dataset = LIDCDataset( train_root_dir_fbp, train_root_dir_gr)

batch_size = 4

train_loader = DataLoader(dataset=train_dataset, batch_size = batch_size, shuffle=True) # shuffle=True for 3D UNet

In [9]:

# Load a 3D U-Net model

optimizer = torch.optim.Adam(unet3d.parameters(), lr = 1e-3)

load = True

# Please give your path for models
model_dir = '.../models/3D U-NetR Models'
if load:
  model_list = []
  for filename in os.listdir(model_dir):
    model_list.append(filename) 
  model_list.sort()

  state = torch.load(os.path.join(model_dir,model_list[-2])) # map_location=torch.device('cpu')

  unet3d.load_state_dict(state['state_dict'])
  optimizer.load_state_dict(state['optimizer'])
  loss_arr = np.loadtxt(os.path.join(model_dir,model_list[-1]))  # map_location=torch.device('cpu')) # Remove map_location for GPU usage

  print("Loaded Model and Optimizer: ", model_list[-2])
  print("Loaded Loss Array: ", model_list[-1])
  print("Total Epoch: ", len(loss_arr))
else:
  # Create an empty array in shape 0,2
  # 1st Column = Training Loss, 2nd Column = Validation Loss
  loss_arr = np.empty(shape=[0,4])

Loaded Model and Optimizer:  LIDC_3DU-Net_epoch_1108
Loaded Loss Array:  LIDC_3DU-Net_epoch_1108.txt
Total Epoch:  1108


In [None]:
# Training the 3D-UNet Model

c=128
EPOCHS = 50
for epoch in range(EPOCHS):

  # Training
  train_epoch_loss = 0
  counter = 0

  unet3d.train()
  for data in tqdm(train_loader, position=0, leave=True):
    counter += 1

    x, y = data
    x, y = x.to(device), y.to(device)

    unet3d.zero_grad()
    output = unet3d(x)
    loss = torch.nn.L1Loss()(output, y)
    loss.backward()
    optimizer.step()
    train_epoch_loss += loss.item()

  train_epoch_loss /= counter
  print("\nTraining Loss:",train_epoch_loss, "Epoch: ", epoch) 

  print("Saving the model and loss")
  
  # Save model
  
  state = {
    'epoch': epoch,
    'state_dict': unet3d.state_dict(),
    'optimizer': optimizer.state_dict()
  }

  model_name = ('/LIDC_3DU-Net_epoch_' + format(loss_arr.shape[0], '03d'))
  model_path =  model_dir + model_name
  torch.save(state, model_path)

  loss_dir = model_dir + model_name + ".txt"
  np.savetxt(loss_dir,loss_arr)
