This notebook contains some experiments.
It does not contain important code.

In [None]:
import numpy as np
import torch, torchvision
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt 
from PIL import Image
from models.SimpleUpscale import SimpleUpscale
from models.SimpleConv import SimpleConv
from train import train_model
from dataclass.ImageClass import UpscaledImages
from time import time
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Resize, Compose
from models.Generator import Generator

In [None]:
class MicroImages(Dataset):
    def __init__(self, root_dir, resize_size=48):
        """Initializes a dataset containing images and labels."""
        super().__init__()
        self.root_dir = root_dir
        self.resize_size = resize_size
        self.transform_lr = Compose([ToTensor(), Resize((resize_size, resize_size))])
        self.transform_hr = Compose([ToTensor(), Resize((2*resize_size, 2*resize_size))])
        # self.transform_hr = Compose([ToTensor(), torch.nn.ZeroPad2d(2*resize_size),
        #                         transforms.CenterCrop(2*resize_size)])

        self.data = []
        for i in range(64):
            image_hr_dir = self.root_dir + 'DIV2K_train_HR/' + "{:0>4}".format(i+1) + '.png'
            image_lr_dir = self.root_dir + 'DIV2K_train_LR_bicubic/X2/' + "{:0>4}x2".format(i+1) + '.png'
            image_hr = Image.open(image_hr_dir)
            image_lr = Image.open(image_lr_dir)
            self.data.append((self.transform_lr(image_lr), self.transform_hr(image_hr)))

        self.size=len(self.data)

    def __len__(self):
        """Returns the size of the dataset."""
        return len(self.data)

    def __getitem__(self, index):
        """Returns the index-th data item of the dataset."""

        return self.data[index]

In [None]:
trainset = MicroImages("../data/")


In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, nbr_channels=64):
        super(ResidualBlock, self).__init__()

        self.net = nn.Sequential(     
            nn.Conv2d(nbr_channels,  nbr_channels, kernel_size=3,  stride=1, padding=1, bias=False),   nn.BatchNorm2d(nbr_channels), nn.PReLU(),
            nn.Conv2d(nbr_channels,  nbr_channels, kernel_size=3,  stride=1, padding=1, bias=False),   nn.BatchNorm2d(nbr_channels))

    def forward(self, x):
        return x + self.net(x) # skip connection

