In [1]:
import os
import sys
import tqdm 
import time
import math
import torch
import datetime
import itertools
import torchvision

import numpy as np
import torch.nn as nn
import skimage.io as io
import SimpleITK as sitk
import torch.optim as optim
import torch.nn.functional as F

from tensorboardX import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from all_models import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(42)

custom_name = 'masked_cryo_final_gamma_zoom_fcn3d_resumed'

# Models

### Generator

In [None]:
generator = FCN3D().to(device)
# generator.load_state_dict(torch.load('./bce_fcn_gamma_zoom/2020-03-11 04:22:17.616635_masked_cryo_final_gamma_zoom_fcn3d_resumed/models/50_0_g_.pt')) # final model saved

### Discriminator

In [None]:
discriminator = Discriminator_old(dim=32, sig=True).to(device)
# discriminator.load_state_dict(torch.load('./bce_fcn_gamma_zoom/2020-03-11 04:22:17.616635_masked_cryo_final_gamma_zoom_fcn3d_resumed/models/50_0_d_.pt'))

# Dataloader

Directory structure is as follows:
- ROOT DIR
    - MRI DIR: Containing all 3D MRI sub-volumes
    - CRYO DIR: Containing all 3D Poisson Cryo sub-volumes (Poisson Generation Code Available in MATLAB)

In [None]:
DATA_DIR = './volume_generator/'
MRI_DIR = 'random_gamma_zoom_volume_mri_16'
CRYO_DIR = 'random_gamma_zoom_volume_poisson_16'

mri_list = os.listdir(DATA_DIR+MRI_DIR)
mri_list.sort()
mri_list = mri_list[::2]

cryo_list = os.listdir(DATA_DIR+CRYO_DIR)
cryo_list.sort()
cryo_list = cryo_list[::2]


train_mr, test_mr, train_cryo, test_cryo = train_test_split(mri_list, cryo_list, test_size=0.1, random_state=42)


batch = 8


class VolumeDataset(Dataset):
    def __init__(self, X, Y, root, size = (32,32,32)):
        self.greypath = os.path.join(root, MRI_DIR)
        self.colorpath = os.path.join(root, CRYO_DIR)

        self.greyimg = X
        self.colorimg = Y
        self.imgsize = size

    def __len__(self):
        return len(self.greyimg)

    def __getitem__(self, index):
        mri = sitk.ReadImage(os.path.join(self.greypath, self.greyimg[index]))
        mri = np.nan_to_num(sitk.GetArrayFromImage(mri))
        
        mri = torch.from_numpy(mri)
        
        cryo = sitk.ReadImage(os.path.join(self.colorpath, self.colorimg[index]))
        cryo = np.nan_to_num(sitk.GetArrayFromImage(cryo))
        cryo = torch.from_numpy(cryo)

        return mri, cryo


train_dataset = VolumeDataset(train_mr, train_cryo, DATA_DIR)
train_dataloader = DataLoader(train_dataset, num_workers=2, shuffle=False, batch_size=batch)

test_dataset = VolumeDataset(test_mr, test_cryo, DATA_DIR)
test_dataloader = DataLoader(test_dataset, num_workers=2, shuffle=False, batch_size=batch)

for i, data in enumerate(train_dataloader):
    print(data[1].shape)
    break

# Optimizers

In [None]:
lr = 1.0e-4
momentum = 0.95

In [None]:
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(momentum, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(momentum, 0.999))

In [None]:
criterion_gan = nn.BCELoss().to(device)
criterion_content = SSIM(device).to(device)
criterion_l1 = nn.L1Loss().to(device)

# Training

In [None]:
print_freq = 10
max_epoch = 1002
save_freq = 1000

In [None]:
ROOT_DIR = './bce_fcn_gamma_zoom/'
now = str(datetime.datetime.now()) + '_' + custom_name

if not os.path.exists(ROOT_DIR):
    os.makedirs(ROOT_DIR)

if not os.path.exists(ROOT_DIR + now):
    os.makedirs(ROOT_DIR + now)

LOG_DIR = ROOT_DIR + now + '/logs/'
if not os.path.exists(LOG_DIR):
    os.makedirs(LOG_DIR)

