In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
from matplotlib import pyplot as plt
import pydicom
import os
import time
import glob
import numpy as np
from numpy import savez_compressed
from numpy import load
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from skimage.metrics import structural_similarity as ssim
import pydicom
from pydicom.pixel_data_handlers.util import apply_voi_lut
import pandas as pd
import gc
gc.collect()
torch.cuda.empty_cache()
import albumentations as A
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')

### Preprocessing

In [None]:
class ThresholdTransform(object):
    def __init__(self, thr_255):
        self.thr = thr_255 / 255.
    
    def __call__(self, x):
        return (x>self.thr).to(x.dtype)

In [None]:
transform = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Resize((256,256)),
                    #transforms.Normalize([0.5],[0.5]),
                    #ThresholdTransform(thr_255=-1)
])

transform_info = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
def NORM(array):
    new_array = ((array / 65535) * 2) - 1
    return new_array

In [None]:
class Dataloader():
    def __init__(self, o_path, t_path):
        self.o_dataset = []
        self.t_dataset = []
        self.info_list = []
        self.o_path = o_path
        self.t_path = t_path

    def preprocess(self):     
        org_patients = []
        
        data = pd.read_csv('train_info.csv')
        data.set_index('SubjectID', inplace=True)
        
        
        ########### ORIGINAL ###########
        # Reading and putting in Original Images
        for org_file in glob.glob(self.o_path):
            org_patient = []
            org_dcm = []
            info_list = []
            org_name = org_file[31:]
            org_patients.append(org_name)
            
            # DICOM file path
            for dcm_file in glob.glob(org_file + '/*'):
                dcm_base = os.path.basename(dcm_file)
                org_dcm.append(dcm_base)
            org_dcm.sort()
            
            
            # DICOM file name
            for dcm_name in org_dcm:
                org_img = pydicom.dcmread(org_file + '/' + dcm_name)            
                
                # convert dicom file to numpy array
                data1 = org_img.pixel_array
                data1 = data1.astype('float32')
                np.set_printoptions(edgeitems=200)
                
                # normalize data
                data1 = NORM(data1)
                
                # change to tensor and resize data to 256x256
                data1 = transform(data1)
                data1 = np.reshape(data1, (1,1,256,256))
                org_patient.append(data1)
                
            org_output = org_patient[0]

            count = 0
            for o_data in org_patient:
                if count == 0:
                    count += 1
                    continue
                org_output = torch.cat([org_output, o_data], 0)

            self.o_dataset.append(org_output)
            
            
        ###### TARGET ##########
        
        for target_name in org_patients:
            tar_patient = []
            tar_dcm = []
            info_temp = []
            
            for dcm_file in glob.glob(self.t_path + target_name + '/*'):
                dcm_base = os.path.basename(dcm_file)
                tar_dcm.append(dcm_base)
            tar_dcm.sort()
            
            for dcm_name in tar_dcm:
                tar_img = pydicom.dcmread(self.t_path + target_name + '/' + dcm_name)
                
                # convert dicom file to numpy array
                data2 = tar_img.pixel_array
                data2 = data2.astype('float32')
                data2 = NORM(data2)
                data2 = transform(data2)
                data2 = np.reshape(data2, (1,1,256,256))
                tar_patient.append(data2)
                
                target_id = target_name[5:]
                info = data.loc[target_id]
                info_np = info.to_numpy()
                info_torch = torch.Tensor(info_np)
                info_torch = info_torch.type(torch.float32)
                info_torch = info_torch.reshape(1,1,5)
                info_temp.append(info_torch)
        
            tar_output = tar_patient[0]
            info_output = info_temp[0]
            
            count = 0
            for t_data in tar_patient:
                if count == 0:
                    count += 1
                    continue
                tar_output = torch.cat([tar_output, t_data], 0)
            
            self.t_dataset.append(tar_output)
            
            count = 0
            for info in info_temp:
                if count == 0:
                    count += 1
                    continue
                info_output = torch.cat([info_output, info], 0)
            
            self.info_list.append(info_output)
            
        final_orig = self.o_dataset[0]
        final_target = self.t_dataset[0]
        final_info = self.info_list[0]
        
        for i in range(1, len(self.o_dataset)):
            final_orig = torch.cat([final_orig, self.o_dataset[i]], 0)
            final_target = torch.cat([final_target, self.t_dataset[i]], 0)
            final_info = torch.cat([final_info, self.info_list[i]], 0)
            
        return final_orig, final_target, final_info

