In [None]:
#============Make necessary imports==========#

import torch
import torch.nn as nn
from torch.nn import init
import functools
import torch.autograd as autograd
import numpy as np
import torchvision.models as models
from torch.autograd import Variable
import os
import matplotlib.pyplot as plt
import cv2
from PIL import Image
from torchvision import transforms
from torchvision.transforms import ToTensor, Resize, Compose
from torch.utils.data import DataLoader, TensorDataset,Dataset
import torch.nn.functional as F
from collections import OrderedDict
import math
import time
import random

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
#=========Find the correct resizing dimesion=========#


"""
We want to decrease the dimension of the images while making all of them consistent. 
One way to do so is by checking aspect ratio and finding the dimensions which maintains the AR ( we dont want any image distortion)
"""

def get_image_aspect_ratios(image_folder):
    aspect_ratios = {}
    for filename in os.listdir(image_folder):
        if filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
            image_path = os.path.join(image_folder, filename)
            with Image.open(image_path) as img:
                width, height = img.size
                aspect_ratio = width / height
                aspect_ratios[filename] = aspect_ratio
    return aspect_ratios


image_folder = "/write down the path/"
aspect_ratios = get_image_aspect_ratios(image_folder)

for filename, aspect_ratio in aspect_ratios.items():
    print(f'{filename}: {aspect_ratio:.2f}')


"""
In my case its 1.78 and (224, 128) maintains that ratio (almost).
You don't have to 100% get the ratio but similar
"""

In [None]:
#============Importing Images============#

"""
I used the Go Pro dataset.
Link: https://www.kaggle.com/datasets/jishnuparayilshibu/a-curated-list-of-image-deblurring-datasets
"""

def load_images_to_tensor(folder_path, max_images=100):
    tensor_list = []
    image_files = os.listdir(folder_path)
    image_files.sort(key=lambda s: s.lower()) # sorting the sequence blurred and sharp


    transform = transforms.Compose([
        transforms.Resize((224,128)), 
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # normalize for stable training
    ])

    for img_filename in image_files[:max_images]:
        img_path = os.path.join(folder_path, img_filename)
        image = Image.open(img_path)
        tensor = transform(image).to(device)
        tensor_list.append(tensor)

    return tensor_list

blur_path = r"/path to your blurred_data"
sharp_path = r"/path to your sharped_data"
max_images = 1000 # We will use 1000 images for our model

motion_blurred_tensors = load_images_to_tensor(blur_path, max_images)
sharp_tensors = load_images_to_tensor(sharp_path, max_images)

#Simple train test split
motion_blurred_train = motion_blurred_tensors[:int(max_images*0.8)]
motion_blurred_test = motion_blurred_tensors[int(max_images*0.8):int(max_images)]
sharp_tensors_train = sharp_tensors[:int(max_images*0.8)]
sharp_tensors_test = sharp_tensors[int(max_images*0.8):int(max_images)]

#Stack the tensors
motion_blurred_stack = torch.stack(motion_blurred_train)
sharp_stack = torch.stack(sharp_tensors_train)
motion_blurred_stack_t = torch.stack(motion_blurred_test)
sharp_stack_t = torch.stack(sharp_tensors_test)

#Make dataset
train_dataset = TensorDataset(motion_blurred_stack, sharp_stack)
test_dataset = TensorDataset(motion_blurred_stack_t, sharp_stack_t)

#Loading the dataloader
train_data_loader = DataLoader(train_dataset, batch_size=2)
test_data_loader = DataLoader(test_dataset, batch_size=2)

print(motion_blurred_stack.shape)  # Expected shape: [num_samples, 3, 224, 128]
print(sharp_stack.shape)           # Expected shape: [num_samples, 3, 224, 128]


In [None]:
#=============Helper Functions===========#

#needed fopr Unet
class ImagePool():
    def __init__(self, pool_size=50):
        self.pool_size = pool_size
        if self.pool_size > 0:
            self.num_imgs = 0
            self.images = []

    def query(self, images):
        if self.pool_size == 0:
            return images
        return_images = []
        for image in images.data:
            image = torch.unsqueeze(image, 0)
            if self.num_imgs < self.pool_size:
                self.num_imgs = self.num_imgs + 1
                self.images.append(image)
                return_images.append(image)
            else:
                p = random.uniform(0, 1)
                if p > 0.5:
                    random_id = random.randint(0, self.pool_size-1)
                    tmp = self.images[random_id].clone()
                    self.images[random_id] = image
                    return_images.append(tmp)
                else:
                    return_images.append(image)
        return_images = Variable(torch.cat(return_images, 0))
        return return_images
    

