In [1]:
# Enable autoreload of module
%load_ext autoreload
%autoreload 2

In [2]:
from torchinfo import summary
import torch
import os

from data.DWSNets_dataset import DWSNetsDataset, LayerOneHotTransform, FlattenTransform, BiasFlagTransform
from networks.naive_rq_ae import RQAutoencoder, RQAutoencoderConfig
from collections import OrderedDict

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

dir_path = os.path.dirname(os.path.abspath(os.getcwd()))
data_root = os.path.join(dir_path, "adl4cv", "datasets", "DWSNets", "mnist-inrs")

class AutoencoderTransform(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.bias = BiasFlagTransform()
    self.flatten = FlattenTransform()
    self.layer_one_hot = LayerOneHotTransform()

  def forward(self, x, y):
    bias, _ = self.bias(x, y)
    layer, _ = self.layer_one_hot(x, y)    
    x, _ = self.flatten(x, y)
    return torch.hstack((x, layer, bias)), y


dataset_no_transform = DWSNetsDataset(data_root)
train_dataset = DWSNetsDataset(data_root, transform=AutoencoderTransform())
test_dataset = DWSNetsDataset(data_root, split="test", transform=AutoencoderTransform())

path = "datasets/DWSNets/mnist-inrs/mnist_splits.json"


In [4]:
ae_config = RQAutoencoderConfig(dim_l=(5, 5, 5))
ae = RQAutoencoder(ae_config)

print(train_dataset[203][0])
print(ae(train_dataset[203][0]))

tensor([[-0.0784,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0276,  1.0000,  0.0000,  0.0000,  0.0000],
        [-0.0030,  1.0000,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.0686,  0.0000,  0.0000,  1.0000,  0.0000],
        [ 0.0870,  0.0000,  0.0000,  1.0000,  0.0000],
        [-0.1302,  0.0000,  0.0000,  1.0000,  1.0000]])
tensor([[ 0.3800, -0.2634,  0.4970,  0.1051,  0.3101],
        [ 0.3809, -0.2645,  0.4972,  0.1039,  0.3103],
        [ 0.3808, -0.2642,  0.4972,  0.1042,  0.3102],
        ...,
        [ 0.3828, -0.2663,  0.4978,  0.1019,  0.3105],
        [ 0.3844, -0.2680,  0.4983,  0.1000,  0.3107],
        [ 0.3863, -0.2706,  0.4971,  0.0986,  0.3117]],
       grad_fn=<AddmmBackward0>)


In [14]:
train_dataset[0][0].flatten().reshape()

tensor([-0.0041,  1.0000,  0.0000,  ...,  0.0000,  1.0000,  1.0000])

In [18]:
train_dataset[0][0].size()

torch.Size([1185, 5])

In [5]:
from training.training_autoencoder import train_model, TrainingConfig

from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss
from torch.utils.data import DataLoader

train_dataloader = DataLoader([train_dataset[0]], batch_size=1, shuffle=True)
test_dataloader = DataLoader([test_dataset[0]], batch_size=1, shuffle=True)

train_config = TrainingConfig()
train_config.max_iters = 5000
train_config.always_save_checkpoint = True
train_config.weight_decay = 0.0
train_config.learning_rate = 1e-3
train_config.lr_decay_iters = 5000
train_config.log_interval = 1

class AutoencoderLoss(nn.Module):
  def __init__(self):
    super().__init__()
    self.mse = MSELoss()
    self.ce = CrossEntropyLoss()

  def forward(self, pred, true):
    # take first column of pred and true and compare them with mse
    x_pred = pred[:, :1]
    x_true = true[:, :1]
    mse_loss = self.mse(x_pred, x_true)

    # compare one hot encoding (eg feature 1, 2, 3) with cross entropy loss
    layer_pred = pred[:, 1:4]
    layer_true = true[:, 1:4]
    ce_loss = self.ce(layer_pred, layer_true)

    # compare bias flag
    bias_pred = pred[:, 4]
    bias_true = true[:, 4]
    bias_loss = self.ce(bias_pred, bias_true)

    return mse_loss + ce_loss + bias_loss

train_model(train_config, ae_config, train_dataloader, test_dataloader, MSELoss())

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mluis-muschal[0m ([33madl-for-cv[0m). Use [1m`wandb login --relogin`[0m to force relogin


                                                                       

0,1
batch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,██▇▆▄▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▂▃▅▆██████▇▇▇▇▇▆▆▆▆▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁

0,1
batch,1.0
epoch,4999.0
loss,0.00014
lr,0.0


In [6]:
PATH = "./models/model_epoch_0.pth"
ae_trained = RQAutoencoder(ae_config)
ae_trained.load_state_dict(torch.load(PATH)["model_state_dict"])

<All keys matched successfully>

In [7]:
print(train_dataset[0][0])

tensor([[-0.0041,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0303,  1.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0380,  1.0000,  0.0000,  0.0000,  0.0000],
        ...,
        [-0.1311,  0.0000,  0.0000,  1.0000,  0.0000],
        [ 0.1318,  0.0000,  0.0000,  1.0000,  0.0000],
        [-0.1846,  0.0000,  0.0000,  1.0000,  1.0000]])


In [8]:
print(ae_trained(train_dataset[0][0]))

tensor([[-0.2182,  1.8510, -0.2103, -0.2100, -0.2099],
        [-0.2182,  1.8485, -0.2101, -0.2091, -0.2099],
        [-0.2182,  1.8479, -0.2100, -0.2089, -0.2099],
        ...,
        [-0.2441,  1.6777, -0.2271, -0.1919, -0.1876],
        [-0.2441,  1.6609, -0.2250, -0.1852, -0.1880],
        [-0.2556,  1.6280, -0.2380, -0.1948, -0.1770]],
       grad_fn=<AddmmBackward0>)


In [11]:
loss = MSELoss()
idx = 0
loss(ae_trained(train_dataset[idx][0]), ae(train_dataset[idx][0]))/len(ae_trained(train_dataset[0][0]))

tensor(0.0007, grad_fn=<DivBackward0>)

In [10]:
x, indices, commit_loss = ae_trained.encode_to_cb(train_dataset[idx][0])
ae_trained.decode(x)

tensor([[-0.1976,  1.8893, -0.2167, -0.2187, -0.2080],
        [-0.1976,  1.8893, -0.2167, -0.2187, -0.2080],
        [-0.1976,  1.8893, -0.2167, -0.2187, -0.2080],
        ...,
        [-0.1402,  1.7051, -0.1938, -0.1777, -0.2014],
        [-0.1720,  1.7908, -0.2006, -0.2040, -0.2080],
        [-0.1041,  1.5429, -0.1611, -0.1606, -0.2079]],
       grad_fn=<AddmmBackward0>)

In [None]:
def backtransform_weights(flattened_weights, original_weights_dict):
    reconstructed_dict = OrderedDict()
    start = 0
    for key, tensor in original_weights_dict.items():
        # Get the number of elements in the tensor
        num_elements = tensor.numel()
        # Get the slice of the flattened weights corresponding to this tensor
        flattened_slice = flattened_weights[start:start + num_elements]
        # Reshape the slice to the shape of the original tensor
        reconstructed_tensor = flattened_slice.view(tensor.shape)
        # Add to the reconstructed dictionary
        reconstructed_dict[key] = reconstructed_tensor
        # Update the start index for the next slice
        start += num_elements
    
    return reconstructed_dict


idx = 1
dataset_ele = dataset_no_transform[idx][0]
dataset_ele_flattened= ae_trained(train_dataset[idx][0])[:, 0]

reconstructed_dict = backtransform_weights(dataset_ele_flattened, dataset_ele)


RuntimeError: shape '[32, 2]' is invalid for input of size 320