In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn as nn
from torchvision.models import vgg19


In [2]:
class ConvBlock(nn.Module):
    def __init__(self,in_channels,out_channels,generator=True,use_act=True,use_bn=True,**kwargs):
        super(ConvBlock, self).__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, bias=not use_bn, **kwargs)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (nn.PReLU(num_parameters=out_channels) if generator else nn.LeakyReLU(0.2, inplace=True))

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


class ResBlock(nn.Module):
    def __init__(self,in_channels):
        super(ResBlock, self).__init__()
        self.block1 = ConvBlock(in_channels,in_channels,kernel_size=3,stride=1,padding=1)
        self.block2 = ConvBlock(in_channels,in_channels,kernel_size=3,stride=1,padding=1,use_act=False)

    def forward(self,x):
        out = self.block1(x)
        out = self.block2(out)

        return out + x

class Upsample(nn.Module):
    def __init__(self,in_channels,scale_factor=2):
        super(Upsample, self).__init__()
        self.conv = nn.Conv2d(in_channels,in_channels * scale_factor ** 2,3,1,1)
        self.ps = nn.PixelShuffle(scale_factor)
        self.act = nn.PReLU(num_parameters=in_channels)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))

class Generator(nn.Module):
    def __init__(self,in_channels=3,out_channels=64,num_blocks=16):
        super(Generator, self).__init__()
        self.initial = ConvBlock(in_channels,out_channels,kernel_size=9, stride=1, padding=4,use_bn=False)
        self.residuals = nn.Sequential(*[ResBlock(out_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(out_channels,out_channels,kernel_size=3,stride=1,padding=1,use_act=False)
        self.upsamples = nn.Sequential(Upsample(out_channels),Upsample(out_channels))
        self.final = nn.Conv2d(in_channels=64,out_channels=3,kernel_size=9, stride=1, padding=4)

    def forward(self,x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.convblock(x)
        x = x + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))
    

class Critic(nn.Module):
    def __init__(self,in_channels=3,featrues=[64,64,128,128,256,256,512,512]):
        super(Critic, self).__init__()
        block_list = []

        for idx, feature in enumerate(featrues):
            block_list.append(ConvBlock(in_channels=in_channels,
                                    out_channels=feature,
                                    kernel_size=3,
                                    stride=1 + idx % 2,
                                    padding=1,
                                    generator=False,
                                    use_act=True,
                                    use_bn=False if idx == 0 else True))
            in_channels = feature

        self.blocks = nn.Sequential(*block_list)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6,6)),
            nn.Flatten(),
            nn.Linear(512*6*6,1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024,1),
        )

    def forward(self,x):
        x = self.blocks(x)
        return self.classifier(x)





In [6]:
def test():
    low_res = 128
    disc = Critic()
    gen = Generator()
    x = torch.randn((4,3,low_res,low_res))
    bz,c  = x.shape[0],x.shape[1]
    gen_out = gen(x)
    assert gen_out.shape == (bz,c,low_res*4,low_res* 4)
    disc_out = disc(gen_out)
    assert disc_out.shape == (bz,1)
    return 'Tests passed!'

test()

'Tests passed!'

In [7]:
# Hyperparameters
import torch
from PIL import Image
import albumentations as A
from albumentations.pytorch import ToTensorV2

LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "../data/checkpoints/gen.pth.tar"
CHECKPOINT_DISC = "../data/checkpoints/disc.pth.tar"
DEVICE = device = 'mps' if torch.backends.mps.is_available() else 'cpu'
LEARNING_RATE = 1e-4
NUM_EPOCHS = 100
BATCH_SIZE = 16
NUM_WORKERS = 4
HIGH_RES = 96
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

highres_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

lowres_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

both_transforms = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

In [None]:
import torch.nn as nn
from torchvision.models import vgg19


# phi_5,4 5th conv layer before maxpooling but after activation

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)

In [None]:
# Training loop