#we need to apply initial weights to our model
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        nn.init.normal_(m.weight, 0, 0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)
    if isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, 1, 0.02)
        nn.init.zeros_(m.bias)


def tensor2im(image_tensor, imtype=np.uint8):
    image_numpy = image_tensor[0].cpu().float().detach().numpy()
    image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
    return image_numpy.astype(imtype)


def get_visuals(real_A, Generator_out, blur_imgs):
    sharp_imgs= tensor2im(real_A)
    Generator_out = tensor2im(Generator_out)
    blur_imgs = tensor2im(blur_imgs)
    return OrderedDict([('Blurred_Train', real_A), ('Restored_Train', Generator_out), ('Sharp_Train', blur_imgs)])


def parameter_count(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    return num_params

#Metrics
def mse(img1, img2):
    return np.linalg.norm(img1 - img2)

def PSNR(img1, img2):

    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    pixel_max = 255.0
    return 20 * math.log10(pixel_max / math.sqrt(mse))

In [None]:
#===============Discriminator==============#

"""
For discriminator we are implementing patchGAN
Link: https://paperswithcode.com/method/patchgan
"""

class Discriminator(nn.Module):

    """
    input_nc : input number of channel
    ndf: discriminator's filter
    n_layers: convolutional layers
    norm: we use instance norm because of smaller batch size
    bias: we use bias only for instance norm, not for batch norm
    """

    def __init__(self, input_nc=3, ndf=64, n_layers=3, norm_layer=nn.InstanceNorm2d, use_bias=True):
        super(Discriminator, self).__init__()
    
        
        kw = 4
        padw = int(np.ceil((kw - 1) / 2)) # so that the output has same spatial dimension as input
        layers = [
            nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
            nn.LeakyReLU(0.2, True)
        ]

        nf_mult = 1
        nf_mult_prev = 1
        for n in range(1, n_layers):
            nf_mult_prev = nf_mult
            nf_mult = min(2 ** n, 8)
            layers += [
                nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
                          kernel_size=kw, stride=2, padding=padw, bias=use_bias),
                norm_layer(ndf * nf_mult),
                nn.LeakyReLU(0.2, True)
            ]

        nf_mult_prev = nf_mult
        nf_mult = min(2 ** n_layers, 8)
        layers += [
            nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
            norm_layer(ndf * nf_mult),
            nn.LeakyReLU(0.2, True)
        ]

        layers += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        out= self.model(x)
        return out


In [None]:
#==============ResNet Generator============#

