In [None]:
%%bash
git clone https://github.com/Lornatang/SRGAN-PyTorch
mv SRGAN-PyTorch/* .
gdown 1Mu4O06B_eFu8qkZhpVkr5CMQIZ6GZYYb

mkdir ~/.kaggle
mv kaggle.json ~/.kaggle
kaggle datasets download -d kostastokis/simpsons-faces

unzip -q simpsons-faces.zip
mv cropped simpsons-faces-200
mkdir simpsons-faces-800

Downloading simpsons-faces.zip to /content



Cloning into 'SRGAN-PyTorch'...
Downloading...
From: https://drive.google.com/uc?id=1Mu4O06B_eFu8qkZhpVkr5CMQIZ6GZYYb
To: /content/SRGAN_x4-ImageNet-c71a4860.pth.tar
  0%|          | 0.00/6.28M [00:00<?, ?B/s]100%|██████████| 6.28M/6.28M [00:00<00:00, 70.2MB/s]
  0%|          | 0.00/442M [00:00<?, ?B/s]  1%|          | 5.00M/442M [00:00<00:11, 38.2MB/s]  7%|▋         | 32.0M/442M [00:00<00:02, 161MB/s]  12%|█▏        | 51.0M/442M [00:00<00:02, 176MB/s] 17%|█▋        | 76.0M/442M [00:00<00:01, 207MB/s] 22%|██▏       | 97.0M/442M [00:00<00:01, 196MB/s] 27%|██▋       | 121M/442M [00:00<00:01, 212MB/s]  32%|███▏      | 142M/442M [00:00<00:01, 198MB/s] 39%|███▊      | 171M/442M [00:00<00:01, 228MB/s] 44%|████▍     | 196M/442M [00:00<00:01, 237MB/s] 50%|████▉     | 220M/442M [00:01<00:00, 235MB/s] 55%|█████▌    | 245M/442M [00:01<00:00, 241MB/s] 61%|██████    | 269M/442M [00:01<00:00, 201MB/s] 67%|██████▋   | 295M/442M [00:01<00:00, 218MB/s] 72%|███████▏  | 317M/442M [00:01<

In [None]:
import os
import math
import random
import cv2
import torch
import numpy as np

from typing import Any
from tqdm.notebook import tqdm
from model import Generator
from torchvision.transforms import functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
    """Convert the image data type to the Tensor (NCWH) data type supported by PyTorch
    Args:
        image (np.ndarray): The image data read by ``OpenCV.imread``, the data range is [0,255] or [0, 1]
        range_norm (bool): Scale [0, 1] data to between [-1, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type
    Returns:
        tensor (torch.Tensor): Data types supported by PyTorch
    Examples:
        >>> example_image = cv2.imread("lr_image.bmp")
        >>> example_tensor = image2tensor(example_image, range_norm=True, half=False)
    """
    # Convert image data type to Tensor data type
    tensor = F.to_tensor(image)

    # Scale the image data from [0, 1] to [-1, 1]
    if range_norm:
        tensor = tensor.mul(2.0).sub(1.0)

    # Convert torch.float32 image data type to torch.half image data type
    if half:
        tensor = tensor.half()

    return tensor


def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
    """Convert the Tensor(NCWH) data type supported by PyTorch to the np.ndarray(WHC) image data type
    Args:
        tensor (torch.Tensor): Data types supported by PyTorch (NCHW), the data range is [0, 1]
        range_norm (bool): Scale [-1, 1] data to between [0, 1]
        half (bool): Whether to convert torch.float32 similarly to torch.half type.
    Returns:
        image (np.ndarray): Data types supported by PIL or OpenCV
    Examples:
        >>> example_image = cv2.imread("lr_image.bmp")
        >>> example_tensor = image2tensor(example_image, range_norm=False, half=False)
    """
    if range_norm:
        tensor = tensor.add(1.0).div(2.0)
    if half:
        tensor = tensor.half()

    image = tensor.permute(1, 2, 0).mul(255).clamp(0, 255).cpu().numpy().astype("uint8")
    return image

In [None]:
class SuperResDataset(Dataset):
    def __init__(self, root):
        self.root = root
        self.files = os.listdir(root)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, item):
        path = os.path.join(self.root, self.files[item])
        image = cv2.imread(path, cv2.IMREAD_UNCHANGED)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image = image2tensor(image, range_norm=False, half=True)
        return image, self.files[item]


def save_image(image_tensor, path, filename):
    image = tensor2image(image_tensor, range_norm=False, half=True)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(path, filename), image)

In [None]:
data_path = 'simpsons-faces-200'
out_path = 'simpsons-faces-800'
model_path = 'SRGAN_x4-ImageNet-c71a4860.pth.tar'
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

model = Generator().to(device=device, memory_format=torch.channels_last)
print("Build SRGAN model successfully.")

# Load the super-resolution model weights
checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
model.load_state_dict(checkpoint["state_dict"])
print(f"Load SRGAN model weights `{os.path.abspath(model_path)}` successfully.")

model.eval()
model.half()

dataset = SuperResDataset(data_path)
dataloader = DataLoader(dataset, batch_size=32, num_workers=2, shuffle=False)

for images, filenames in tqdm(dataloader):
    images = images.to(device=device, memory_format=torch.channels_last, non_blocking=True)
    with torch.no_grad():
        images = model(images)

    for image, filename in zip(images, filenames):
        save_image(image, out_path, filename)

Build SRGAN model successfully.
Load SRGAN model weights `/content/SRGAN_x4-ImageNet-c71a4860.pth.tar` successfully.


  0%|          | 0/309 [00:00<?, ?it/s]

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%bash
tar czf simpsons-faces-800.tar.gz simpsons-faces-800
mv simpsons-faces-800.tar.gz drive/MyDrive/