<a href="https://colab.research.google.com/github/bansaldolly527-lab/deeplearningproj/blob/main/esrgan.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from torch import nn



![image.png](attachment:bcc81a09-bdd0-4e6b-beb1-8e708bfc604f.png)



In [None]:
class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_activation= True ):
        super().__init__()
        self.conv1= nn.Conv2d(in_channels, out_channels, kernel_size= kernel_size, stride=stride , padding=padding)
        self.activation= nn.LeakyReLU(inplace=True) if use_activation else nn.Identity()

    def forward(self,x):
        x=self.conv1(x)
        x=self.activation(x)
        return x

In [None]:
class DenseBlock(nn.Module):
    def __init__(self, in_channels, channels=64, beta=0.2):
        super().__init__()
        blocks=[]
        for i in range(5):
            blocks.append(Conv(in_channels=in_channels+channels*i,
                               out_channels=channels if i <4 else in_channels  ,
                               kernel_size=3,
                               stride=1,
                               padding=1,
                              use_activation= True if i<4 else False
                              ))

        self.conv_layers=nn.Sequential(*blocks)
        self.beta= beta

    def forward(self, x):
        inputs = x
        for layer in self.conv_layers:
            out = layer(inputs)
            inputs = torch.cat([inputs, out], dim=1) # last concat is unnecessary
        return out * self.beta + x                   # scale the outputs of the dense block







