In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
import torch.nn as nn
from data.neural_field_datasets_shapenet import AllWeights3D, FlattenTransform3D, ImageTransform3D, ModelTransform3D, ShapeNetDataset
from torch.utils.data import DataLoader
from networks.shapenet_ae import GlobalAutoencoder

from networks.shapenet_ae import VanillaDecoder, VanillaEncoder, PositionEncoder
device = torch.device("cuda")



image_transform = ImageTransform3D()

dataset_weights = ShapeNetDataset(os.path.join("./", "datasets", "shapenet_nefs", "pretrained"), transform=image_transform)
dataset = ShapeNetDataset(os.path.join("./", "datasets", "shapenet_nefs", "pretrained"))

model_transform = ModelTransform3D(dataset[0][0]["model_config"])


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
import torch

# Assuming tensor of shape [3482, 116, 32]
tensor = torch.cat([weights[0] for weights in dataset_weights])

tensor = tensor.permute(0, 2, 1)
tensor_centered = (tensor - tensor.mean(dim=0).mean(dim=0))

num_nef, num_vec, dim_vec = tensor.shape
pos_enc = torch.Tensor([i for _ in range(num_nef) for i in range(num_vec)]).view(num_nef, num_vec, 1).to(device="cuda")
tensor_w_pos = torch.cat((tensor_centered, pos_enc), dim =2)

# Define the split ratio
split_ratio = 0.8
num_train = int(split_ratio * tensor_centered.shape[0])

# Generate random indices
indices = torch.randperm(tensor_centered.shape[0])

# Split the indices for training and validation
train_indices = indices[:num_train]
val_indices = indices[num_train:]

# Index the tensor to create training and validation sets
train_data = tensor_w_pos[train_indices]
val_data = tensor_w_pos[val_indices]

print(f'Training data shape: {train_data.shape}')
print(f'Validation data shape: {val_data.shape}')


Training data shape: torch.Size([2785, 116, 33])
Validation data shape: torch.Size([697, 116, 33])


In [11]:
from networks.nano_gpt_adapted import TransformerDecoder, TransformerDecoderConfig
from networks.shapenet_ae import construct_pos_decoder, construct_pos_encoder

latent_space = 16
    

encoder_layer_1 = TransformerDecoderConfig(input_size=32, n_embd=256, block_size=num_vec, n_head=16, n_layer=3, output_size=latent_space, vocab_size=2048, dropout=0.0)
decoder_layer_1 = TransformerDecoderConfig(input_size= latent_space, n_embd=256, block_size=num_vec, n_head=16, n_layer=2, output_size=32, vocab_size=2048, dropout=0.0)

# encoder level
pos_emb_enc = 512
feature_extractor_1 = construct_pos_encoder(input_dim=latent_space, emb_dim=pos_emb_enc, latent_dim=16, num_layers_enc=2, num_vec=num_vec)
encoder_1 = TransformerDecoder(encoder_layer_1)
encoder_2 = construct_pos_encoder(input_dim=32, emb_dim=pos_emb_enc, latent_dim=16, num_layers_enc=3, num_vec=num_vec)

# decoder level
pos_emb_dec = 1024
decoder_1 = TransformerDecoder(decoder_layer_1)
upscaler_1, decoder_2= construct_pos_decoder(input_dim=16, output_dim=32, emb_dim=pos_emb_dec, num_layers_dec=2, num_vec=num_vec)

blocks = [feature_extractor_1,encoder_1, encoder_2, decoder_1, decoder_2, upscaler_1]

params = []
for element in [feature_extractor_1,encoder_1, encoder_2, decoder_1, decoder_2, upscaler_1]:
    params = params + list(element.parameters())

for element in blocks:
    element.to(device)

        
def add_pos_enc(x, enc):
    return torch.cat((x, enc), dim = 2).to(device)

def compute_output(x):
    pos = x[:, :, -1:]
    x = x[:, :, :-1]
    
    # add positional encoding extract features from it
    #x = feature_extractor_1(add_pos_enc(x, pos)) + x
        
    # encode information
    x = encoder_1(x)
    #x = encoder_2(add_pos_enc(x, pos))
    
    
    # decode information
    x = decoder_1(x)
    #x = upscaler_1(x)
    #x = decoder_2(add_pos_enc(x, pos))

    return x


number of parameters: 3.49M
number of parameters: 2.67M


In [12]:
tensor_centered[:, 0, :].min()

tensor(-9.7245, device='cuda:0')

In [13]:
torch.nn.functional.mse_loss(tensor_centered[:, 0, 2], torch.zeros_like(tensor_centered[:, 0, 2]))

tensor(0.3862, device='cuda:0')

In [15]:
from networks.nano_gpt_adapted import TransformerDecoder
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math

from networks.nano_gpt import GPTConfig
from networks.shapenet_ae import VanillaAutoencoder



def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

def interpolate_loss(iter, max_iter, loss_1, loss_2, max_loss_2 = 1.0):
    
    split_1, split_2 =  1 - iter / max_iter, iter / max_iter
    
    if split_2 > max_loss_2:
        split_1 = 1 - max_loss_2
        split_2 = max_loss_2
    
    return loss_1 * split_1, loss_2 * split_2, (split_1, split_2)



lr_decay_iters = 10000
warmup_iters = 0.0125 * lr_decay_iters
learning_rate = 0.00075

eval_iters = 4



wandb.init(project="autoencoder", name=f"Autoencoder Transformer")



criterion = nn.MSELoss()
optimizer = optim.Adam(params, lr=learning_rate)

# Training the autoencoder
batch_size = 256
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)


exp_avg_loss = None

iter = 0
epoch = 0

while True:
    batch_bat = tqdm(train_dataloader)
    for batch in batch_bat:
        if iter%eval_iters == 0:
            for val_batch in val_dataloader:
                val_batch.to(device)
                with torch.no_grad():
                    reconstructed = compute_output(val_batch)
                    val_loss = criterion(reconstructed, val_batch[:, :, :-1])
                break               
        batch.to(device)
        lr = get_lr(iter)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        optimizer.zero_grad()
        
        reconstructed = compute_output(batch)
        loss = criterion(reconstructed, batch[:, :, :-1])
        loss.backward()
        optimizer.step()
        if exp_avg_loss == None:
            exp_avg_loss = loss
        exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
        batch_bat.set_description(f"Avg. Loss {exp_avg_loss}")
        wandb.log({"iter": iter, "loss": loss.item(), "val_loss": val_loss, "epoch": epoch, "lr": lr})
        iter += 1
        
        if iter > lr_decay_iters:
            break
    
    epoch += 1
        
    if iter > lr_decay_iters:
        break





0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
iter,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇███
loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▁▃▅▇████████████████████████████████████
val_loss,█▄▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,122.0
iter,1350.0
loss,0.16954
lr,0.00072
val_loss,0.16004


  0%|          | 0/11 [00:00<?, ?it/s]

In [None]:
autoencoder

VanillaAutoencoder(
  (encoder): PositionEncoder(
    (emb): Embedding(116, 256)
    (encoder): Sequential(
      (0): Linear(in_features=288, out_features=197, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=197, out_features=107, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=107, out_features=16, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (upscaling): Linear(in_features=16, out_features=256, bias=True)
  (decoder): PositionEncoder(
    (emb): Embedding(116, 256)
    (encoder): Sequential(
      (0): Linear(in_features=512, out_features=272, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=272, out_features=32, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [None]:
for key in dataset_weights[0][0]["state_dict"].keys():
    print(key)

IndexError: too many indices for tensor of dimension 3