In [None]:
data = Dataloader('/home/mri-any/GAN_Data/orignal/*', '/home/mri-any/GAN_Data/target/')

orig, target, info = data.preprocess()
print(orig.shape)
print(info.shape)
ds = TensorDataset(orig, target, info)
train_dataloader = DataLoader(ds, batch_size=1, shuffle=True)
#train_dataloader = DataLoader(ds, batch_size=1, shuffle=False)

In [None]:
test_data = Dataloader('/home/mri-any/GAN_test/orignal/*', '/home/mri-any/GAN_test/target/')
orig, target, info = test_data.preprocess()
print(orig.shape)
print(info.shape)

ds = TensorDataset(orig, target, info)
test_dataloader = DataLoader(ds, batch_size=1, shuffle=True)

### Generator Model ###

In [None]:
# UNet
class UNetDown(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True, dropout=0.0):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 4, stride=2, padding=1, bias=False)]

        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels)),

        layers.append(nn.LeakyReLU(0.2))

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.down = nn.Sequential(*layers)

    def forward(self, x):
        x = self.down(x)
        return x

In [None]:
class UNetUp(nn.Module):
    def __init__(self, in_channels, out_channels, dropout=0.0):
        super().__init__()

        layers = [
            nn.ConvTranspose2d(in_channels, out_channels,4,2,1,bias=False),
            nn.InstanceNorm2d(out_channels),
            nn.LeakyReLU()
        ]

        if dropout:
            layers.append(nn.Dropout(dropout))

        self.up = nn.Sequential(*layers)

    def forward(self,x,skip):
        x = self.up(x)
        x = torch.cat((x,skip),1)
        return x

In [None]:
# generator: 가짜 이미지를 생성합니다.
class GeneratorUNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super().__init__()

        self.info = info
        self.down1 = UNetDown(in_channels, 64, normalize=False)
        self.down2 = UNetDown(64,128)                 
        self.down3 = UNetDown(128,256)               
        self.down4 = UNetDown(256,512,dropout=0.5) 
        self.down5 = UNetDown(512,512,dropout=0.5)      
        self.down6 = UNetDown(512,512,dropout=0.5)             
        self.down7 = UNetDown(512,512,dropout=0.5)              
        self.down8 = UNetDown(512,509,normalize=False,dropout=0.5)

        self.up1 = UNetUp(512,512,dropout=0.5)
        self.up2 = UNetUp(1024,512,dropout=0.5)
        self.up3 = UNetUp(1024,512,dropout=0.5)
        self.up4 = UNetUp(1024,512,dropout=0.5)
        self.up5 = UNetUp(1024,256)
        self.up6 = UNetUp(512,128)
        self.up7 = UNetUp(256,64)
        self.up8 = nn.Sequential(
            nn.ConvTranspose2d(128,1,4,stride=2,padding=1),
            nn.Tanh()
        )

    def forward(self, x, info):
        d1 = self.down1(x)
        d2 = self.down2(d1)
        d3 = self.down3(d2)
        d4 = self.down4(d3)
        d5 = self.down5(d4)
        d6 = self.down6(d5)
        d7 = self.down7(d6)
        d8 = self.down8(d7)
        
        #print(d8.shape)
        d8 = torch.cat([d8, info], 1)
        
        u1 = self.up1(d8,d7)
        u2 = self.up2(u1,d6)
        u3 = self.up3(u2,d5)
        u4 = self.up4(u3,d4)
        u5 = self.up5(u4,d3)
        u6 = self.up6(u5,d2)
        u7 = self.up7(u6,d1)
        u8 = self.up8(u7)

        return u8