In [None]:
class Generator(nn.Module):
    def __init__(self, nbr_channels=64, nbr_blocks=3):
        super(Generator, self).__init__()

        self.entry_block = nn.Sequential(nn.Conv2d(3, nbr_channels, kernel_size=9, stride=1, padding=4), nn.PReLU())

        self.residual_blocks = nn.Sequential(*[ResidualBlock(nbr_channels=64) for _ in range(nbr_blocks)])

        self.upscale_block = nn.Sequential( nn.Conv2d(nbr_channels, nbr_channels*4, kernel_size=3, stride=1, padding=1),
                                            nn.PixelShuffle(2),
                                            nn.PReLU())

        self.end_block = nn.Conv2d(nbr_channels, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.entry_block(x)
        x = self.residual_blocks(x) + x
        x = self.upscale_block(x)
        x = self.end_block(x)
        return x

In [None]:
class DownBlock(nn.Module):
    def __init__(self, nbr_channels=64):
        super(DownBlock, self).__init__()

        self.net = nn.Sequential(     
            nn.Conv2d(nbr_channels,  nbr_channels, kernel_size=3,  stride=2, padding=1, bias=False),   nn.BatchNorm2d(nbr_channels), nn.LeakyReLU())

    def forward(self, x):
        return self.net(x) # skip connection

class ConvBlock(nn.Module):
    def __init__(self, nbr_channels=64):
        super(ConvBlock, self).__init__()

        self.net = nn.Sequential(     
            nn.Conv2d(nbr_channels,  nbr_channels, kernel_size=3,  stride=1, padding=1, bias=False),   nn.BatchNorm2d(nbr_channels), nn.LeakyReLU())

    def forward(self, x):
        return self.net(x)

class Discriminator(nn.Module):
    def __init__(self, nbr_channels=64):
        super(Discriminator, self).__init__()

        self.entry_block = nn.Sequential(nn.Conv2d(3, nbr_channels, kernel_size=3, stride=1, padding=1), nn.LeakyReLU())

        self.conv_blocks = nn.Sequential(   DownBlock(nbr_channels=nbr_channels),
                                            ConvBlock(nbr_channels=nbr_channels),
                                            DownBlock(nbr_channels=nbr_channels),
                                            ConvBlock(nbr_channels=nbr_channels),
                                            DownBlock(nbr_channels=nbr_channels),
                                            ConvBlock(nbr_channels=nbr_channels),
                                            DownBlock(nbr_channels=nbr_channels),
                                            nn.Conv2d(nbr_channels, 1, kernel_size=3,stride=1, padding=1))

        self.pool = nn.AdaptiveAvgPool2d((32, 32))

        self.end_block = nn.Sequential( nn.Linear(1024, 32),
                                        nn.LeakyReLU(),
                                        nn.Linear(32, 1)) # do not use sigmoid which is in BCEWithLogitsLoss loss function

    def forward(self, x):
        x = self.entry_block(x)
        x = self.conv_blocks(x)
        x = self.pool(x)
        x = self.end_block(x.view(x.size(0), 1024))
        return x

In [None]:
def train_gan(trainset: torch.utils.data.Dataset, lr: float=0.0001, batch_size: int=16, gpu: bool=True, save_file: str=None, use_amp:bool=True, log:bool=True, auto_tuner:bool=True, epochs: int=4, gen_args: dict=None, dis_args: dict=None):
    # Initialization
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True)
    torch.backends.cudnn.benchmark = auto_tuner
    device = "cuda:0" if gpu and torch.cuda.is_available() else "cpu"
    model = Generator(**gen_args).to(device) if gen_args is not None else Generator().to(device)
    discriminator = Discriminator(**dis_args).to(device) if dis_args is not None else Discriminator().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=lr)
    criterion = nn.MSELoss()
    d_criterion = nn.BCEWithLogitsLoss()

    # Training
    start = time()
    epoch_train_losses = []
    d_losses = []
    for i in range(epochs):
        epoch_start = time()
        tmp_loss = []
        d_tmp_loss = []
        for (x, y) in trainloader: # [batch_size x 3 x w x h]
            with torch.cuda.amp.autocast(enabled=use_amp):
                real_label = torch.full((x.size(0), 1), 1, dtype=x.dtype).to(device)
                fake_label = torch.full((x.size(0), 1), 0, dtype=x.dtype).to(device)
                # Update D
                discriminator.zero_grad(set_to_none=True)

                outputs = model(x.to(device))

                d_loss_real = d_criterion(discriminator(y.to(device)), real_label)
                d_loss_fake = d_criterion(discriminator(outputs.detach()), fake_label)
                d_loss = d_loss_real + d_loss_fake

                d_tmp_loss.append(d_loss.detach())

                d_loss.backward()
                d_optimizer.step()

                # Update G
                model.zero_grad(set_to_none=True)

                content_loss = criterion(outputs, y.to(device).detach())
                gan_loss = d_criterion(discriminator(outputs), real_label)
                g_loss = content_loss + 0.001 * gan_loss

                tmp_loss.append(g_loss.detach())

                g_loss.backward()
                optimizer.step()
            
        d_losses.append(torch.tensor(d_tmp_loss).mean())
        epoch_train_losses.append(torch.tensor(tmp_loss).mean())
        if log:
            print(f"Epoch {i+1} in {(time() - epoch_start):.2f}s, total time: {(time() - start):.2f}s, loss: {epoch_train_losses[-1]:.8f}, remaing time: {(epochs - i) * (time() - epoch_start):.2f}s")
        if save_file is not None:
            torch.save({
                    'epoch': i+1,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'd_optimizer_state_dict': d_optimizer.state_dict(),
                    'g_loss': epoch_train_losses[-1],
                    }, save_file)

    end = time()
    print(f"Training took {end - start:.2f} seconds for {epochs} epochs, or {(end - start)/epochs:.2f} seconds per epochs")
    print(f"Final loss {epoch_train_losses[-1]:.8f}")

In [None]:
train_gan(trainset)

In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True, pin_memory=True)

In [None]:
torch.backends.cudnn.benchmark = True

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
model = Generator().to(device)
discriminator = Discriminator().to(device)
learning_rate=0.0001
# gan_lr = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# g_optimizer = torch.optim.Adam(model.parameters(), lr=gan_lr)
d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
d_criterion = nn.BCEWithLogitsLoss()

In [None]:
PATH = 'models/save/gan_model_exp.pt'

In [None]:
use_amp = True

In [None]:
print(torch.cuda.memory_summary())

In [None]:
# Training
test = False
epochs = 4

