In [1]:
import numpy as np
from util.data_preparation import VariedSizedImagesCollate, LoadingImgs_to_ListOfListOfTensors, VariedSizedImagesDataset, data_aug_preprocessing_HR, patch_wise_predict
import time
import torch
import torchvision.transforms as T
from options.train_options import TrainOptions
from models import create_model
def linear_normalize(tmp):
    return (tmp - tmp.min())/(tmp.max() - tmp.min())
import matplotlib.pyplot as plt
import os
from PIL import Image
import imageio

'''load data'''
data_folder = 'dataset/example_testing_data/'
image_indices = [1, 2, 3]
list_of_list_of_tensors = LoadingImgs_to_ListOfListOfTensors(data_folder, image_indices)
dataset = VariedSizedImagesDataset(list_of_list_of_tensors)
## batch_size can only be 1
dataloader = torch.utils.data.DataLoader(dataset,batch_size=1, collate_fn=VariedSizedImagesCollate, shuffle=False,
                                         num_workers=0, drop_last=True)

'''create model and load model weights'''
model = '--model cutreg_twostage'
opt = TrainOptions(model).parse()   # get training options
opt.input_nc = 3
model = create_model(opt)      # create a model given opt.model and other options

'''CUT model weights'''
# G_weights_name = 'CUT_net_G'

'''FastCUT model weights'''
# G_weights_name = 'FastCUT_net_G'

'''Pseudo + CUT + REG model weights'''
G_weights_name = 'CUT+REG+Pseudo_net_G'

'''ablation study'''
'''Pseudo + CUT model weights'''
# G_weights_name = 'CUT+Pseudo_net_G'

'''CUT + REG model weights'''
# G_weights_name = 'CUT+REG_net_G'

'''Registration only'''
# G_weights_name = 'PreREG_net_G'

'''Pseudo label only'''
# G_weights_name = 'PseudoSupervised_net_G'

In [None]:
'''generate DS results from netG'''
model.netG.load_state_dict(torch.load('model_weights/'+G_weights_name+'.pth'))

img_dir = 'test_results/' + G_weights_name + '/'
if not os.path.exists(img_dir):
    os.makedirs(img_dir)
    print(f'Directory {img_dir} created')
else:
    print(f'Directory {img_dir} already exists')

for i, data in enumerate(dataloader):
    model.set_input(data)
    fake_B = model.patch_wise_predict(input_data=model.real_A_eq, bs=4, stride_ratio=4).cpu().numpy()
    Image.fromarray((fake_B*255).astype(np.uint8)).save(img_dir + 'translated_' + str(i + 1) + '.png')
    np.save(img_dir + 'translated_' + str(i + 1) + '.npy', fake_B)

In [None]:
'''generate registration results from netR'''
R_weights_name = 'registration_net_R'
model.netR.load_state_dict(torch.load('model_weights/'+R_weights_name+'.pth'))
test_R_target_list_numpy = []
test_R_moving_OD_list_numpy = []
test_R_OD_list_numpy = []
test_R_GS_list_numpy = []
img_dir = 'test_results/' + R_weights_name + '/'
if not os.path.exists(img_dir):
    os.makedirs(img_dir)
    print(f'Directory {img_dir} created')
else:
    print(f'Directory {img_dir} already exists')
for i, data in enumerate(dataloader):
    model.set_input(data)
    model.forward_stage2()
    test_R_target_list_numpy.append(1.0 - model.real_A_eq.permute(0,2,3,1)[:,:,:,0].squeeze().detach().cpu().numpy())
    test_R_moving_OD_list_numpy.append(1.0 - model.fake_A_eq.permute(0,2,3,1)[:,:,:,0].squeeze().detach().cpu().numpy())
    model.forward_stage2()
    test_R_OD_list_numpy.append(1.0 - model.fake_A_reg_eq.permute(0,2,3,1).squeeze().detach().cpu().numpy())
    test_R_GS_list_numpy.append(model.real_B_reg.permute(0,2,3,1).squeeze().detach().cpu().numpy())

for i in range(len(test_R_target_list_numpy)):
    imageio.imwrite(img_dir+'OCT_eq'+str(i + 1)+'.tif', (test_R_target_list_numpy[i]*255).astype(np.uint8))
    imageio.imwrite(img_dir+'moving_OD'+str(i + 1)+'.tif', (test_R_moving_OD_list_numpy[i]*255).astype(np.uint8))
    imageio.imwrite(img_dir+'moved_OD'+str(i + 1)+'.tif', (test_R_OD_list_numpy[i]*255).astype(np.uint8))
    imageio.imwrite(img_dir+'moved_GS'+str(i + 1)+'.tif', (test_R_GS_list_numpy[i]*255).astype(np.uint8))