# check

# info = torch.randn(150,3,1,1, device=device)
# x = torch.randn(150, 1,256,256,device=device)
# model = GeneratorUNet().to(device)
# out = model(x, info)
# print(out.shape)


In [None]:
class encoder(nn.Module):
    def __init__(self):
        super(encoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(5, 128),
            nn.ReLU(True),
            nn.Linear(128, 64),
            nn.ReLU(True), nn.Linear(64, 12), 
            nn.ReLU(True), nn.Linear(12, 3))

    def forward(self, x):
        x = self.encoder(x)
        return x


# Load Pretrained Encoder Weights
encoder_model = encoder()
weights = torch.load('./encoder_weight3.pth')
encoder_model.load_state_dict(weights)


### Discriminator Model

In [None]:
class Dis_block(nn.Module):
    def __init__(self, in_channels, out_channels, normalize=True):
        super().__init__()

        layers = [nn.Conv2d(in_channels, out_channels, 3, stride=2, padding=1)]
        if normalize:
            layers.append(nn.InstanceNorm2d(out_channels))
        layers.append(nn.LeakyReLU(0.2))
    
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        x = self.block(x)
        return x

In [None]:
class Discriminator(nn.Module):
    def __init__(self, in_channels=1):
        super().__init__()

        self.stage_1 = Dis_block(in_channels*2,64,normalize=False)
        self.stage_2 = Dis_block(64, 128)
        self.stage_3 = Dis_block(128, 256)
        self.stage_4 = Dis_block(256, 512)
        self.fc1 = torch.nn.Linear(131072, 2)

    def forward(self,a,b):
        x = torch.cat((a,b),1)
        x = self.stage_1(x)
        x = self.stage_2(x)
        x = self.stage_3(x)
        x = self.stage_4(x)
        x = torch.flatten(x)
        x = self.fc1(x)
        return x

# check
#x = torch.randn(160,1,256,256,device=device)
#model = Discriminator().to(device)
#out = model(x,x)
#print(out.shape)

In [None]:
model_gen = GeneratorUNet().to(device)
model_dis = Discriminator().to(device)

In [None]:
# 가중치 초기화
def initialize_weights(model):
    class_name = model.__class__.__name__
    if class_name.find('Conv') != -1:
        nn.init.normal_(model.weight.data, 0.0, 0.02)


# 가중치 초기화 적용
model_gen.apply(initialize_weights);
model_dis.apply(initialize_weights);

In [None]:
augmentation = A.Compose([
  A.HorizontalFlip(p=0.5),
  A.VerticalFlip(p=0.5),    
])

In [None]:
# 손실함수
loss_func_gan = nn.BCEWithLogitsLoss()
loss_func_pix = nn.L1Loss()

# loss_func_pix 가중치
lambda_pixel = 100
lambda_gen = 1

# 최적화 파라미터
from torch import optim
g_lr = 2e-4
d_lr = 2e-8
beta1 = 0.5
beta2 = 0.999

opt_dis = optim.Adam(model_dis.parameters(),lr=d_lr,betas=(beta1,beta2))
opt_gen = optim.Adam(model_gen.parameters(),lr=g_lr,betas=(beta1,beta2))

In [None]:
# 학습
model_gen.train()
model_dis.train()

batch_count = 0
num_epochs = 100
start_time = time.time()

loss_hist = {'gen':[],
             'dis':[]}

epoch_g_loss = []
epoch_d_loss = []

train_ssim_FO = [] #fake, original
train_ssim_FT = [] #fake, target
train_ssim_OT = []
test_ssim_FO = []
test_ssim_FT = []
test_ssim_OT = []

