# Import Modules 

In [None]:
from model.gen import OrGAN
from model.dnet import CNN
from utils.dataset import XrayDataset, TwoStreamBatchSampler, TXDataset
from utils.evaluate import *
from torch.utils.data import DataLoader 
from torch.utils.data import random_split 
from tqdm import tqdm 
import torch
import torch.nn as nn
import torch.nn.functional as F 
import albumentations as A 
from albumentations.pytorch import ToTensorV2
from torch.autograd import Variable
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
import numpy as np
import math
import cv2
import matplotlib.pyplot as plt

from torchinfo import summary
import pandas as pd

import torch.backends.cudnn as cudnn
import random
import gc

cudnn.benchmark = False
cudnn.deterministic = True
torch.manual_seed(0)
torch.cuda.manual_seed(0)
random.seed(0)
np.random.seed(0) 

# PSNR function

In [None]:
def PSNR(original, pred):
    mse = np.mean((original - pred) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 1.0
    PSNR = 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
    return PSNR

# Parameter Set

In [None]:
BATCH_SIZE_TRAIN = 12
L_BATCH_SIZE_TRAIN = 6
BATCH_SIZE_TEST = 6
learning_rate = 0.0003
dlearning_rate = 0.001
epochs = 600
model_iter = '1'
os.makedirs('model_weights/'+model_iter, exist_ok=True)

# DATASET Defining

In [None]:

train_transforms = A.Compose(
    [
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.ShiftScaleRotate(shift_limit=0.07, scale_limit=0.1, rotate_limit=20, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), p=0.5),
        ], p=0.8),
        A.OneOf([
            A.RandomGamma(gamma_limit=(70.0, 160.0), p=0.6),
            A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.3),
            A.Blur(blur_limit=(3, 7), p=0.1),
            A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.1)
        ], p=0.6),
        A.Resize(width=512, height=512, p=1.0)
    ]
)
train_transforms_realx = A.Compose(
    [
        A.HorizontalFlip(p=0.5), 
        A.OneOf([
            A.ShiftScaleRotate(shift_limit=0.07, scale_limit=0.1, rotate_limit=20, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), p=0.5),
            A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0), p=0.5),
        ], p=0.8),
        A.Resize(width=512, height=512, p=1.0)
    ] 
)

In [None]:
train_xray_directory = 'data/Train/Xray'
train_xray_filenames = sorted(os.listdir(train_xray_directory))
l_id = 624
label_idx = list(range(l_id))
unlabel_idx = list(range(l_id,len(train_xray_filenames)))

test_xray_directory = 'data/Test/Xray'
test_xray_filenames = os.listdir(test_xray_directory)

dataset_train = XrayDataset(train_xray_filenames,train_xray_directory, l_id, transform=train_transforms, r_transform=train_transforms_realx)
dataset_test = TXDataset(test_xray_filenames,test_xray_directory,transform = train_transforms)

In [None]:
batch_sampler = TwoStreamBatchSampler(label_idx, unlabel_idx, BATCH_SIZE_TRAIN, BATCH_SIZE_TRAIN - L_BATCH_SIZE_TRAIN)

In [None]:
Total_datalength = len(dataset_train)
train_loader = DataLoader(dataset_train, batch_sampler=batch_sampler)
test_loader = DataLoader(dataset_test, BATCH_SIZE_TEST, shuffle = True, drop_last=True)

# Model Creation

In [None]:
model = OrGAN()
dnet = CNN()

# Training Phase

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate)
d_optimizer = torch.optim.AdamW(dnet.parameters(), lr = dlearning_rate)

d_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(d_optimizer, 'min', factor=0.5, patience= 5, verbose=True, min_lr=1e-7)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, verbose=False, eta_min=1e-7, T_max=4*len(train_loader))

#Loss functions
mae = torch.nn.L1Loss()
mse = torch.nn.MSELoss()
bce = torch.nn.BCELoss()
nll = torch.nn.NLLLoss()

ssim = MS_SSIM(win_size=11, win_sigma=2, data_range=1, size_average=True, channel=1)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'


model= nn.DataParallel(model,device_ids = [0, 1, 2, 3])
model = model.to(device)

dnet= nn.DataParallel(dnet,device_ids = [0, 1, 2, 3])
dnet = dnet.to(device)

In [None]:
global_step = 0
best_psnr_score = 0
best_dloss = 100.0

column_names = ["TrPSNR", "TePSNR", "TrSSIM", "TeSSIM", "GTrloss", "GTrSuploss", "GTrAdloss", "GTrcloss","DTrloss"]
df = pd.DataFrame(columns=column_names)

