In [1]:
# Enable autoreload of module
%load_ext autoreload
%autoreload 2

In [2]:
# log python version
import sys
print(sys.version)

3.10.14 (main, May 26 2024, 13:34:58) [GCC 13.2.1 20230801]


In [2]:
from training import naive_approach
from networks.regression_transformer import RegressionTransformerConfig, RegressionTransformer

from data.neural_field_datasets import DWSNetsDataset, MnistNeFDataset, FlattenTransform, MinMaxTransform

import os
import torch
import torchinfo

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
torch.cuda.is_available()

True

In [9]:
# Dataloading
dir_path = os.path.dirname(os.path.abspath(os.getcwd()))
data_root_dwsnet = os.path.join(dir_path, "adl4cv", "datasets", "DWSNets", "mnist-inrs")
data_root_ours = os.path.join(dir_path, "adl4cv", "datasets", "mnist-nerfs")

class FlattenMinMaxTransform(torch.nn.Module):
  def __init__(self, min_max: tuple = None):
    super().__init__()
    self.flatten = FlattenTransform()
    if min_max:
      self.minmax = MinMaxTransform(*min_max)
    else:
      self.minmax = MinMaxTransform()

  def forward(self, x, y):
    x, _ = self.flatten(x, y)
    x, _ = self.minmax(x, y)
    return x, y


dataset = DWSNetsDataset(data_root_dwsnet, transform=FlattenMinMaxTransform())
dataset_wo_min_max = DWSNetsDataset(data_root_dwsnet, transform=FlattenTransform())
dataset_no_transform = DWSNetsDataset(data_root_dwsnet)

kwargs = {
"type": "pretrained",
"fixed_label": 5,
}


dataset_wo_min_max = MnistNeFDataset(data_root_ours, transform=FlattenTransform(), **kwargs)
min_ours, max_ours = dataset_wo_min_max.min_max()
dataset = MnistNeFDataset(data_root_ours, transform=FlattenMinMaxTransform((min_ours, max_ours)), **kwargs)
dataset_no_transform = MnistNeFDataset(data_root_ours, **kwargs)

In [10]:
len(dataset)

409

In [7]:
# Config Training
config = naive_approach.Config()
config.learning_rate=5e-4
config.max_iters = 14000
config.weight_decay=0
config.decay_lr=True
config.lr_decay_iters=14000
config.warmup_iters=0.1*config.max_iters
config.batch_size = 1
config.detailed_folder = "training_sample_5"

# Config Transforemer
model_config = RegressionTransformerConfig(n_embd=32, block_size=len(dataset[0][0]) - 1, n_head=8, n_layer=16)

In [8]:
# take first n samples that have label == 1 (where label is second entry of dataset object)
n = 5
samples = [(i, dataset[i][0]) for i in range(len(dataset)) if dataset[i][1] == 5][:n]


def get_batch(split: str):
    # let's get a batch with the single element
    # y should be the same shifted by 1
    ix = torch.zeros(config.batch_size, dtype=torch.int)
    #torch.randint(torch.numel(flattened) - model_config.block_size, (config.batch_size,))

    # randomly select a sample (0...n-1)
    split_start = 0 if split == "train" else int(0.8 * n)
    split_end = int(0.8 * n) if split == "train" else n

    sample = samples[torch.randint(split_start, split_end, (1,))][1]

    x = torch.stack(
        [sample[i : i + model_config.block_size] for i in ix]
    )
    y = torch.stack(
        [sample[i + 1 : i + 1 + model_config.block_size] for i in ix]
    )

    # x and y have to be (1, *, 1)
    x = x.unsqueeze(-1).to(config.device)
    y = y.unsqueeze(-1).to(config.device)
    return x, y

In [9]:
# Prepeare model parameters and train
naive_approach.train(get_batch, config, model_config)

Initializing a new model from scratch
num decayed parameter tensors: 67, with 215,616 parameters
num non-decayed parameter tensors: 131, with 6,752 parameters
using fused AdamW: True


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