for epoch in range(num_epochs):
    train_ssim_1 = [] #fake, original
    train_ssim_2 = [] #fake, target
    train_ssim_3 = [] #original, target
    
    for a, b, info in train_dataloader:
        ba_si = a.size(0)
        # real image
        
        
#         #### For Augmentation of Data ####
#         real_a = a
#         real_b = b
#         torch.set_printoptions(profile = 'full')
               
            
#         concat_image = torch.cat([real_a, real_b], 0)
#         concat_image = concat_image.numpy()
#         concat_image = augmentation(image=concat_image)

#         #plt.imshow(np.reshape(concat_image['image'][0], (256,256)))
#         #plt.imshow(np.reshape(concat_image['image'][1], (256,256)))
        
        
#         real_a = np.reshape(concat_image['image'][0],(1,1,256,256))
#         real_b = np.reshape(concat_image['image'][1],(1,1,256,256))
        
#         real_a = torch.Tensor(real_a)
#         real_b = torch.Tensor(real_b)
                
#         real_a = real_a.to(device)
#         real_b = real_b.to(device)
        
        
        real_a = a.to(device)
        real_b = b.to(device)
        
        
        ## Labeling for real and fake
        real_label = torch.ones(2)
        fake_label = torch.zeros(2)
        real_label[1] = 0
        fake_label[1] = 1
        
        
        real_label, fake_label = real_label.to(device), fake_label.to(device)

        ##### GENERATOR #####
        model_gen.train()
        model_dis.eval()
        model_gen.zero_grad()
        
        
        encoded_info = encoder_model(info).to(device)
        encoded_info = encoded_info.reshape(1,3,1,1)
        

        fake_b = model_gen(real_a, encoded_info) # 가짜 이미지 생성
        out_dis = model_dis(fake_b, real_b) # 가짜 이미지 
        
        gen_loss = loss_func_gan(out_dis, real_label)
        pixel_loss = loss_func_pix(fake_b, real_b)

        
        g_loss = (lambda_gen * gen_loss) + (lambda_pixel * pixel_loss)
        g_loss.backward()
        opt_gen.step()

       
        ##### DISCRIMINATOR #######
        model_dis.train()
        model_dis.zero_grad()

        out_dis = model_dis(real_b, real_a) # 진짜 이미지 식별
        real_loss = loss_func_gan(out_dis, real_label)

        out_dis = model_dis(fake_b.detach(), real_a) # 가짜 이미지 식별
        fake_loss = loss_func_gan(out_dis, fake_label)

        d_loss = (real_loss + fake_loss) / 2.
        d_loss.backward()
        opt_dis.step()
        
        epoch_g_loss.append(g_loss.item())
        epoch_d_loss.append(d_loss.item())
        
        
        #### TRAINING SET SSIM ####
        fake_imgs = model_gen(a.to(device), encoded_info).detach().cpu()
        fake_imgs = np.squeeze(fake_imgs[0])

        real_imgs_a = a
        real_imgs_a = np.squeeze(real_imgs_a[0])
        real_imgs_b = b
        real_imgs_b = np.squeeze(real_imgs_b[0])


        fake_np = fake_imgs.numpy()
        real_np_a = real_imgs_a.numpy()
        real_np_b = real_imgs_b.numpy()

        s1 = ssim(fake_np, real_np_a)
        s2 = ssim(fake_np, real_np_b)
        s3 = ssim(real_np_a, real_np_b)
        
        train_ssim_1.append(s1) #fake, original
        train_ssim_2.append(s2) #fake, target
        train_ssim_3.append(s3) #original, target
        ##############################
            
        
        batch_count +=1
        
        if batch_count % 6516 == 0:
            loss_hist['gen'].append(sum(epoch_g_loss)/len(epoch_g_loss))
            loss_hist['dis'].append(sum(epoch_d_loss)/len(epoch_d_loss))
            epoch_g_loss = []
            epoch_d_loss = []
            train_ssim_FO.append(sum(train_ssim_1)/len(train_ssim_1))
            train_ssim_FT.append(sum(train_ssim_2)/len(train_ssim_2))
            train_ssim_OT.append(sum(train_ssim_3)/len(train_ssim_3))
            
