In [1]:
import os
import sys
import tqdm 
import time
import math
import torch
import datetime
import itertools
import torchvision

import numpy as np
import torch.nn as nn
import skimage.io as io
import SimpleITK as sitk
import torch.optim as optim
import torch.nn.functional as F

from tensorboardX import SummaryWriter
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

from all_models import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
np.random.seed(42)

custom_name = 'masked_cryo_final_gamma_zoom_fcn3d_resumed'

# Models

### Generator

In [None]:
generator = FCN3D().to(device)
generator.load_state_dict(torch.load('./bce_fcn_gamma_zoom/2020-03-11 04:22:17.616635_masked_cryo_final_gamma_zoom_fcn3d_resumed/models/50_0_g_.pt')) # final model saved

### Discriminator

In [None]:
discriminator = Discriminator_old(dim=32, sig=True).to(device)
discriminator.load_state_dict(torch.load('./bce_fcn_gamma_zoom/2020-03-11 04:22:17.616635_masked_cryo_final_gamma_zoom_fcn3d_resumed/models/50_0_d_.pt'))

# Dataloader

Directory structure is as follows:
- ROOT DIR
    - MRI DIR: Containing all 3D MRI sub-volumes
    - CRYO DIR: Containing all 3D Poisson Cryo sub-volumes (Poisson Generation Code Available in MATLAB)

In [None]:
DATA_DIR = './volume_generator/'
MRI_DIR = 'random_gamma_zoom_volume_mri_16'
CRYO_DIR = 'random_gamma_zoom_volume_poisson_16'

mri_list = os.listdir(DATA_DIR+MRI_DIR)
mri_list.sort()
mri_list = mri_list[::2]

cryo_list = os.listdir(DATA_DIR+CRYO_DIR)
cryo_list.sort()
cryo_list = cryo_list[::2]


train_mr, test_mr, train_cryo, test_cryo = train_test_split(mri_list, cryo_list, test_size=0.1, random_state=42)


batch = 8


class VolumeDataset(Dataset):
    def __init__(self, X, Y, root, size = (32,32,32)):
        self.greypath = os.path.join(root, MRI_DIR)
        self.colorpath = os.path.join(root, CRYO_DIR)

        self.greyimg = X
        self.colorimg = Y
        self.imgsize = size

    def __len__(self):
        return len(self.greyimg)

    def __getitem__(self, index):
        mri = sitk.ReadImage(os.path.join(self.greypath, self.greyimg[index]))
        mri = np.nan_to_num(sitk.GetArrayFromImage(mri))
        
        mri = torch.from_numpy(mri)
        
        cryo = sitk.ReadImage(os.path.join(self.colorpath, self.colorimg[index]))
        cryo = np.nan_to_num(sitk.GetArrayFromImage(cryo))
        cryo = torch.from_numpy(cryo)

        return mri, cryo


train_dataset = VolumeDataset(train_mr, train_cryo, DATA_DIR)
train_dataloader = DataLoader(train_dataset, num_workers=2, shuffle=False, batch_size=batch)

test_dataset = VolumeDataset(test_mr, test_cryo, DATA_DIR)
test_dataloader = DataLoader(test_dataset, num_workers=2, shuffle=False, batch_size=batch)

for i, data in enumerate(train_dataloader):
    print(data[1].shape)
    break

# Optimizers

In [None]:
lr = 1.0e-4
momentum = 0.95

In [None]:
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(momentum, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(momentum, 0.999))

In [None]:
criterion_gan = nn.BCELoss().to(device)
criterion_content = SSIM(device).to(device)
criterion_l1 = nn.L1Loss().to(device)

# 