In [1]:
# Enable autoreload of module
%load_ext autoreload
%autoreload 2

In [2]:
from data.neural_field_datasets_shapenet import FlattenTransform3D, ShapeNetDataset
# load dataset
dataset = ShapeNetDataset("./datasets/plane_mlp_weights")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from collections import OrderedDict
from os import listdir
from os.path import isfile, join

import torch

from utils import get_default_device


def get_weight_vector(state_dict: OrderedDict, layer: int, neuron: int) -> torch.Tensor:
    return state_dict[f"layers.{layer}.weight"][:, neuron]

def get_bias(state_dict: OrderedDict, layer: int, neuron: int) -> torch.Tensor:
    return  state_dict[f"layers.{layer}.bias"][neuron]

def get_all_weights(files: list, layer: int, neuron: int) -> torch.Tensor:
    all_weights = torch.stack([get_weight_vector(mlp3d[0], layer, neuron) for mlp3d in dataset])
    return all_weights



def get_all_biases(dataset: ShapeNetDataset, layer: int, neuron: int) -> torch.Tensor:
    all_biases = torch.stack([get_bias(mlp3d[0], layer, neuron) for mlp3d in dataset])



In [4]:
len(dataset)

4045

In [6]:
compute_weights = True
if compute_weights:
    import h5py

    all_weights_layer_0 = torch.zeros(27, 4045, 128)
    all_weights_layer_1 = torch.zeros(128, 4045, 128)
    all_weights_layer_2 = torch.zeros(128, 4045, 128)
    all_weights_layer_3 = torch.zeros(128, 4045, 1)


    for i in range(27):
        all_weights_layer_0[i] = get_all_weights(dataset, 0, i)

    for i in range(128):
        all_weights_layer_1[i] = get_all_weights(dataset, 1, i)

    for i in range(128):
        all_weights_layer_2[i] = get_all_weights(dataset, 2, i)

    for i in range(128):
        all_weights_layer_3[i] = get_all_weights(dataset, 3, i)
        
    all_weights = torch.cat((all_weights_layer_0.view(-1, 128), all_weights_layer_1.view(-1, 128), all_weights_layer_2.view(-1, 128), all_weights_layer_3.view(-1, 128)))

    # Save the tensor to an HDF5 file
    with h5py.File('datasets/plane_mlp_weights.h5', 'w') as f:
        f.create_dataset('dataset', data=all_weights)




In [25]:
import torch
from torch.utils.data import Dataset, DataLoader
import h5py

class HDF5Dataset(Dataset):
    def __init__(self, hdf5_file, dataset_name):
        self.hdf5_file = hdf5_file
        self.dataset_name = dataset_name
        with h5py.File(self.hdf5_file, 'r') as f:
            self.dataset_length = f[self.dataset_name].shape[0]
    
    def __len__(self):
        return self.dataset_length
    
    def __getitem__(self, idx):
        with h5py.File(self.hdf5_file, 'r') as f:
            data = f[self.dataset_name][idx]
            return torch.tensor(data)

# Instantiate the dataset and dataloader
hdf5_dataset = HDF5Dataset('datasets/plane_mlp_weights.h5', 'dataset')

In [26]:
batch_size = 2**14
train_dataloader = DataLoader(hdf5_dataset, batch_size=batch_size, shuffle=True)

In [27]:
from networks.vq_ae import VQAutoencoderConfig
from training.training_autoencoder import train_model, TrainingConfig

from torch import nn
from torch.nn import MSELoss, CrossEntropyLoss, L1Loss
from torch.utils.data import DataLoader

import wandb

wandb.login()

batch_size = 2**14
train_dataloader = DataLoader(hdf5_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = train_dataloader#DataLoader([test_dataset[0][0][0]], batch_size=1, shuffle=True)

train_config = TrainingConfig()
train_config.max_iters = 10
train_config.always_save_checkpoint = True
train_config.weight_decay = 0.0
train_config.learning_rate = 5e-3
train_config.lr_decay_iters = train_config.max_iters*len(train_dataloader)
train_config.warmup_iters = 0.05*train_config.max_iters*len(train_dataloader)
train_config.log_interval = 1
train_config.out_dir = "models/vq_ae"

vq_config = VQAutoencoderConfig()
vq_config.dim_enc = (128, 64, 32)
vq_config.dim_dec = (32, 64, 128)
vq_config.with_vq = False


ae_trained = train_model(train_config, vq_config, train_dataloader, test_dataloader, L1Loss())



0,1
batch,▂▃▄▆▇▁▃▄▅▇▁▂▄▅▇█▂▄▄▆█▂▃▅▆▇▁▃▅▅▇▂▂▄▆▇█▂▄▆
epoch,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss,██▆▆▅▅▄▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▂▄▅▆████████████████▇▇▇▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▆

0,1
batch,52.0
epoch,7.0
loss,0.23942
lr,0.0036


                                                                         

Checkpoint saved: models/vq_ae/model_epoch_0.pth


                                                                         

Checkpoint saved: models/vq_ae/model_epoch_1.pth


                                                                         

Checkpoint saved: models/vq_ae/model_epoch_2.pth


                                                                         

Checkpoint saved: models/vq_ae/model_epoch_3.pth


                                                                         

Checkpoint saved: models/vq_ae/model_epoch_4.pth


                                                                         

KeyboardInterrupt: 

wandb: Network error (ConnectionError), entering retry loop.


In [12]:
from networks.vq_ae import VQAutoencoder


vq_ae = VQAutoencoder(vq_config)
vq_ae_dict = torch.load("models/vq_ae/model_epoch_6.pth")["model_state_dict"]
vq_ae.load_state_dict(vq_ae_dict)


for batch in train_dataloader:
    X = batch
    Y = vq_ae(X)
    error = X-Y
    break
    


In [19]:
error.abs()

tensor([[0.1078, 0.0341, 0.0054,  ..., 0.0309, 0.0983, 0.3144],
        [0.0984, 0.2903, 0.0172,  ..., 0.3958, 0.7926, 1.1895],
        [0.0341, 0.0267, 0.0785,  ..., 0.0658, 0.0469, 0.0328],
        ...,
        [0.1093, 0.0593, 0.0043,  ..., 0.0463, 0.0622, 0.0515],
        [0.3041, 0.0401, 0.0287,  ..., 0.7800, 0.6764, 3.0190],
        [0.2775, 1.1532, 0.0257,  ..., 0.7382, 0.0043, 0.2169]],
       grad_fn=<AbsBackward0>)