#             print("FO:",sum(train_ssim_1)/len(train_ssim_1))
#             print("FT:",sum(train_ssim_2)/len(train_ssim_2))
#             print("OT:",sum(train_ssim_3)/len(train_ssim_3))
            
        
        
        if batch_count % 6516 == 0:
            print('Epoch: %.0f, G_Loss: %.6f, D_Loss: %.6f, time: %.2f min' %(epoch, g_loss.item(), d_loss.item(), (time.time()-start_time)/60))
            fake_imgs = model_gen(a.to(device), encoded_info).detach().cpu()
            real_imgs = (b.to(device)).detach().cpu()

            a = np.squeeze(a)
            real_imgs = np.squeeze(real_imgs)
            fake_imgs = np.squeeze(fake_imgs)

            fig = plt.figure()
            ax1 = fig.add_subplot(1,3,1)
            ax2 = fig.add_subplot(1,3,2)
            ax3 = fig.add_subplot(1,3,3)

            ax1.imshow(a)
            ax2.imshow(fake_imgs)
            ax3.imshow(real_imgs)
            plt.show()

            # Saving model weights
            path2models = '/home/mri-any/GAN_weight/test/'
            os.makedirs(path2models, exist_ok=True)

            gen_weight_path = 'weights_gen_' + str(epoch) + '.pt'
            dis_weight_path = 'weights_dis_' + str(epoch) + '.pt'
            path2weights_gen = os.path.join(path2models, gen_weight_path)
            path2weights_dis = os.path.join(path2models, dis_weight_path)

            torch.save(model_gen.state_dict(), path2weights_gen)
            torch.save(model_dis.state_dict(), path2weights_dis)

            
            #### TESTING SET SSIM #### 
            model_gen.eval()

            ssim_1 = [] #fake, original
            ssim_2 = [] #fake, target
            ssim_3 = [] #original, target

            
    
            for a,b, info in test_dataloader:
                encoded_info = encoder_model(info).to(device)
                encoded_info = encoded_info.reshape(1,3,1,1)

                fake_imgs = model_gen(a.to(device), encoded_info).detach().cpu()
                fake_imgs = np.squeeze(fake_imgs[0])
                real_imgs_a = a
                real_imgs_a = np.squeeze(real_imgs_a[0])
                real_imgs_b = b
                real_imgs_b = np.squeeze(real_imgs_b[0])

                fake_np = fake_imgs.numpy()
                real_a_np = real_imgs_a.numpy()
                real_b_np = real_imgs_b.numpy()

                s1 = ssim(fake_np, real_a_np) #fake, original
                ssim_1.append(s1)

                s2 = ssim(fake_np, real_b_np) #fake, target
                ssim_2.append(s2)

                s3 = ssim(real_a_np, real_b_np) #original, target
                ssim_3.append(s3)


        
            test_ssim_FO.append(sum(ssim_1) / len(ssim_1))
            test_ssim_FT.append(sum(ssim_2) / len(ssim_2))
            test_ssim_OT.append(sum(ssim_3) / len(ssim_3))

            print('fake, original = ', (sum(ssim_1) / len(ssim_1)))
            print('fake, target = ', (sum(ssim_2) / len(ssim_2)))
            print('original, target = ', (sum(ssim_3) / len(ssim_3)))
            #################################
            
            
        del a
        del b
        del real_label
        del fake_label
        del encoded_info

In [None]:
# Plot loss history of generator | discriminator | training, testing SSIM 