for epoch in range(epochs):
    model.train()
    dnet.train()
    
    epoch_loss = 0
    epoch_sloss = 0
    epoch_closs = 0
    epoch_aloss = 0
    epoch_dloss=0
    epoch_psnr = 0
    epoch_ssim = 0
    
    count = 0
    pbar = tqdm(train_loader, dynamic_ncols=True)
    for batch, (images, masks) in enumerate(pbar):
        
        lungs = masks[:L_BATCH_SIZE_TRAIN]
    
        lungs = lungs.unsqueeze(1).to(device).float()
        l_image = images[:L_BATCH_SIZE_TRAIN].to(device).float()
        ul_image = images[L_BATCH_SIZE_TRAIN:].to(device).float()
        
        all_pred, _ = model(images.to(device).float())
        ul_d = dnet(all_pred.detach())
        ul_zero = torch.zeros_like(ul_d).float()
        ul_dloss = bce(ul_d, ul_zero)
        
        l_d = dnet(lungs)
        l_one = torch.ones_like(l_d).float()
        l_dloss = bce(l_d, l_one)
        
        dloss = 0.5*(ul_dloss + l_dloss)
        
        
        d_optimizer.zero_grad()
        dloss.backward()
        d_optimizer.step()
        
        del all_pred, ul_zero, ul_d, ul_dloss, l_d, l_one, l_dloss
        
        l_pred, l_log = model(l_image)
        ul_pred, ul_log = model(ul_image)
        
        
        all_pred, _ = model(images.to(device).float())
        
        l_domain = torch.zeros(L_BATCH_SIZE_TRAIN)
        l_domain = l_domain.to(device).long()
        
        ul_domain = torch.ones(BATCH_SIZE_TRAIN-L_BATCH_SIZE_TRAIN)
        ul_domain = ul_domain.to(device).long()
        
        l_nll_loss = nll(l_log, l_domain)
        ul_nll_loss = nll(ul_log, ul_domain)
        
        sloss = mae(l_pred, lungs) + (1-ssim(l_pred, lungs))
        
        ul_d = dnet(all_pred)
        ul_one = torch.ones_like(ul_d).float()
        
        aloss = bce(ul_d, ul_one)
        
        closs = 0.5*(l_nll_loss + ul_nll_loss)
 
        loss = sloss + 0.01*(aloss + closs)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        del ul_d,ul_one
    
            
        ssim = ssim(l_pred, lungs)
        psnr =  PSNR(lungs.cpu().detach().numpy(), l_pred.cpu().detach().numpy())
             
        epoch_psnr += psnr
        epoch_ssim += ssim.item()
        epoch_loss += loss.item()
        epoch_sloss += sloss.item()
        epoch_aloss += aloss.item()
        epoch_dloss += dloss.item()
        epoch_closs += closs.item()
                
            
        global_step +=1
        count +=1
        pbar.set_postfix({'epoch':epoch, 'bestVPSNR': best_psnr_score}) 
        
        del l_pred, lungs, images, loss, ssim, all_pred, l_nll_loss, ul_nll_loss, aloss, sloss, dloss, closs
        
        gc.collect()
        torch.cuda.empty_cache()
        
        scheduler.step()
        
        division_step = (len(train_loader)// (2))
        if division_step >0:
            if global_step % division_step ==0:
           
                psnr_score, ssim_score = evaluate(model, test_loader, device)
                print('OrGAN | VPSNR(dB): ' + str(round(psnr_score,3)) + '| VSSIM: ' + str(round(ssim_score,3)))

                if (psnr_score > best_psnr_score):
                    model.psnr_score = psnr_score
                    torch.save(model, 'model_weights/'+model_iter+'/best.ckpt')
                    best_psnr_score = p
                    

    temp_dloss = round(((epoch_dloss)/count),3)
    
    if best_dloss>temp_dloss:
        best_dloss = temp_dloss
        
    if d_scheduler is not None:
        d_scheduler.step(best_dloss)

    row = [round(((epoch_psnr)/count),3), psnr_score, round(((epoch_ssim)/count),3), ssim_score, round(((epoch_loss)/count),3), round(((epoch_sloss)/count),3),round(((epoch_aloss)/count),3),round(((epoch_closs)/count),3),round(((epoch_dloss)/count),3)]
    df.loc[len(df)] = row
    
    print('Epoch '+str(epoch))
    print(' TPsnr: '+str(round(((epoch_psnr)/count),3))+' TSsim: '+str(round(((epoch_ssim)/count),3))+' Tloss: '+str(round(((epoch_loss)/count),3))+ ' dloss: '+str(round(((epoch_dloss)/count),3))+ ' adv_loss: '+str(round(((epoch_aloss)/count),3))+ ' sloss: '+str(round(((epoch_sloss)/count),3))+ ' c_loss: '+str(round(((epoch_closs)/count),3)))    

In [None]:
# saving the dataframe
df.to_csv('model_weights/'+model_iter+'epoch_data.csv')