This notebook contains some experiments.
It does not contain important code.

In [None]:
import numpy as np
import torch, torchvision
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt 
from PIL import Image
from models.SimpleUpscale import SimpleUpscale
from models.SimpleConv import SimpleConv
from train import train_model
from dataclass.ImageClass import UpscaledImages
from time import time
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, Resize, Compose
from models.Generator import Generator

In [None]:
class MicroImages(Dataset):
    def __init__(self, root_dir):
        """Initializes a dataset containing images and labels."""
        super().__init__()
        self.root_dir = root_dir
        resize_size = 128
        self.transform_lr = Compose([ToTensor(), Resize((resize_size, resize_size))])
        self.transform_hr = Compose([ToTensor(), Resize((2*resize_size, 2*resize_size))])
        # self.transform_hr = Compose([ToTensor(), torch.nn.ZeroPad2d(2*resize_size),
        #                         transforms.CenterCrop(2*resize_size)])

        self.data = []
        for i in range(800):
            image_hr_dir = self.root_dir + 'DIV2K_train_HR/' + "{:0>4}".format(i+1) + '.png'
            image_lr_dir = self.root_dir + 'DIV2K_train_LR_bicubic/X2/' + "{:0>4}x2".format(i+1) + '.png'
            image_hr = Image.open(image_hr_dir)
            image_lr = Image.open(image_lr_dir)
            self.data.append((self.transform_lr(image_lr), self.transform_hr(image_hr)))

        self.size=len(self.data)

    def __len__(self):
        """Returns the size of the dataset."""
        return len(self.data)

    def __getitem__(self, index):
        """Returns the index-th data item of the dataset."""

        return self.data[index]

In [None]:
trainset = MicroImages("../data/")


In [None]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)

In [None]:
class ResidualBlock(nn.Module):
    def __init__(self, nbr_channels=64):
        super(ResidualBlock, self).__init__()

        self.net = nn.Sequential(     
            nn.Conv2d(nbr_channels,  nbr_channels, kernel_size=3,  stride=1, padding=1),   nn.BatchNorm2d(nbr_channels), nn.PReLU(),
            nn.Conv2d(nbr_channels,  nbr_channels, kernel_size=3,  stride=1, padding=1),   nn.BatchNorm2d(nbr_channels))

    def forward(self, x):
        return x + self.net(x) # skip connection

In [None]:
class Generator(nn.Module):
    def __init__(self, nbr_channels=64):
        super(Generator, self).__init__()

        self.entry_block = nn.Sequential(nn.Conv2d(3, nbr_channels, kernel_size=9, stride=1, padding=4), nn.PReLU())

        self.residual_blocks = nn.Sequential(ResidualBlock(nbr_channels=64), ResidualBlock(nbr_channels=64), ResidualBlock(nbr_channels=64))

        self.upscale_block = nn.Sequential( nn.Conv2d(nbr_channels, nbr_channels*4, kernel_size=3, stride=1, padding=1),
                                            nn.PixelShuffle(2),
                                            nn.PReLU())

        self.end_block = nn.Conv2d(nbr_channels, 3, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        x = self.entry_block(x)
        x = self.residual_blocks(x) + x
        x = self.upscale_block(x)
        x = self.end_block(x)
        return x

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
model = Generator().to(device)
learning_rate=0.0005
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

In [None]:
print(torch.cuda.memory_summary())

In [None]:
# Training
test = False
epochs = 4

start = time()
epoch_train_losses = []
for i in range(epochs):
    tmp_loss = []
    for (x, y) in trainloader: # [batch_size x 3 x w x h]
        outputs = model(x.to(device))
        loss = criterion(outputs, y.to(device))
        tmp_loss.append(loss.detach())
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        x.detach()
        y.detach()
    epoch_train_losses.append(torch.tensor(tmp_loss).mean())
    print(f"Epoch {i+1}")

end = time()
print(f"Training took {end - start} seconds for {epochs} epochs, or {(end - start)/epochs} seconds per epochs")

torch.cuda.empty_cache()



In [None]:
plt.plot(range(epochs), epoch_train_losses)

In [None]:
image_lr, image_hr = trainset[0]

In [None]:
def show_images(img):
    img = img 
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [None]:
show_images(torchvision.utils.make_grid([image_lr]))


In [None]:
show_images(torchvision.utils.make_grid([(image_hr)]))


In [None]:
with torch.no_grad():
    ouput_lr = model(image_lr.unsqueeze(0).to(device))
show_images(torchvision.utils.make_grid(ouput_lr.cpu().detach()))


In [None]:
with torch.no_grad():
    ouput_hr = model(image_hr.unsqueeze(0).to(device))
show_images(torchvision.utils.make_grid(ouput_hr.cpu().detach()))


In [None]:
show_images(torchvision.utils.make_grid(list(map(lambda x: x[0], trainset[0:5]))))

In [None]:
with torch.no_grad():
    show_images(torchvision.utils.make_grid(list(map(lambda x: model(x[0].unsqueeze(0).to(device))[0].detach().to("cpu"), trainset[0:5]))))

In [None]:
x.size()

In [None]:
print(torch.cuda.memory_summary())

In [None]:
def test_mem(bs=1, w=256):
    x_size = (bs, 3, w, w)
    x = torch.rand(x_size)
    with torch.no_grad():
        y = model(x.to(device))

In [None]:
# for i in range(40, 800, 20):
#     try:
#         test_mem(bs=32, w=i)
#     except Exception as e:
#         print(f"Largest size was {i-20}")
#         print(e)
#         break
# Optimal conditions:
# Max for bs=1 : 680
# Max for bs=2 : 360
# Max for bs=4 : 340
# Max for bs=8 : 220
# Max for bs=16 : 140
# Max for bs=32 : 140

In [None]:
# max_width = []
# times = []
# pixels = []
# step = 20
# start_ = 0
# end_ = 0
# for bs in range(1, 32, 1):
#     for w in range(40, 800, step):
#         try:
#             start=time()
#             test_mem(bs=bs, w=w)
#             end=time()
#         except Exception as e:
#             torch.cuda.empty_cache()
#             max_width.append(w-step)
#             times.append(end_-start_)
#             pixels.append((w-step)*(w-step)*bs)
#             print(f"Batch size {bs}, Largest size was {w-step}")
#             print(e)
#             break
#         finally:
#             start_=start
#             end_=end

In [None]:
# plt.plot(range(1, 32), max_width)
# plt.plot(range(1, 32), [p/1000 for p in pixels])
# plt.plot(range(1, 32), [t*100 for t in times])
# plt.plot(range(1, 32), [(p/t)/1000 for p,t in zip(pixels, times)])

# plt.xlabel("width")
# plt.ylabel("batch size")
# plt.savefig("batch-size-analysis-complete.pdf")



In [None]:
# del x
torch.cuda.empty_cache()

In [None]:
x=None
y=None
if x:
    del x
if y:
    del y
torch.cuda.empty_cache()