class ResnetGenerator(nn.Module):
    def __init__(
            self, input_nc=3, output_nc=3, ngf=64, norm_layer=nn.InstanceNorm2d, use_bias=True, use_dropout=False,
            n_blocks=9, learn_residual=False, padding_type='reflect'):
        
        """
        input_nc: input channels
        output_nc: output channels
        ngf: number of generator's filters
        learn_residual : whether or not we want skip connections
        """

        super(ResnetGenerator, self).__init__()

        #Initial Layer
        layers = [
            nn.ReflectionPad2d(3),
            nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(True)
        ]


        #Downsample
        n_downsampling = 2
        layers += [
            nn.Conv2d(ngf, ngf*2, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(ngf*2),
            nn.ReLU(True),

            nn.Conv2d(ngf*2, ngf*4, kernel_size=3, stride=2, padding=1, bias=use_bias),
            norm_layer(ngf*4),
            nn.ReLU(True)
        ]

        #Residual blocks        
        for i in range(n_blocks):
           
            layers += [
                ResnetBlock(ngf*4, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)
            ]

        #Upsample
        layers += [
            nn.ConvTranspose2d(ngf*4, ngf*2, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            norm_layer(ngf*2),
            nn.ReLU(True),

            nn.ConvTranspose2d(ngf*2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1, bias=use_bias),
            norm_layer(ngf),
            nn.ReLU(True),
        ]

        #Final Layer
        layers += [
            nn.ReflectionPad2d(3),
            nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
            nn.Tanh()
        ]

        self.model = nn.Sequential(*layers)

    def forward(self, x):
        
        out= self.model(x)
        if self.learn_residual:
            out+=x
            out = torch.clamp(out, min=-1, max=1)
        return out


class ResnetBlock(nn.Module):

	def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
		super(ResnetBlock, self).__init__()

		padAndConv = {
			'reflect': [
                nn.ReflectionPad2d(1),
                nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
			'replicate': [
                nn.ReplicationPad2d(1),
                nn.Conv2d(dim, dim, kernel_size=3, bias=use_bias)],
			'zero': [
                nn.Conv2d(dim, dim, kernel_size=3, padding=1, bias=use_bias)]
		}

		try:
			blocks = padAndConv[padding_type] + [
				norm_layer(dim),
				nn.ReLU(True)
            ] + [
				nn.Dropout(0.5)
			] if use_dropout else [] + padAndConv[padding_type] + [
				norm_layer(dim)
			]
		except:
			raise NotImplementedError('padding [%s] is not implemented' % padding_type)

		self.conv_block = nn.Sequential(*blocks)

		
	def forward(self, x):
		out = x + self.conv_block(x)
		return out



In [None]:
#==========debug code=========#

model= ResnetGenerator()
x= torch.randn(1,3,224,128)
out= model(x)
print(out.shape)

In [None]:
#=================Unet Generator=================#

"""
The original code has also implemented Unet Generator.
So choose based on your task and performance
"""

class UnetGenerator(nn.Module):
    def __init__(
            self, input_nc=3, output_nc=3, num_downs=8, ngf=64, norm_layer=nn.BatchNorm2d,
            use_dropout=False, learn_residual=False):
        super(UnetGenerator, self).__init__()
   
        self.learn_residual = learn_residual
    
        # construct unet structure
        unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, norm_layer=norm_layer, innermost=True)
        for i in range(num_downs - 5):
            unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, unet_block, norm_layer=norm_layer,
                                                 use_dropout=use_dropout)
        unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, unet_block, norm_layer=norm_layer)
        unet_block = UnetSkipConnectionBlock(output_nc, ngf, unet_block, outermost=True, norm_layer=norm_layer)

        self.model = unet_block

    def forward(self, input):
        if self.gpu_ids and isinstance(input.data, torch.cuda.FloatTensor) and self.use_parallel:
            output = nn.parallel.data_parallel(self.model, input, self.gpu_ids)
        else:
            output = self.model(input)
        if self.learn_residual:
            output = input + output
            output = torch.clamp(output, min=-1, max=1)
        return output


