In [None]:
import numpy as np
import torch
import torchvision
from datetime import datetime
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from PIL import Image

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

In [None]:
# drive_folder = '/content/drive'
# from google.colab import drive
# drive.mount(drive_folder)

In [None]:
# !mkdir data
# !tar -xf /content/drive/MyDrive/Colab/data/features.tar.gz -C data/
# !tar -xf /content/drive/MyDrive/Colab/data/labels.tar.gz -C data/

In [None]:
log_dir = "runs/" + datetime.now().strftime('%Y-%m-%d_%Hh%M')
writer = SummaryWriter(log_dir)

## Data

In [None]:
data_dir = 'data'

In [None]:
def get_batch(index):
    x = torch.from_numpy(np.load('%s/features_%05d.npy' % (data_dir, index))).float().to(device)
    y = torch.from_numpy(np.load('%s/labels_%05d.npy' % (data_dir, index))).unsqueeze(1).float().to(device)
    return x, y

In [None]:
def image(x):
    return np.array([1 - x[0], 1 - x[0] - x[1], 1 - x[0] - x[1]])

In [None]:
def imshow(x):
    return Image.fromarray(np.transpose(255 * image(x), (1, 2, 0)).astype(np.uint8))

In [None]:
x, y = get_batch(0)
x.shape, y.shape

In [None]:
imshow(x[0].cpu())

In [None]:
x_grid = torchvision.utils.make_grid(x).cpu()

In [None]:
imshow(x_grid)

In [None]:
writer.add_image('one batch', image(x_grid))

### Data augmentation

In [None]:
def flip_grids(grids):
    return torch.flip(grids, dims=[-1])

In [None]:
def rotate_grids(grids, quarter_turns=1):
    match quarter_turns % 4:
        case 1:
            return torch.rot90(grids, dims=[-2,-1])
        case 2:
            return torch.flip(grids, dims=[-2,-1])
        case 3:
            return torch.rot90(grids, dims=[-1,-2])
        case _:
            return grids

In [None]:
def transform_grids(grids, seed):
    return rotate_grids(grids if seed % 8 < 4 else flip_grids(grids), seed)

In [None]:
imshow(torchvision.utils.make_grid(torch.tensor(
    np.array([transform_grids(x[0].cpu(), seed).numpy() for seed in range(8)]))))

Projector:

In [None]:
# writer.add_embedding(x.view(-1, 2*96*96), metadata=y, label_img=torch.tensor(np.array([image(grid) for grid in x])).mean(1).unsqueeze(1))
# writer.close()

## Model

In [None]:
class ResNet(torch.nn.Module):
    def __init__(self, module):
        super().__init__()
        self.module = module

    def forward(self, inputs):
        return self.module(inputs) + inputs

In [None]:
def ResNetBlock():
    return ResNet(
        torch.nn.Sequential(
            torch.nn.Conv2d(32, 32, stride = 1, kernel_size = 3, padding = 1),
            torch.nn.BatchNorm2d(32, eps=1e-05, momentum=0.1),
            torch.nn.ReLU(),
            torch.nn.Conv2d(32, 32, stride = 1, kernel_size = 3, padding = 1),
            torch.nn.BatchNorm2d(32, eps=1e-05, momentum=0.1),
            torch.nn.ReLU()
        )
    )

In [None]:
net = torch.nn.Sequential(
    torch.nn.Conv2d(2, 16, stride = 3, kernel_size = 3, padding = 0), # output 32 x 32 pixels
    torch.nn.ReLU(),
    torch.nn.Conv2d(16, 32, stride = 1, kernel_size = 2, padding = 0), # output 31 x 31 pixels
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 32, stride = 1, kernel_size = 2, padding = 0), # output 30 x 30 pixels
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 32, stride = 1, kernel_size = 2, padding = 0), # output 29 x 29 pixels
    torch.nn.ReLU(),
    torch.nn.Conv2d(32, 32, stride = 1, kernel_size = 2, padding = 0), # output 28 x 28 pixels
    torch.nn.ReLU(),
#    ResNetBlock(),
#    ResNetBlock(),
    torch.nn.AdaptiveAvgPool2d(1),
    torch.nn.Flatten(),
    torch.nn.Linear(32, 16),
    torch.nn.ReLU(),
    torch.nn.Linear(16, 8),
    torch.nn.ReLU(),
    torch.nn.Linear(8, 1)
).to(device)

In [None]:
writer.add_graph(net, x)
writer.close()

## Training

In [None]:
loss_fct = torch.nn.MSELoss()

In [None]:
optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

In [None]:
n_train = 5000

In [None]:
running_loss = 0.0
for epoch in tqdm(range(200), position=0):

    for i in tqdm(range(n_train), position=1, leave=False):

        # get the inputs;
        inputs, labels = get_batch(i)
        inputs = transform_grids(inputs, epoch)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = loss_fct(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 100 == 99:    # every 1000 mini-batches...

            # ...log the running loss
            writer.add_scalar('training loss', running_loss / 100,
                              epoch * n_train + i)
            running_loss = 0.0