plt.figure(figsize=(10,5))
plt.title('Loss Progress')
plt.plot(loss_hist['gen'], label='Gen. Loss')
plt.title('Gen Loss')
plt.xlabel('epoch count')
plt.ylabel('Loss')
plt.show()


plt.figure(figsize=(10,5))
plt.plot(loss_hist['dis'], label='Dis. Loss')
plt.title('Dis Loss')
plt.ylim(0,3)
plt.xlabel('epoch count')
plt.ylabel('Loss')
plt.legend()
plt.show()



plt.figure(figsize=(10,5))
plt.plot(train_ssim_FO, label='train_FO')
plt.plot(train_ssim_FT, label='train_FT')
plt.plot(train_ssim_OT, label='train_OT')
plt.plot(test_ssim_FO, label='test_FO')
plt.plot(test_ssim_FT, label='test_FT')
plt.plot(test_ssim_OT, label='test_OT')
plt.title('SSIM')
plt.ylim(0,1)
plt.xlabel('epoch count')
plt.ylabel('Loss')
plt.legend()
plt.show()

## Test Model

In [None]:
path2models = '/home/mri-any/GAN_weight/final_migan/'
# Call saved model weights
path2weights_gen = os.path.join(path2models, 'weights_gen_20.pt')
weights = torch.load(path2weights_gen)
model_gen.load_state_dict(weights)

### ssim

In [None]:
test_data = Dataloader('/home/mri-any/GAN_test/orignal/*', '/home/mri-any/GAN_test/target/')
orig, target, info = test_data.preprocess()
ds = TensorDataset(orig, target, info)
test_dataloader = DataLoader(ds, batch_size=1, shuffle=False)

In [None]:
from skimage.metrics import structural_similarity as ssim

ssim_1 = [] #fake, original
ssim_2 = [] #fake, target
ssim_3 = [] #original, target

fid_1 = [] #fake, original
fid_2 = [] #fake, target
fid_3 = [] #original, target

In [None]:
# evaluation model
model_gen.eval()

count = 0

sum_og = 0
sum_ot = 0
sum_tg = 0

# 가짜 이미지 생성
with torch.no_grad():
    for a,b, info in test_dataloader:
        encoded_info = encoder_model(info).to(device)
        encoded_info = encoded_info.reshape(1,3,1,1)
        
        fake_imgs = model_gen(a.to(device), encoded_info).detach().cpu()
        fake_imgs = np.squeeze(fake_imgs[0])
        real_imgs_a = a
        real_imgs_a = np.squeeze(real_imgs_a[0])
        real_imgs_b = b
        real_imgs_b = np.squeeze(real_imgs_b[0])
        
        fake_np = fake_imgs.numpy()
        real_a_np = real_imgs_a.numpy()
        real_b_np = real_imgs_b.numpy()
        
        s1 = ssim(fake_np, real_a_np) #fake, original
        ssim_1.append(s1)
        
        s2 = ssim(fake_np, real_b_np) #fake, target
        ssim_2.append(s2)
        
        s3 = ssim(real_a_np, real_b_np) #original, target
        ssim_3.append(s3)
        
        
        if count % 300 == 0:
            fig = plt.figure()
            ax1 = fig.add_subplot(1,3,1)
            ax2 = fig.add_subplot(1,3,2)
            ax3 = fig.add_subplot(1,3,3)
            ax1.set_title('original: year 1')
            ax2.set_title('fake')
            ax3.set_title('target: year 3')
            ax1.imshow(real_imgs_a, cmap = 'gray')
            ax2.imshow(fake_imgs, cmap = 'gray')
            ax3.imshow(real_imgs_b, cmap = 'gray')
            plt.show()
        count += 1
        
        
print('fake, original = ', (sum(ssim_1) / len(ssim_1)))
print('fake, target = ', (sum(ssim_2) / len(ssim_2)))
print('original, target = ', (sum(ssim_3) / len(ssim_3)))