class UnetSkipConnectionBlock(nn.Module):
    def __init__(
            self, outer_nc=3, inner_nc=3, submodule=None,
            outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
        super(UnetSkipConnectionBlock, self).__init__()
        self.outermost = outermost
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        dConv = nn.Conv2d(outer_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
        dRelu = nn.LeakyReLU(0.2, True)
        dNorm = norm_layer(inner_nc)
        uRelu = nn.ReLU(True)
        uNorm = norm_layer(outer_nc)

        if outermost:
            uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1)
            dModel = [dConv]
            uModel = [uRelu, uConv, nn.Tanh()]
            model = [
                dModel,
                submodule,
                uModel
            ]
     
        elif innermost:
            uConv = nn.ConvTranspose2d(inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            dModel = [dRelu, dConv]
            uModel = [uRelu, uConv, uNorm]
            model = [
                dModel,
                uModel
            ]
     
        else:
            uConv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias)
            dModel = [dRelu, dConv, dNorm]
            uModel = [uRelu, uConv, uNorm]

            model = [
                dModel,
                submodule,
                uModel
            ]
            model += [nn.Dropout(0.5)] if use_dropout else []


        self.model = nn.Sequential(*model)

    def forward(self, x):
        if self.outermost:
            return self.model(x)
        else:
            return torch.cat([self.model(x), x], 1)


In [None]:
#=====================Losses======================#

"""
Loss functions simplified:

Loss_generator = adversarial_loss_gen + content_loss * c_lambda 
Loss_discriminator = adversarial_loss_disc + gradient_penalty * g_lambda
We use wasserstain loss as the adversarial loss.

Typical adversarial loss:
adversarial_loss_gen = - Discriminator(generated_deblurred_imgs).mean
(It acts like a penalty for generator's loss)
adversarial_loss_disc = Discriminator(generated_deblurred_imgs).mean - Discriminator(True_sharp_imgs).mean

For content loss, we use L1 loss
but for this we use perception loss. 

c_lambda= 100
g_lamda= 10

"""


"""
Perception_loss paper: https://arxiv.org/abs/1603.08155
"""

def perceptual_loss(y_true, y_pred):
    conv_3_3_layer = 14
    cnn = models.vgg19(pretrained=True).features
    cnn = cnn.cuda()
    model = nn.Sequential()
    model = model.cuda()
    for i, layer in enumerate(list(cnn)):
        model.add_module(str(i), layer)
        if i == conv_3_3_layer:
            break
    criterion = nn.MSELoss()
    fake = model.forward(y_pred)
    real = model.forward(y_true)
    f_real = real.detach()
    loss = criterion(fake, f_real)
    return loss

"""
Gradient penalty for stable training
Link: https://paperswithcode.com/method/wgan-gp 
"""

def calc_gradient_penalty(netD, real_data, fake_data):
    alpha = torch.rand(1, 1)
    alpha = alpha.expand(real_data.size())
    alpha = alpha.cuda()
    
    # alpha = torch.rand(real_data.size(0), 1, 1, 1).cuda()    
    interpolates = alpha * real_data + ((1 - alpha) * fake_data)
    interpolates = interpolates.cuda()
    interpolates = Variable(interpolates, requires_grad=True)    

    disc_interpolates = netD.forward(interpolates)    
    gradients = autograd.grad(
        outputs=disc_interpolates,
        inputs=interpolates,
        grad_outputs=torch.ones(disc_interpolates.size()).cuda(),
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    # gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * 10    # we already multiplied by lambda here
    return gradient_penalty



In [None]:
#==================Train Function ==================#


def train(dataset, D, G, optimizer_G, optimizer_D ,bs, n_epoch, d_training, g_training):

    d_loss_total = []
    g_loss_total = []
    steps = 0

    for epoch in range(1, n_epoch+1):
        print("Running epoch:", epoch)
        start_time_epoch = time.time()
        d = 0
        g= 0
        for i, (blur, sharp) in enumerate(dataset):

            steps += bs
            blur_imgs = blur.to(device)
            sharp_imgs= sharp.to(device)
            Generator_out = G.forward(blur_imgs)

            """
            I added a loop for generator too.
            Its for the case when discriminator is way stronger and you might wanna train generator frequently to catch up
            If the model is collapsing, decrease disc's learning rate/ train the generator more frequently to balance.
            There are other ways to fix model collapsing, check: https://www.reddit.com/r/MachineLearning/comments/i085a8/d_best_gan_tricks/
            """

            # =======================Train the discriminator=======================#
            for iter in range(d_training):
                optimizer_D.zero_grad()

                # Discriminator outputs
                real_validity = D(sharp_imgs)
                fake_validity= D(Generator_out.detach())

                # Gradient penalty
                gradient_penalty = calc_gradient_penalty(D, sharp_imgs.data, Generator_out.data)
                # gradient_penalty=0

                d_loss = fake_validity.mean() - real_validity.mean() + gradient_penalty
                d_loss.backward(retain_graph=True)

                optimizer_D.step()
                if iter == d_training-1:
                    d += d_loss.item() # add the final loss of iteration

        
            #========================Train the generator===========================#
            for iter in range(g_training):
                optimizer_G.zero_grad()

                Generator_out = G(blur_imgs)
                fake_validity = D(Generator_out)
                g_adv_loss = -fake_validity.mean()
                g_contentloss = perceptual_loss(Generator_out, sharp_imgs) * 100
                g_total_loss = g_adv_loss + g_contentloss
                g_total_loss.backward() # no need to retain the graph

                
                """
                If training doesn't seem stable, add gradient clipping
                torch.nn.utils.clip_grad_norm_(G.parameters(), max_norm=1.0)
                """
                
                optimizer_G.step()
                if iter == g_training - 1:
                    g += g_total_loss.item()


            # pnsr metric
            if steps % 1000 == 4:
                image_res = get_visuals(sharp_imgs, Generator_out, blur_imgs)
                psnr = PSNR(image_res['Restored_Train'], image_res['Sharp_Train'])
                print('PSNR on Train (at epoch {0}) = {1}'.format(epoch, psnr))

            g+= g_total_loss.item()

        d_loss_total.append(d/len(dataset))
        g_loss_total.append(g/len(dataset))

        #saving model after every 50 epochs
        if epoch % 50 == 0:
            torch.save(G.state_dict(), 'Generator_' + str(epoch) + '.pt')
            torch.save(D.state_dict(), 'Discriminator' + str(epoch) + '.pt')
        
        end_time_epoch = time.time()

        print("Time for epoch {0}: {1} | Disc loss: {2}  | Gen loss: {3}".format(epoch, (end_time_epoch - start_time_epoch), d_loss_total[epoch-1], g_loss_total[epoch-1]))

    plt.figure()
    plt.plot(d_loss_total, label='Discriminator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Discriminator Loss Over Epochs')
    plt.legend()
    plt.savefig('Discriminator_loss.png')
    plt.show()

    plt.figure()
    plt.plot(g_loss_total, label='Generator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Generator Loss Over Epochs')
    plt.legend()
    plt.savefig('Generator_loss.png')
    plt.show()




In [None]:
#==============Define parameters=============#

class Parameters:
    g_lr = 0.0001
    d_lr=0.0001
    g_training=1
    d_training = 5
    beta1 = 0.5 # adam opt constants
    beta2 = 0.999
    batchsize = 2
    n_epochs=100

In [None]:
#==============Start Training==============#

if __name__ == "__main__":
    
    D = Discriminator().to(device)
    G = ResnetGenerator().to(device)
    p= Parameters()
    optimizer_G = torch.optim.Adam(G.parameters(), lr=p.g_lr, betas=(p.beta1, p.beta2))
    optimizer_D = torch.optim.Adam(D.parameters(), lr=p.d_lr, betas=(p.beta1, p.beta2))
    G.apply(init_weights)
    D.apply(init_weights)

    
    p_d= parameter_count(D)
    p_g= parameter_count(G)

    print('ㅎ ㅡ ㅎ Begin Training ㅎ ㅡ ㅎ')
    print("Parameters of Discriminator:",p_d)
    print("Parameters of Generattor", p_g)  
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~')

    start_time = time.time()
    train(train_data_loader, D, G, optimizer_G, optimizer_D, p.batchsize, p.n_epochs, p.d_training, p.g_training)
    end_time= time.time()
    print("Total time for training:", end_time - start_time)
   

In [None]:
#==============Predicting Images===============#

#Load the trained model 
generator = ResnetGenerator().to(device)
generator.load_state_dict(torch.load('saved_name', map_location=device))

def imshow(tensor, title=None):

    if isinstance (tensor, np.ndarray):
        tensor= torch.from_numpy(tensor)
        
    mean = torch.tensor([0.5, 0.5, 0.5])
    std = torch.tensor([0.5, 0.5, 0.5])

    tensor = tensor.clone().detach()
    if tensor.dim() == 4 and tensor.shape[0] == 1:
        tensor = tensor.squeeze(0)

    tensor = tensor * std[:, None, None] + mean[:, None, None] # Denormalize
    tensor = tensor.permute(1, 2, 0)  # CxHxW -> HxWxC
    tensor = torch.clamp(tensor, 0, 1)  # Clamp values to ensure they are between 0 and 1

    plt.imshow(tensor)
    if title:
        plt.title(title)
    plt.axis('off')


generator.eval()
indices = list(range(len(test_data_loader.dataset)))
# random.shuffle(indices)
with torch.no_grad():
      for set_num in range(5):
            set_indices = indices[set_num * 2: (set_num + 1) * 2]  # Adjust the batch size to your needs

            for i, idx in enumerate(set_indices):
                blurred_imgs, true_sharp_imgs = test_data_loader.dataset[idx]
                blurred_imgs = blurred_imgs.unsqueeze(0).to(device)
                true_sharp_imgs = true_sharp_imgs.unsqueeze(0).to(device)
                predicted_sharp_imgs = generator(blurred_imgs).to(device)

                plt.figure(figsize=(15, 5))

                # Blurred images
                plt.subplot(1, 3, 1)
                imshow(blurred_imgs[0].cpu().numpy(), title="Blurred")

                # True sharp images
                plt.subplot(1, 3, 2)
                imshow(true_sharp_imgs[0].cpu().numpy(), title="True Sharp")

                # Predicted sharp images
                plt.subplot(1, 3, 3)
                imshow(predicted_sharp_imgs[0].cpu().numpy(), title="Predicted Sharp")

                plt.show()


In [None]:
#=======continue training=====#

def load_model_checkpoint(generator_path, discriminator_path, G, D):
    G.load_state_dict(torch.load(generator_path))
    D.load_state_dict(torch.load(discriminator_path))

def continue_training(dataset, D, G, bs, start_epoch, n_epoch, device):
    
    d_loss_total = []
    g_loss_total = []
    steps = 0
    d_training = 5
    g_training = 1

    for epoch in range(start_epoch, n_epoch+1):
        print("Running epoch:", epoch)
        start_time_epoch = time.time()
        d = 0
        g = 0
        for i, (blur, sharp) in enumerate(dataset):
            steps += bs
            blur_imgs = blur.to(device)
            sharp_imgs = sharp.to(device)
            Generator_out = G(blur_imgs)

            # =======================Train the discriminator=======================#
            for iter in range(d_training):
                optimizer_D.zero_grad()

                # Discriminator outputs
                real_validity = D.forward(sharp_imgs)
                fake_validity = D.forward(Generator_out.detach())

                # Gradient penalty
                gradient_penalty = calc_gradient_penalty(D, sharp_imgs.data, Generator_out.data)
                # gradient_penalty = 0

                d_loss = fake_validity.mean() - real_validity.mean() + gradient_penalty
                d_loss.backward(retain_graph=True)

                optimizer_D.step()
                if iter == d_training-1:
                    d += d_loss.item()

            # ========================Train the generator===========================#
            for iter in range(g_training):
                optimizer_G.zero_grad()

                Generator_out = G(blur_imgs)
                fake_validity = D(Generator_out)
                g_adv_loss = -fake_validity.mean()
                g_contentloss = perceptual_loss(Generator_out, sharp_imgs) * 100
                g_total_loss = g_adv_loss + g_contentloss
                g_total_loss.backward()
                optimizer_G.step()

                if iter == g_training - 1:
                    g += g_total_loss.item()

            # PSNR metrics
            if steps % 1000 == 4:
                image_res = get_visuals(sharp_imgs, Generator_out, blur_imgs)
                psnr = PSNR(image_res['Restored_Train'], image_res['Sharp_Train'])
                print('PSNR on Train (at epoch {0}) = {1}'.format(epoch, psnr))

            g += g_total_loss.item()

        d_loss_total.append(d/len(dataset))
        g_loss_total.append(g/len(dataset))
  
        # saving model after every 50 epochs
        if epoch % 50 == 0:
            torch.save(G.state_dict(), 'Generator_' + str(epoch) + '.pt')
            torch.save(D.state_dict(), 'Discrminator_' + str(epoch) + '.pt')
       
        end_time_epoch = time.time()

        print("Time for epoch {0}: {1} | Disc loss: {2}  | Gen loss: {3}".format(epoch, (end_time_epoch - start_time_epoch), d_loss_total[epoch-1], g_loss_total[epoch-1]))

    plt.figure()
    plt.plot(d_loss_total, label='Discriminator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Discriminator Loss Over Epochs')
    plt.legend()
    # plt.savefig('Discriminator_loss.png')
    plt.show()

    plt.figure()
    plt.plot(g_loss_total, label='Generator Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Generator Loss Over Epochs')
    plt.legend()
    # plt.savefig('Generator_loss.png')
    plt.show()


In [None]:
if __name__ == "__main__":


    D = Discriminator().to(device)
    G = ResnetGenerator().to(device)
    p= Parameters()
    optimizer_G = torch.optim.Adam(G.parameters(), lr=p.g_lr, betas=(p.beta1, p.beta2))
    optimizer_D = torch.optim.Adam(D.parameters(), lr=p.d_lr, betas=(p.beta1, p.beta2))
    
    #load the pretrained models
    load_model_checkpoint('Generator_trained_100epochs.pt', 'Discriminator_trained_100epochs', G, D)
    
    #DO not apply init weights

    print('ㅠ ㅠ Continue Training ㅠ ㅠ ')
    start_time = time.time()
    continue_training(train_data_loader, D, G, p.batchsize, start_epoch=101, n_epoch=251, device=device)
    end_time= time.time()
    print("Total time for training:", end_time- start_time)