step 0: train loss 0.9639, val loss 0.9633
saving checkpoint to models
step 100: train loss 0.3275, val loss 0.3255
saving checkpoint to models
saving checkpoint to models
step 200: train loss 0.0824, val loss 0.0804
saving checkpoint to models
saving checkpoint to models
step 300: train loss 0.0560, val loss 0.0556
saving checkpoint to models
saving checkpoint to models
step 400: train loss 0.0408, val loss 0.0415
saving checkpoint to models
saving checkpoint to models
step 500: train loss 0.0314, val loss 0.0323
saving checkpoint to models
saving checkpoint to models
step 600: train loss 0.0264, val loss 0.0290
saving checkpoint to models
saving checkpoint to models
step 700: train loss 0.0290, val loss 0.0302
saving checkpoint to models
step 800: train loss 0.0283, val loss 0.0292
saving checkpoint to models
step 900: train loss 0.0238, val loss 0.0246
saving checkpoint to models
saving checkpoint to models
step 1000: train loss 0.0249, val loss 0.0255
saving checkpoint to models
st

In [None]:
import matplotlib.pyplot as plt
import torch
from animation.util import backtransform_weights
from data.neural_field_datasets import MinMaxTransform  

# Assuming the following classes and functions are defined elsewhere:
# - INR
# - dataset_no_transform
# - ae_trained
# - dataset_flatten



# Configuration
inr_kwargs = {"n_layers": 3, "in_dim": 2, "up_scale": 16}
image_size = (28, 28)
idx = 0
# Get dataset elements
original_dict = dataset_no_transform[idx][0]

model = RegressionTransformer(model_config)
model.load_state_dict(torch.load("./models/final_model.pth")["model"])


sample = dataset[idx][0]
X, Y = (sample[:model_config.block_size].unsqueeze(-1).unsqueeze(0), sample[1: 1 + model_config.block_size].unsqueeze(-1).unsqueeze(0))

# autoregressive process
seq = torch.zeros((1, 593, 1))
seq[0][0][0] = X[0][0][0]

for i in range(0, model_config.block_size):
    pred, _loss = model(seq[:, :-1], Y)
    seq[0][i + 1][0] = pred[0][i][0]


#pred, loss = model(X, Y)
#pred = torch.cat([X[0][0], pred[0].squeeze(-1)]).unsqueeze(0).unsqueeze(-1)


minmax_transformer = MinMaxTransform(min_value=min_ours, max_value=max_ours)
dataset_ele_flattened = minmax_transformer.reverse(seq[0]).unsqueeze(0)

# Backtransform weights
reconstructed_dict = backtransform_weights(dataset_ele_flattened, original_dict)

from animation.util import reconstruct_image
from networks.mlp_models import MLP3D


model_config_nef = {
        "out_size": 1,
        "hidden_neurons": [16, 16],
        "use_leaky_relu": False,
        "output_type": "logits",
        "input_dims": 2,
        "multires": 4,
    }

model = MLP3D(**model_config_nef)
model.load_state_dict(reconstructed_dict)
reconstructed_tensor = reconstruct_image(model)
model.load_state_dict(original_dict)
ground_truth_tensor = reconstruct_image(model)

# Plotting the tensors as heatmaps in grayscale
fig, axes = plt.subplots(1, 2, figsize=(20, 10))

axes[0].imshow(ground_truth_tensor, cmap='gray', aspect='auto')
axes[0].set_title('Ground Truth')
axes[0].set_xlabel('X-axis')
axes[0].set_ylabel('Y-axis')

axes[1].imshow(reconstructed_tensor, cmap='gray', aspect='auto')
axes[1].set_title('Reconstructed')
axes[1].set_xlabel('X-axis')
axes[1].set_ylabel('Y-axis')

plt.colorbar(axes[0].imshow(ground_truth_tensor, cmap='gray', aspect='auto'), ax=axes[0])
plt.colorbar(axes[1].imshow(reconstructed_tensor, cmap='gray', aspect='auto'), ax=axes[1])
plt.show()