In [None]:
# Enable autoreload of module
%load_ext autoreload
%autoreload 2

In [None]:
# log python version
import sys
print(sys.version)

In [None]:
from training import training_regression_transformer
from networks.regression_transformer import RegressionTransformerConfig, RegressionTransformer

from data.nef_mnist_dataset import MnistNeFDataset, FlattenTransform, MinMaxTransform

import os
import torch
import torchinfo

In [None]:
torch.cuda.is_available()

In [None]:
# Dataloading
dir_path = os.path.dirname(os.path.abspath(os.getcwd()))
data_root_ours = os.path.join(dir_path, "adl4cv", "datasets", "mnist-nerfs")

class FlattenMinMaxTransform(torch.nn.Module):
  def __init__(self, min_max: tuple = None):
    super().__init__()
    self.flatten = FlattenTransform()
    if min_max:
      self.minmax = MinMaxTransform(*min_max)
    else:
      self.minmax = MinMaxTransform()

  def forward(self, x, y):
    x, _ = self.flatten(x, y)
    x, _ = self.minmax(x, y)
    return x, y


kwargs = {
"type": "pretrained",
"fixed_label": 5,
}


dataset_wo_min_max = MnistNeFDataset(data_root_ours, transform=FlattenTransform(), **kwargs)
min_ours, max_ours = dataset_wo_min_max.min_max()
dataset = MnistNeFDataset(data_root_ours, transform=FlattenMinMaxTransform((min_ours, max_ours)), **kwargs)
dataset_no_transform = MnistNeFDataset(data_root_ours, **kwargs)

In [None]:
# Config Training
config = training_regression_transformer.Config()
config.learning_rate=5e-4
config.max_iters = 14000
config.weight_decay=0
config.decay_lr=True
config.lr_decay_iters=14000
config.warmup_iters=0.1*config.max_iters
config.batch_size = 1
config.detailed_folder = "training_sample_5"

# Config Transforemer
model_config = RegressionTransformerConfig(n_embd=32, block_size=len(dataset[0][0]) - 1, n_head=8, n_layer=16)

In [None]:
# take first n samples that have label == 1 (where label is second entry of dataset object)
n = 5
samples = [(i, dataset[i][0]) for i in range(len(dataset)) if dataset[i][1] == 5][:n]


def get_batch(split: str):
    # let's get a batch with the single element
    # y should be the same shifted by 1
    ix = torch.zeros(config.batch_size, dtype=torch.int)
    #torch.randint(torch.numel(flattened) - model_config.block_size, (config.batch_size,))

    # randomly select a sample (0...n-1)
    split_start = 0 if split == "train" else int(0.8 * n)
    split_end = int(0.8 * n) if split == "train" else n

    sample = samples[torch.randint(split_start, split_end, (1,))][1]

    x = torch.stack(
        [sample[i : i + model_config.block_size] for i in ix]
    )
    y = torch.stack(
        [sample[i + 1 : i + 1 + model_config.block_size] for i in ix]
    )

    # x and y have to be (1, *, 1)
    x = x.unsqueeze(-1).to(config.device)
    y = y.unsqueeze(-1).to(config.device)
    return x, y

In [None]:
# Prepeare model parameters and train
training_regression_transformer.train(get_batch, config, model_config)