This Baseline SRGAN Model was implemented for benchmarking purposes and is referenced by the following implementation: 
https://github.com/aladdinpersson/Machine-Learning-Collection/tree/master/ML/Pytorch/GANs/SRGAN 

In [None]:
import os
import torch
import torch.nn as n
import torch.nn.functional as f
import numpy as np
import os
from torchsummary import summary
import torch.optim as optim
from tqdm import tqdm
from torchvision import models
import cv2
from matplotlib import pyplot as plt
from PIL import Image

Baseline will Convert any image up to 4x scale

In [None]:
import torch
from torch import nn


class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True,
        use_bn=True,
        **kwargs,
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True)
            if discriminator
            else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


class UpsampleBlock(nn.Module):
    def __init__(self, in_c, scale_factor):
        super().__init__()
        self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
        self.ps = nn.PixelShuffle(scale_factor)  # in_c * 4, H, W --> in_c, H*2, W*2
        self.act = nn.PReLU(num_parameters=in_c)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False,
        )

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x


class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.convblock(x) + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    discriminator=True,
                    use_act=True,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)

def test():
    low_resolution = 24  # 96x96 -> 24x24
    with torch.cuda.amp.autocast():
        x = torch.randn((5, 3, low_resolution, low_resolution))
        gen = Generator()
        gen_out = gen(x)
        disc = Discriminator()
        disc_out = disc(gen_out)

        print(gen_out.shape)
        print(disc_out.shape)


if __name__ == "__main__":
    test()


torch.Size([5, 3, 96, 96])
torch.Size([5, 1])


In [None]:
!pip install albumentations==0.4.6
import albumentations as A
from albumentations.pytorch import ToTensorV2

Collecting albumentations==0.4.6
  Downloading albumentations-0.4.6.tar.gz (117 kB)
[K     |████████████████████████████████| 117 kB 7.0 MB/s 
Collecting imgaug>=0.4.0
  Downloading imgaug-0.4.0-py2.py3-none-any.whl (948 kB)
[K     |████████████████████████████████| 948 kB 41.5 MB/s 
Building wheels for collected packages: albumentations
  Building wheel for albumentations (setup.py) ... [?25l[?25hdone
  Created wheel for albumentations: filename=albumentations-0.4.6-py3-none-any.whl size=65174 sha256=214733884647ad782c55d44f46f3c39e1f9de82be86d84f4b9faa84ee7488c1b
  Stored in directory: /root/.cache/pip/wheels/cf/34/0f/cb2a5f93561a181a4bcc84847ad6aaceea8b5a3127469616cc
Successfully built albumentations
Installing collected packages: imgaug, albumentations
  Attempting uninstall: imgaug
    Found existing installation: imgaug 0.2.9
    Uninstalling imgaug-0.2.9:
      Successfully uninstalled imgaug-0.2.9
  Attempting uninstall: albumentations
    Found existing installation: album

In [None]:
from PIL import Image
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_GEN = "gen.pth.tar"
CHECKPOINT_DISC = "disc.pth.tar"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 50
BATCH_SIZE = 16
NUM_WORKERS = 4
HIGH_RES = 224
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

highres_transform = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ToTensorV2(),
    ]
)

lowres_transform = A.Compose(
    [
        A.Resize(width=LOW_RES, height=LOW_RES, interpolation=Image.BICUBIC),
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

both_transforms = A.Compose(
    [
        A.RandomCrop(width=HIGH_RES, height=HIGH_RES),
        A.HorizontalFlip(p=0.5),
        A.RandomRotate90(p=0.5),
    ]
)

test_transform = A.Compose(
    [
        A.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        ToTensorV2(),
    ]
)

In [None]:
import torch.nn as nn
from torchvision.models import vgg19

# phi_5,4 5th conv layer before maxpooling but after activation

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)


In [None]:
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image


class MyImageFolder(Dataset):
    def __init__(self, root_dir):
        super(MyImageFolder, self).__init__()
        self.data = []
        self.root_dir = root_dir
        self.class_names = os.listdir(root_dir)

        for index, name in enumerate(self.class_names):
            files = os.listdir(os.path.join(root_dir, name))
            self.data += list(zip(files, [index] * len(files)))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        img_file, label = self.data[index]
        root_and_dir = os.path.join(self.root_dir, self.class_names[label])

        image = np.array(Image.open(os.path.join(root_and_dir, img_file)))
        image = both_transforms(image=image)["image"]
        high_res = highres_transform(image=image)["image"]
        low_res = lowres_transform(image=image)["image"]
        return low_res, high_res


def test():
    dataset = MyImageFolder(root_dir="/content/drive/MyDrive/APS360/Bird_Data_HR/")
    
    loader = DataLoader(dataset, batch_size=1, num_workers=8)

    print(len(dataset))
    #for low_res, high_res in loader:
        #print(low_res.shape)
        #print(high_res.shape)


if __name__ == "__main__":
    test()

2000


  cpuset_checked))


In [None]:
import numpy as np
from PIL import Image
from torchvision.utils import save_image


def gradient_penalty(critic, real, fake, device):
    BATCH_SIZE, C, H, W = real.shape
    alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
    interpolated_images = real * alpha + fake.detach() * (1 - alpha)
    interpolated_images.requires_grad_(True)

    # Calculate critic scores
    mixed_scores = critic(interpolated_images)

    # Take the gradient of the scores with respect to the images
    gradient = torch.autograd.grad(
        inputs=interpolated_images,
        outputs=mixed_scores,
        grad_outputs=torch.ones_like(mixed_scores),
        create_graph=True,
        retain_graph=True,
    )[0]
    gradient = gradient.view(gradient.shape[0], -1)
    gradient_norm = gradient.norm(2, dim=1)
    gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
    return gradient_penalty


def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)


