In [17]:
import numpy as np
import pandas as pd
import os, math, sys
import glob, itertools
import argparse, random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image, make_grid

import plotly
import plotly.express as px
import plotly.graph_objects as go
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm_notebook as tqdm
from sklearn.model_selection import train_test_split

random.seed(42)
import warnings
warnings.filterwarnings("ignore")


In [20]:
os.makedirs("images", exist_ok=True)
os.makedirs("saved_models", exist_ok=True)

In [26]:
cuda = torch.cuda.is_available()
print(f"Is CUDA available? {cuda}")

Is CUDA available? True


In [22]:
load_pretrained_models = True
n_epochs = 2
dataset_path = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
batch_size = 16
lr = 0.00008
b1 = 0.05
b2 = 0.999

decay_epochs = 100
n_cpu = 8
hr_height = 256
hr_width = 256
channel = 3

hr_shape = (hr_height, hr_width)
hr_shape


(256, 256)

In [23]:
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])

In [28]:
class ImageDataset(Dataset):
    def __init__(self, files, hr_shape):
        hr_height, hr_width = hr_shape
        self.lr_transform = transforms.Compose([
            transforms.Resize((hr_height // 4, hr_height // 4), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        self.hr_transform = transforms.Compose([
            transforms.Resize((hr_height, hr_height), Image.BICUBIC),
            transforms.ToTensor(),
            transforms.Normalize(mean, std),
        ])
        self.files = files

    def __getitems__(self, index):
        img = Image.open(self.files[index % len(self.files)])
        img_lr = self.lr_transform(img)
        img_hr = self.hr_transform(img)
        return {"lr": img_lr, "hr": img_hr}

    def __len__(self):
        return len(self.files)



In [29]:
train_paths, test_paths = train_test_split(
    sorted(glob.glob(f"{dataset_path}/*.*")), test_size=0.02, random_state=42
)
train_dataloader = DataLoader(
    ImageDataset(train_paths, hr_shape=hr_shape),
    batch_size=batch_size,
    shuffle=True,
    num_workers=n_cpu,
)
test_dataloader = DataLoader(
    ImageDataset(test_paths, hr_shape=hr_shape),
    batch_size=int(batch_size * 0.75),
    shuffle=True,
    num_workers=n_cpu,
)

![srgan](https://miro.medium.com/max/1000/1*zsiBj3IL4ALeLgsCeQ3lyA.png)

In [33]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.feature_extractor = nn.Sequential(*list(vgg19_model.features.children())[:18])
    def forward(self, img):
        return self.feature_extractor(img)


In [None]:
class 