start = time()
epoch_train_losses = []
d_losses = []
for i in range(epochs):
    tmp_loss = []
    d_tmp_loss = []
    for (x, y) in trainloader: # [batch_size x 3 x w x h]
        with torch.cuda.amp.autocast(enabled=use_amp):
            real_label = torch.full((x.size(0), 1), 1, dtype=x.dtype).to(device)
            fake_label = torch.full((x.size(0), 1), 0, dtype=x.dtype).to(device)
            # Update D
            discriminator.zero_grad(set_to_none=True)

            outputs = model(x.to(device))

            d_loss_real = d_criterion(discriminator(y.to(device)), real_label)
            d_loss_fake = d_criterion(discriminator(outputs.detach()), fake_label)
            d_loss = d_loss_real + d_loss_fake

            d_tmp_loss.append(d_loss.detach())

            d_loss.backward()
            d_optimizer.step()

            # Update G
            model.zero_grad(set_to_none=True)

            content_loss = criterion(outputs, y.to(device).detach())
            gan_loss = d_criterion(discriminator(outputs), real_label)
            g_loss = content_loss + 0.001 * gan_loss

            tmp_loss.append(g_loss.detach())

            g_loss.backward()
            optimizer.step()
        
    d_losses.append(torch.tensor(d_tmp_loss).mean())
    epoch_train_losses.append(torch.tensor(tmp_loss).mean())
    print(f"Epoch {i+1}")
    torch.save({
            'epoch': i+1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'd_optimizer_state_dict': d_optimizer.state_dict(),
            'g_loss': epoch_train_losses[-1],
            }, PATH)

end = time()
print(f"Training took {end - start} seconds for {epochs} epochs, or {(end - start)/epochs} seconds per epochs")

torch.cuda.empty_cache()



In [None]:
plt.plot(range(epochs), epoch_train_losses)

In [None]:
plt.plot(range(epochs), d_losses)

In [None]:
image_lr, image_hr = trainset[0]

In [None]:
def show_images(img):
    img = img 
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
show_images(torchvision.utils.make_grid([image_lr]))


In [None]:
show_images(torchvision.utils.make_grid([(image_hr)]))


In [None]:
with torch.no_grad():
    ouput_lr = model(image_lr.unsqueeze(0).to(device))
show_images(torchvision.utils.make_grid(ouput_lr.cpu().detach()))


In [None]:
with torch.no_grad():
    ouput_hr = model(image_hr.unsqueeze(0).to(device))
show_images(torchvision.utils.make_grid(ouput_hr.cpu().detach()))


In [None]:
show_images(torchvision.utils.make_grid(list(map(lambda x: x[0], trainset[0:5]))))

In [None]:
with torch.no_grad():
    show_images(torchvision.utils.make_grid(list(map(lambda x: model(x[0].unsqueeze(0).to(device))[0].detach().to("cpu"), trainset[0:5]))))

In [None]:
x.size()

In [None]:
print(torch.cuda.memory_summary())

In [None]:
def test_mem(bs=1, w=256):
    x_size = (bs, 3, w, w)
    x = torch.rand(x_size)
    with torch.no_grad():
        y = model(x.to(device))

In [None]:
# for i in range(40, 800, 20):
#     try:
#         test_mem(bs=32, w=i)
#     except Exception as e:
#         print(f"Largest size was {i-20}")
#         print(e)
#         break
# Optimal conditions:
# Max for bs=1 : 680
# Max for bs=2 : 360
# Max for bs=4 : 340
# Max for bs=8 : 220
# Max for bs=16 : 140
# Max for bs=32 : 140

In [None]:
# max_width = []
# times = []
# pixels = []
# step = 20
# start_ = 0
# end_ = 0
# for bs in range(1, 32, 1):
#     for w in range(40, 800, step):
#         try:
#             start=time()
#             test_mem(bs=bs, w=w)
#             end=time()
#         except Exception as e:
#             torch.cuda.empty_cache()
#             max_width.append(w-step)
#             times.append(end_-start_)
#             pixels.append((w-step)*(w-step)*bs)
#             print(f"Batch size {bs}, Largest size was {w-step}")
#             print(e)
#             break
#         finally:
#             start_=start
#             end_=end

In [None]:
# plt.plot(range(1, 32), max_width)
# plt.plot(range(1, 32), [p/1000 for p in pixels])
# plt.plot(range(1, 32), [t*100 for t in times])
# plt.plot(range(1, 32), [(p/t)/1000 for p,t in zip(pixels, times)])

# plt.xlabel("width")
# plt.ylabel("batch size")
# plt.savefig("batch-size-analysis-complete.pdf")



In [None]:
# del x
torch.cuda.empty_cache()

In [None]:
x=None
y=None
if x:
    del x
if y:
    del y
torch.cuda.empty_cache()

In [None]:
test_x = torch.rand((16, 1, 32, 32)).view(16, 1024)

In [None]:
test_y = nn.Linear(1024, 32)(test_x)
test_y = nn.ReLU()(test_y)
test_y = nn.Linear(32, 1)(test_y)
test_y = nn.Sigmoid()(test_y)

In [None]:
test_x.size()

In [None]:
print(discriminator)

In [None]:
len(d_losses)