def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr


def plot_examples(low_res_folder, gen):
    files = os.listdir(low_res_folder)

    gen.eval()
    for file in files:
      if file!= '.ipynb_checkpoints':
        #image = Image.open("/content/drive/MyDrive/APS360/Bird_Data_LR_Test/" + file)
        #image = Image.open("/content/test_images/" + file)
        image = Image.open("/content/drive/MyDrive/APS360/personal_bird_test_images/bird_pic224/pics/" + file)
        #image = Image.open(low_res_folder)
        with torch.no_grad():
            upscaled_img = gen(
                test_transform(image=np.asarray(image))["image"]
                .unsqueeze(0)
                .to(DEVICE)
            )
        save_image(upscaled_img * 0.5 + 0.5, f"saved/{file}")
    gen.train()

In [None]:
import torch
from torch import optim
from torch.utils.data import DataLoader
from tqdm import tqdm

torch.backends.cudnn.benchmark = True

def test_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
  loop = tqdm(loader, leave=True)
  for idx, (low_res, high_res) in enumerate(loop):
    high_res = high_res.to(DEVICE)
    low_res = low_res.to(DEVICE)
    fake = gen(low_res)
    loss = mse(fake, high_res)
  
  return loss

def train_fn(loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
    loop = tqdm(loader, leave=True)
    for idx, (low_res, high_res) in enumerate(loop):
        high_res = high_res.to(DEVICE)
        low_res = low_res.to(DEVICE)
        
        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        fake = gen(low_res)
        disc_real = disc(high_res)
        disc_fake = disc(fake.detach())
        disc_loss_real = bce(
            disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
        )
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = disc_loss_fake + disc_loss_real

        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        disc_fake = disc(fake)
        l2_loss = mse(fake, high_res)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
        loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
        gen_loss = l2_loss + loss_for_vgg + adversarial_loss

        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()

        if idx % 200 == 0:
            #plot_examples("test_images/", gen)
            print(f'gen loss = {gen_loss}')
            #print(f'disc loss = {loss_disc}')

    return l2_loss

def main():
    train_dataset = MyImageFolder(root_dir="/content/drive/MyDrive/APS360/Bird_Data_HR_Train")
    test_dataset = MyImageFolder(root_dir = "/content/drive/MyDrive/APS360/Bird_Data_HR_Test")
    train_loader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        pin_memory=True,
        num_workers=NUM_WORKERS,
    )
    gen = Generator(in_channels=3).to(DEVICE)
    disc = Discriminator(in_channels=3).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    mse = nn.MSELoss()
    bce = nn.BCEWithLogitsLoss()
    vgg_loss = VGGLoss()

    if LOAD_MODEL:
        load_checkpoint(
            CHECKPOINT_GEN,
            gen,
            opt_gen,
            LEARNING_RATE,
        )
        load_checkpoint(
           CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
        )
    train_loss =  []
    test_loss = []
    for epoch in range(NUM_EPOCHS):
        print(f'epoch: {epoch}')
        trLoss = train_fn(train_loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
        testLoss = test_fn(test_loader, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
        train_loss.append(trLoss)
        test_loss.append(test_loss)

        if SAVE_MODEL:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)
    print(train_loss)
    print('test loss')
    print(test_loss)

if __name__ == "__main__":
    main()

  cpuset_checked))


epoch: 0


  1%|          | 1/88 [00:02<02:56,  2.02s/it]