In [None]:
class RRDB(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        blocks=[]
        for i in range(3):
            blocks.append(DenseBlock(in_channels))

        self.layers=nn.Sequential(*blocks)

    def forward(self, x):
        inputs=x
        for layer in self.layers:
            x=layer(x)
        return 0.2*x+inputs

tensor([[[[0.4554, 0.9994, 0.4231,  ..., 1.1095, 0.1548, 1.0674],
          [0.7988, 0.5274, 0.3004,  ..., 0.8801, 0.8844, 0.1626],
          [0.5289, 0.6769, 0.1701,  ..., 0.6001, 0.8835, 0.2978],
          ...,
          [0.1087, 0.3849, 0.5856,  ..., 0.6478, 0.4746, 0.9044],
          [0.7340, 0.5345, 0.6894,  ..., 0.5598, 1.0363, 1.0023],
          [0.4683, 0.5732, 0.8683,  ..., 1.0621, 0.0164, 0.2339]],

         [[1.0692, 0.8053, 0.0503,  ..., 0.1192, 0.7475, 0.0942],
          [0.1851, 1.1323, 0.9045,  ..., 0.2940, 1.1735, 0.4155],
          [0.8681, 0.3606, 1.1128,  ..., 0.1313, 0.8321, 0.2090],
          ...,
          [1.1821, 0.2077, 0.4883,  ..., 1.0299, 0.4221, 0.5283],
          [0.3568, 0.3638, 0.1070,  ..., 0.7025, 0.1931, 0.8730],
          [0.0996, 0.9771, 0.1487,  ..., 0.0621, 1.1503, 0.6464]],

         [[0.8035, 0.8663, 0.3375,  ..., 1.0179, 0.0452, 1.1124],
          [0.6998, 0.2042, 0.3232,  ..., 0.1930, 0.2389, 0.6999],
          [1.0749, 0.9602, 1.0164,  ..., 0

In [None]:
x= torch.rand(1,3,28,28)
x
rrdb=RRDB(3)
rrdb(x).shape

torch.Size([1, 3, 28, 28])

In [None]:
class Upsample(nn.Module):
    def __init__(self, in_channels, scale_factor):
        super().__init__()
        self.conv1= nn.Conv2d(in_channels=in_channels, out_channels=in_channels, kernel_size=3, padding=1, stride=1)
        self.upsample = nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)
        self.activation= nn.LeakyReLU(0.2, inplace=True)

    def forward(self,x):
        x=self.upsample(x)
        x=self.conv1(x)
        x=self.activation(x)
        return x
ups=Upsample(3,200)
out=ups(x)
out.shape

torch.Size([1, 3, 5600, 5600])

In [None]:
class Generator(nn.Module):
    def __init__(self, in_channels, channels=64, scale_factor=2):
        super().__init__()
        self.conv1=nn.Conv2d(in_channels, channels, kernel_size=3, padding=1, stride=1 )
        blocks=[]
        for i in range(23):
            blocks.append(RRDB(channels))
        self.rrdb=nn.Sequential(*blocks)
        self.conv2=Conv(channels, channels, kernel_size=3, stride=1 , padding=1 )
        self.upsample=Upsample(channels, scale_factor)
        self.conv3=Conv(channels, channels, kernel_size=3, padding=1, stride=1)
        self.conv4=Conv(channels, in_channels, kernel_size=3, padding=1, stride=1, use_activation=False)

    def forward(self,x):
        x=self.conv1(x)
        residual=x
        x=self.rrdb(x)
        x=self.conv2(x)
        x=x+residual
        x=self.upsample(x)
        x=self.upsample(x)
        x=self.conv3(x)
        x=self.conv4(x)
        return x
gen=Generator(3)
otp=gen(x)
otp.shape

torch.Size([1, 3, 112, 112])

In [None]:
#discrimator
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512])
        super().__init__()
        blocks=[]
        for idx, feature in enumerate(features):
            blocks.append(Conv(in_channels, feature, kernel_size=3, stride=1+idx%2, padding=1, use_activation=True))
            in_channels=feature
        self.blocks=nn.Sequential(*blocks)
        self.classifier= nn.Sequential(
            nn.AdaptiveAvgPoool2d((6,6)),
            nn.Flatten(),
            nn.Linear(512*6*6,1024)
            nn.LeakyReLU(0.2, inplace=True)
            nn.Linear(1024,1)
        )
    def forward(self, x):
        x=self.blocks(x)
        return self.classifier(x)


In [None]:
def initialize_weights(model, scale=0.1)
    for m in model.module():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale

        elif isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight.data)
            m.weight.data *= scale


In [None]:
def test():
    gen=Generator(3)
    disc=Discriminator()
    low_res=24
    x=torch.randn((5, 3, low_res))
    gen_out= gen(x)
    disc_out=disc(gen_out)

    print(gen_out.shape)
    print(disc_out.shape)
test()

In [None]:
from torchvision.models import vgg19
# import config

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg=vgg19(pretrained=True).features[:35].eval()


        for param in self.vgg.parameters():
            param.requires_grad=False
        self.loss=nn.MSELoss()

    def forward(self, input, target):
        vgg_input_features=self.vgg(input)
        vgg_target_features= self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)


In [None]:
import torch
import config
from torch import nn
from torch import optim
from utils import gradient_penalty, load_checkpoint, save_checkpoint, plot_examples
from loss import VGGLoss
from torch.utils.data import DataLoader
from model import Generator, Discriminator, initialize_weights
from tqdm import tqdm
from dataset import MyImageFolder
from torch.utils.tensorboard import SummaryWriter

torch.backends.cudnn.benchmark = True

def train_fn(loader, disc, gen, opt_gen, opt_disc, l1, vgg_loss, g_scaler, d_scaler, writer, tb_step,):
    loop = tqdm(loader, leave=True)

    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(config.DEVICE)
        low_res = low_res.to(config.DEVICE)

        with torch.cuda.amp.autocast():
            fake = gen(low_res)
            critic_real = disc(high_res)
            critic_fake = disc(fake.detach())
            gp = gradient_penalty(disc, high_res, fake, device=config.DEVICE)
            loss_critic = (-(torch.mean(critic_real) - torch.mean(critic_fake))+ config.LAMBDA_GP )

        opt_disc.zero_grad()
        d_scaler.scale(loss_critic).backward()
        d_scaler.step(opt_disc)

        d_scaler.update()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        with torch.cuda.amp.autocast():
            l1_loss = 1e-2 * l1(fake, high_res)
            adversarial_loss = 5e-3 * -torch.mean(disc(fake))
            loss_for_vgg = vgg_loss(fake, high_res)
            gen_loss = l1_loss + loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        g_scaler.scale(gen_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()

        writer.add_scalar("Critic loss", loss_critic.item(), global_step=tb_step)
        tb_step += 1

        if idx % 100 == 0 and idx > 0:
            plot_examples("test_images/", gen)

        loop.set_postfix(gp=gp.item(),critic=loss_critic.item(),l1=l1_loss.item(),vgg=loss_for_vgg.item(),adversarial=adversarial_loss.item(),)

    return tb_step


def main():
    dataset = MyImageFolder(root_dir="data/")
    loader = DataLoader(
        dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=config.NUM_WORKERS,
    )
    gen = Generator(in_channels=3).to(config.DEVICE)
    disc = Discriminator(in_channels=3).to(config.DEVICE)
    initialize_weights(gen)
    opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
    opt_disc = optim.Adam(disc.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
    writer = SummaryWriter("logs")
    tb_step = 0
    l1 = nn.L1Loss()
    gen.train()
    disc.train()
    vgg_loss = VGGLoss()

    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()

    if config.LOAD_MODEL:
        load_checkpoint(
            config.CHECKPOINT_GEN,
            gen,
            opt_gen,
            config.LEARNING_RATE,
        )
        load_checkpoint(
            config.CHECKPOINT_DISC,
            disc,
            opt_disc,
            config.LEARNING_RATE,
        )


    for epoch in range(config.NUM_EPOCHS):
        tb_step = train_fn(loader,disc,gen,opt_gen,opt_disc,l1,vgg_loss,g_scaler,d_scaler,writer,tb_step,)

        if config.SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename=config.CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=config.CHECKPOINT_DISC)


if __name__ == "__main__":
    try_model = True

    if try_model:
        # Will just use pretrained weights and run on images
        # in test_images/ and save the ones to SR in saved/
        gen = Generator(in_channels=3).to(config.DEVICE)
        opt_gen = optim.Adam(gen.parameters(), lr=config.LEARNING_RATE, betas=(0.0, 0.9))
        load_checkpoint(
            config.CHECKPOINT_GEN,
            gen,
            opt_gen,
            config.LEARNING_RATE,
        )
        plot_examples("test_images/", gen)
    else:
        # This will train from scratch
        main()