OUTPUTS_DIR = ROOT_DIR  + now + '/outputs/'
if not os.path.exists(OUTPUTS_DIR):
    os.makedirs(OUTPUTS_DIR)

MODEL_DIR = ROOT_DIR + now + '/models/'
if not os.path.exists(MODEL_DIR):
    os.makedirs(MODEL_DIR)

summary_writer = SummaryWriter(LOG_DIR)


In [None]:
lambda_gan = 1.0
lambda_content = -1.0

In [None]:
for epoch in range(max_epoch):
    for i, data in enumerate(train_dataloader):
        mri = data[0].unsqueeze(1).float().to(device)
        cryo = data[1].unsqueeze(1).float().to(device)
        
        gen_cryo = generator(mri)
        
        # print(gen_cryo.max(), gen_cryo.min())
        # # print(mri.max
        # import ipdb; ipdb.set_trace()
        # exit()
        real_cryo_prob = discriminator(cryo)
        fake_cryo_prob = discriminator(gen_cryo)
        
        truth = torch.ones(mri.shape[0],1).to(device)
        fake = torch.zeros(mri.shape[0],1).to(device)
        
        optimizer_G.zero_grad()
        loss_gan = criterion_gan(fake_cryo_prob, truth) * lambda_gan
        loss_content = criterion_content(gen_cryo, cryo) * lambda_content + criterion_l1(gen_cryo, cryo) * math.fabs(lambda_content)
        loss_gen = loss_gan + loss_content
        loss_gen.backward(retain_graph=True)
        
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        loss_real = criterion_gan(real_cryo_prob, truth)
        loss_fake = criterion_gan(fake_cryo_prob, fake)
        loss_dis = (loss_real + loss_fake) * lambda_gan / 2
        if loss_dis.item() > 0.5 or epoch > 10:
            loss_dis.backward(retain_graph=True)            
            optimizer_D.step()
        
        summary_writer.add_scalar('Discriminator Loss', loss_dis.item())
        summary_writer.add_scalar('Content Loss', loss_content.item())
        summary_writer.add_scalar('Generator Loss', loss_gen.item())

        print('Epoch: {}, Iteration: {}, Content Loss: {}, Generator Loss: {}, Discriminator Loss: {}'.format(epoch, i, loss_content.item()
                                                                                                                , loss_gan.item(), loss_dis.item()))

        if i % save_freq == 0:
            print('\n\n Saving model and output \n\n')
            if epoch % 50 == 0:
                torch.save(generator.state_dict(), MODEL_DIR+'{}_{}_g_.pt'.format(epoch,i))
                torch.save(discriminator.state_dict(), MODEL_DIR+'{}_{}_d_.pt'.format(epoch,i))
                
                torch.save(optimizer_G.state_dict(), MODEL_DIR+'{}_{}_g_optim_mr2gr.pt'.format(epoch,i))
                torch.save(optimizer_D.state_dict(), MODEL_DIR+'{}_{}_d_optim_mr2gr.pt'.format(epoch,i))
            
            for j in range(gen_cryo.shape[0]):
                fake_cryo = gen_cryo[j,:,:,:,:]
                fake_cryo = fake_cryo.permute(1,2,3,0)
                fake_cryo = fake_cryo.cpu().detach().numpy()
                cryo_vol = sitk.GetImageFromArray(fake_cryo)
                sitk.WriteImage(cryo_vol, OUTPUTS_DIR+'{}_{}_{}_cryo_gen.mhd'.format(epoch,i,j))
                
                fake_mri = mri[j,0,:,:,:]
                fake_mri = fake_mri.cpu().detach().numpy()
                mri_vol = sitk.GetImageFromArray(fake_mri)
                sitk.WriteImage(mri_vol, OUTPUTS_DIR+'{}_{}_{}mri_gt.mhd'.format(epoch,i,j))

                rev_cryo = cryo[j,:,:,:,:]
                rev_cryo = rev_cryo.permute([1,2,3,0])
                rev_cryo = rev_cryo.cpu().detach().numpy()
                rev_vol = sitk.GetImageFromArray(rev_cryo)
                sitk.WriteImage(rev_vol, OUTPUTS_DIR+'{}_{}_{}cryo_gt.mhd'.format(epoch,i,j))

