In [22]:
import torch.nn as nn
from torchvision.models import vgg19

In [23]:
class discriminatorLoss(nn.Module):
    
    def __init__(self, generator, discriminator, device):
        
        super().__init__()
        self.device = device
        self.generator = generator
        self.discriminator = discriminator
        self.bceloss = nn.BCELoss().to(device)
        
    def forward(self, LR_image, HR_image):
        
        HR_pred = self.discriminator(HR_image)
        SR_image = self.generator(LR_image)
        SR_pred = self.discriminator(SR_image)
        real_ = torch.ones(HR_pred.shape).to(self.device)
        fake_ = torch.zeros(SR_pred.shape).to(self.device)
        
        HR_loss = self.bceloss(HR_pred, real_)
        SR_loss = self.bceloss(SR_pred, fake_)
        loss = HR_loss + SR_loss
        
        return loss
        

In [24]:
class generatorLoss(nn.Module):
    
    def __init__(self, generator, discriminator, device):
        
        super().__init__()
        
        self.generator = generator
        self.discriminator
        
        vgg19 = vgg19(pretrained=True, progress=True)
        
        vgg_loss = nn.Sequential(*(list(vgg19)[:9])).eval()
        for param in vgg_loss.parameters():
            param.requires_grad = False
        
        self.vgg_features = vgg_loss.to(device)
        self.mseloss = nn.MSELoss.to(device)
        self.bceloss = nn.BCELoss.to(device)
        
        
    def forward(self, LR_image, HR_image):
        
        SR_image = self.generator(LR_image)
        SR_pred = self.discriminator(SR_image)
        real_ = torch.ones(SR_pred.shape).to(self.device)
        
        adversial_loss = self.bceloss(SR_pred, real_)
        perceptual_loss = self.mseloss(self.vgg_features(HR_image), self.vgg_features(SR_image))
        content_loss = self.mseloss(HR_image, SR_image)
        
        return content_loss + 0.001*adversial_loss + 0.006*perceptual_loss
    