gen loss = 0.21800506114959717


100%|██████████| 88/88 [01:20<00:00,  1.10it/s]
100%|██████████| 38/38 [00:03<00:00, 11.25it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 1


  1%|          | 1/88 [00:01<01:53,  1.30s/it]

gen loss = 0.03961951658129692


100%|██████████| 88/88 [01:19<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.22it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 2


  1%|          | 1/88 [00:01<02:09,  1.49s/it]

gen loss = 0.041331686079502106


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.04it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 3


  1%|          | 1/88 [00:01<02:04,  1.43s/it]

gen loss = 0.03974296152591705


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.31it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 4


  1%|          | 1/88 [00:01<02:04,  1.43s/it]

gen loss = 0.0457458533346653


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 10.97it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 5


  1%|          | 1/88 [00:01<01:59,  1.38s/it]

gen loss = 0.03304334729909897


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.36it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 6


  1%|          | 1/88 [00:01<02:08,  1.48s/it]

gen loss = 0.03339751064777374


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.34it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 7


  1%|          | 1/88 [00:01<01:56,  1.34s/it]

gen loss = 0.029273053631186485


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.36it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 8


  1%|          | 1/88 [00:01<02:05,  1.45s/it]

gen loss = 0.033955760300159454


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.00it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 9


  1%|          | 1/88 [00:01<02:04,  1.43s/it]

gen loss = 0.031550176441669464


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.32it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 10


  1%|          | 1/88 [00:01<02:06,  1.46s/it]

gen loss = 0.04448777809739113


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.14it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 11


  1%|          | 1/88 [00:01<01:55,  1.33s/it]

gen loss = 0.04125586152076721


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.22it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 12


  1%|          | 1/88 [00:01<02:00,  1.38s/it]

gen loss = 0.02768169902265072


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.01it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 13


  1%|          | 1/88 [00:01<01:52,  1.30s/it]

gen loss = 0.04285845160484314


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.43it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 14


  1%|          | 1/88 [00:01<01:59,  1.38s/it]

gen loss = 0.037633560597896576


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.28it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 15


  1%|          | 1/88 [00:01<02:10,  1.50s/it]

gen loss = 0.04496350139379501


100%|██████████| 88/88 [01:19<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.03it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 16


  1%|          | 1/88 [00:01<02:03,  1.42s/it]

gen loss = 0.02991548366844654


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.24it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 17


  1%|          | 1/88 [00:01<01:59,  1.37s/it]

gen loss = 0.027228128165006638


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.23it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 18


  1%|          | 1/88 [00:01<02:08,  1.48s/it]

gen loss = 0.04060293734073639


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.05it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 19


  1%|          | 1/88 [00:01<02:10,  1.50s/it]

gen loss = 0.024744028225541115


100%|██████████| 88/88 [01:19<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.15it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 20


  1%|          | 1/88 [00:01<02:06,  1.46s/it]

gen loss = 0.02770739048719406


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.28it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 21


  1%|          | 1/88 [00:01<02:04,  1.43s/it]

gen loss = 0.030691787600517273


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.10it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 22


  1%|          | 1/88 [00:01<01:56,  1.34s/it]

gen loss = 0.03286115080118179


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.14it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 23


  1%|          | 1/88 [00:01<02:04,  1.43s/it]

gen loss = 0.034102510660886765


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.40it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 24


  1%|          | 1/88 [00:01<02:04,  1.43s/it]

gen loss = 0.024277620017528534


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.32it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 25


  1%|          | 1/88 [00:01<02:06,  1.46s/it]

gen loss = 0.028644364327192307


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.25it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 26


  1%|          | 1/88 [00:01<02:04,  1.44s/it]

gen loss = 0.02637122943997383


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.43it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 27


  1%|          | 1/88 [00:01<02:03,  1.42s/it]

gen loss = 0.03189624100923538


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.53it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 28


  1%|          | 1/88 [00:01<01:57,  1.35s/it]

gen loss = 0.02959318459033966


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.76it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 29


  1%|          | 1/88 [00:01<01:55,  1.33s/it]

gen loss = 0.02816954255104065


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.13it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 30


  1%|          | 1/88 [00:01<02:10,  1.50s/it]

gen loss = 0.024743445217609406


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.44it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 31


  1%|          | 1/88 [00:01<02:02,  1.41s/it]

gen loss = 0.026651067659258842


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.30it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 32


  1%|          | 1/88 [00:01<02:06,  1.45s/it]

gen loss = 0.024382276460528374


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.09it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 33


  1%|          | 1/88 [00:01<02:06,  1.45s/it]

gen loss = 0.0259441826492548


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.29it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 34


  1%|          | 1/88 [00:01<02:09,  1.49s/it]

gen loss = 0.03128160536289215


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.34it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 35


  1%|          | 1/88 [00:01<02:07,  1.46s/it]

gen loss = 0.022672224789857864


100%|██████████| 88/88 [01:19<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 10.77it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 36


  1%|          | 1/88 [00:01<02:12,  1.53s/it]

gen loss = 0.03218594193458557


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.39it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 37


  1%|          | 1/88 [00:01<02:05,  1.44s/it]

gen loss = 0.025744719430804253


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 10.96it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 38


  1%|          | 1/88 [00:01<01:54,  1.32s/it]

gen loss = 0.028409013524651527


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 10.91it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 39


  1%|          | 1/88 [00:01<02:01,  1.40s/it]

gen loss = 0.02749508246779442


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.46it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 40


  1%|          | 1/88 [00:01<02:05,  1.44s/it]

gen loss = 0.03205931559205055


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 11.47it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 41


  1%|          | 1/88 [00:01<02:05,  1.44s/it]

gen loss = 0.02288161963224411


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.15it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 42


  1%|          | 1/88 [00:01<02:05,  1.44s/it]

gen loss = 0.030010106042027473


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.37it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 43


  1%|          | 1/88 [00:01<02:07,  1.47s/it]

gen loss = 0.03372744098305702


100%|██████████| 88/88 [01:18<00:00,  1.11it/s]
100%|██████████| 38/38 [00:03<00:00, 10.90it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 44


  1%|          | 1/88 [00:01<01:52,  1.29s/it]

gen loss = 0.02419453300535679


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.12it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 45


  1%|          | 1/88 [00:01<01:55,  1.33s/it]

gen loss = 0.026921477168798447


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.37it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 46


  1%|          | 1/88 [00:01<02:02,  1.41s/it]

gen loss = 0.026375900954008102


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.36it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 47


  1%|          | 1/88 [00:01<02:10,  1.50s/it]

gen loss = 0.02664848044514656


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.08it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 48


  1%|          | 1/88 [00:01<02:06,  1.45s/it]

gen loss = 0.02368355542421341


100%|██████████| 88/88 [01:18<00:00,  1.12it/s]
100%|██████████| 38/38 [00:03<00:00, 11.39it/s]


=> Saving checkpoint
=> Saving checkpoint
epoch: 49


  1%|          | 1/88 [00:01<02:06,  1.45s/it]

gen loss = 0.02344530262053013


100%|██████████| 88/88 [01:17<00:00,  1.14it/s]
100%|██████████| 38/38 [00:03<00:00, 11.63it/s]


=> Saving checkpoint
=> Saving checkpoint
[tensor(0.0315, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0336, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0456, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0328, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0362, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0244, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0351, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0356, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0300, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0315, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0262, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0261, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0381, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0273, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0378, device='cuda:0', grad_fn=<MseLossBackward0>), tensor(0.0282, device='cuda:0', grad_fn=<M

In [None]:
dataset = MyImageFolder(root_dir="/content/drive/MyDrive/APS360/personal_bird_test_images/bird_pic224")
loader = DataLoader(
dataset,
batch_size=BATCH_SIZE,
shuffle=True,
pin_memory=True,
num_workers=NUM_WORKERS,
)
gen = Generator(in_channels=3).to(DEVICE)
disc = Discriminator(in_channels=3).to(DEVICE)
opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
mse = nn.MSELoss()
bce = nn.BCEWithLogitsLoss()
vgg_loss = VGGLoss()

if LOAD_MODEL:
    load_checkpoint(
    CHECKPOINT_GEN,
    gen,
    opt_gen,
    LEARNING_RATE,
      )
    
    load_checkpoint(
      CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE,
      )




  cpuset_checked))
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


  0%|          | 0.00/548M [00:00<?, ?B/s]

=> Loading checkpoint
=> Loading checkpoint


In [None]:
plot_examples("/content/drive/MyDrive/APS360/Bird_Data_LR_Test/", gen)

In [None]:
#for demo
plot_examples("/content/drive/MyDrive/APS360/personal_bird_test_images/bird_pic224/pics/", gen)

In [None]:
from albumentations.augmentations.transforms import Resize
import os
f = '/content/saved'
for file in os.listdir(f):
    f_img = f+"/"+file
    #print(file)
    if f_img != '/content/saved/.ipynb_checkpoints':
      img = Image.open(f_img)
      img = img.resize((224,224))
      img.save(f_img)