In [8]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [9]:
import os

import torch
from data.neural_field_datasets_shapenet import AllWeights3D, FlattenTransform3D, ImageTransform3D, ModelTransform3D, ShapeNetDataset
from torch.utils.data import DataLoader
from networks.shapenet_ae import GlobalAutoencoder

from networks.shapenet_ae import VanillaDecoder, VanillaEncoder, PositionEncoder


image_transform = ImageTransform3D()

dataset_weights = ShapeNetDataset(os.path.join("./", "datasets", "shapenet_nefs", "pretrained"), transform=image_transform)
dataset = ShapeNetDataset(os.path.join("./", "datasets", "shapenet_nefs", "pretrained"))

model_transform = ModelTransform3D(dataset[0][0]["model_config"])

input_size = 3712
dimensions = [input_size, int(input_size//1.25), int(input_size//1.5), int(input_size//1.75), int(input_size//2)]
encoder = VanillaEncoder(dimensions)
dimensions.reverse()
decoder = VanillaDecoder(dimensions)



In [10]:
import torch

# Assuming tensor of shape [3482, 116, 32]
tensor = torch.cat([weights[0] for weights in dataset_weights])

tensor = tensor.permute(0, 2, 1)
tensor_centered = (tensor - tensor.mean(dim=0).mean(dim=0))

num_nef, num_vec, dim_vec = tensor.shape
pos_enc = torch.Tensor([i for _ in range(num_nef) for i in range(num_vec)]).view(num_nef, num_vec, 1).to(device="cuda")
tensor_w_pos = torch.cat((tensor_centered, pos_enc), dim =2)

# Define the split ratio
split_ratio = 0.8
num_train = int(split_ratio * tensor_centered.shape[0])

# Generate random indices
indices = torch.randperm(tensor_centered.shape[0])

# Split the indices for training and validation
train_indices = indices[:num_train]
val_indices = indices[num_train:]

# Index the tensor to create training and validation sets
train_data = tensor_w_pos[train_indices]
val_data = tensor_w_pos[val_indices]

print(f'Training data shape: {train_data.shape}')
print(f'Validation data shape: {val_data.shape}')


Training data shape: torch.Size([2785, 116, 33])
Validation data shape: torch.Size([697, 116, 33])


In [11]:
16*16

256

In [13]:
## Train Autoencoder


import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math

from networks.nano_gpt import GPTConfig
from networks.nano_gpt_adapted import TransformerDecoder, TransformerDecoderConfig
from networks.shapenet_ae import VanillaAutoencoder



def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

def interpolate_loss(iter, max_iter, loss_1, loss_2, max_loss_2 = 1.0):
    
    split_1, split_2 =  1 - iter / max_iter, iter / max_iter
    
    if split_2 > max_loss_2:
        split_1 = 1 - max_loss_2
        split_2 = max_loss_2
    
    return loss_1 * split_1, loss_2 * split_2, (split_1, split_2)
    

device = torch.device("cuda")


v_ae_config = {
    "input_dim": 32,
    "latent_dim": 24,
    "num_layers_enc": 3,
    "num_layers_dec": 2,
    "emb_dim": 1024,
    "num_vec": num_vec,
    "local_latent_emb": 1024,
    "with_upscaling": True
}

model_config = TransformerDecoderConfig(n_embd=64, block_size=num_vec, n_head=8, n_layer=12, output_size=32, dropout=0.0)


autoencoder =VanillaAutoencoder(**v_ae_config)

autoencoder.to(device)


lr_decay_iters = 2000
warmup_iters = 0.05 * lr_decay_iters
learning_rate = 0.005

eval_iters = 4



wandb.init(project="autoencoder", name=f"Autoencoder Transformer")



criterion = nn.MSELoss()
optimizer = optim.Adam(list(autoencoder.parameters()) + list(decoder.parameters()), lr=learning_rate)

# Training the autoencoder
batch_size = 128
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)


exp_avg_loss = None

iter = 0
epoch = 0

while True:
    batch_bat = tqdm(train_dataloader)
    for batch in batch_bat:
        if iter%eval_iters == 0:
            for val_batch in val_dataloader:
                val_batch.to(device)
                autoencoder.eval()
                with torch.no_grad():
                    reconstructed_local, latent = autoencoder(val_batch)
                    local_loss = criterion(reconstructed_local[:, :, :-1], val_batch[:, :, :-1])
                    
                    val_loss = local_loss
                autoencoder.train()
        batch.to(device)
        lr = get_lr(iter)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        optimizer.zero_grad()
        
        
        reconstructed_local, latent = autoencoder(batch)
        local_loss = criterion(reconstructed_local[:, :, :-1], batch[:, :, :-1])

        loss = local_loss
        loss.backward()
        optimizer.step()
        if exp_avg_loss == None:
            exp_avg_loss = loss
        exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
        batch_bat.set_description(f"Avg. Loss {exp_avg_loss}")
        wandb.log({"iter": iter, "loss": loss.item(), "local_loss": local_loss, "val_loss": val_loss, "epoch": epoch, "lr": lr})
        iter += 1
        
        if iter > lr_decay_iters:
            break
    
    epoch += 1
        
    if iter > lr_decay_iters:
        break




0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
iter,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
local_loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
loss,█▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr,▂▅███████▇▇▇▇▇▆▆▆▅▅▅▅▄▄▄▃▃▃▃▂▂▂▂▂▁▁▁▁▁▁▁
val_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,90.0
iter,2000.0
local_loss,0.00612
loss,0.00612
lr,0.0
val_loss,0.00696


Avg. Loss 0.7663930058479309: 100%|██████████| 22/22 [00:00<00:00, 29.07it/s]
Avg. Loss 0.41943973302841187: 100%|██████████| 22/22 [00:00<00:00, 30.68it/s]
Avg. Loss 0.27471262216567993: 100%|██████████| 22/22 [00:00<00:00, 29.00it/s]
Avg. Loss 0.18572014570236206: 100%|██████████| 22/22 [00:00<00:00, 30.67it/s]
Avg. Loss 0.12310180813074112: 100%|██████████| 22/22 [00:00<00:00, 29.02it/s]
Avg. Loss 0.0802365392446518: 100%|██████████| 22/22 [00:00<00:00, 30.56it/s] 
Avg. Loss 0.054042112082242966: 100%|██████████| 22/22 [00:00<00:00, 26.94it/s]
Avg. Loss 0.036516208201646805: 100%|██████████| 22/22 [00:00<00:00, 30.60it/s]
Avg. Loss 0.026528624817728996: 100%|██████████| 22/22 [00:00<00:00, 28.90it/s]
Avg. Loss 0.021740617230534554: 100%|██████████| 22/22 [00:00<00:00, 30.65it/s]
Avg. Loss 0.018000302836298943: 100%|██████████| 22/22 [00:00<00:00, 28.91it/s]
Avg. Loss 0.015384499914944172: 100%|██████████| 22/22 [00:00<00:00, 30.69it/s]
Avg. Loss 0.014040342532098293: 100%|██████████

In [12]:
reconstruction.shape

torch.Size([1, 116, 32])

In [15]:
reconstruction.permute(0, 2, 1)

tensor([[[ 0.1138,  0.4034, -0.1642,  ..., -0.2770,  0.0835, -2.5250],
         [ 0.4246, -0.2133, -0.1942,  ...,  1.1884,  0.2018, -0.8811],
         [-0.7344, -0.8400,  0.9013,  ...,  0.0668,  0.1827, -0.4337],
         ...,
         [ 0.6513, -1.4336, -0.5136,  ..., -2.5951, -2.1750, -0.7618],
         [-3.0182, -0.5292,  0.4926,  ...,  1.5645,  0.4605, -0.2840],
         [-0.4659, -2.0861,  0.6799,  ...,  1.1112, -0.7230, -1.3389]]],
       device='cuda:0', grad_fn=<PermuteBackward0>)

In [12]:
from utils.visualization3d import model_to_mesh


autoencoder = torch.load("./models/autoencoder.pth")
for params in autoencoder.parameters():
    params.requiere_grad = False
    
mean = tensor.mean(dim=0).mean(dim=0)
    
sample = val_data[0].unsqueeze(0)

reconstruction, latent = autoencoder(sample)
reconstruction = reconstruction[:, :, :-1] + mean
sample = sample[:, :, :-1] + mean

model_dict_reconstructed = image_transform.inverse(reconstruction.permute(0, 2, 1))
model_dict_sample = image_transform.inverse(sample.permute(0, 2, 1))

model_reconstructed = model_transform(model_dict_reconstructed)[0]
model_sample = model_transform(model_dict_reconstructed)[0]

# First original
mesh, sdf = model_to_mesh(model_reconstructed, res=256)
mesh.show()


In [15]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math

from networks.nano_gpt import GPTConfig
from networks.nano_gpt_adapted import TransformerDecoder, TransformerDecoderConfig
from networks.shapenet_ae import VanillaAutoencoder



def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

def interpolate_loss(iter, max_iter, loss_1, loss_2, max_loss_2 = 1.0):
    
    split_1, split_2 =  1 - iter / max_iter, iter / max_iter
    
    if split_2 > max_loss_2:
        split_1 = 1 - max_loss_2
        split_2 = max_loss_2
    
    return loss_1 * split_1, loss_2 * split_2, (split_1, split_2)
    
     
input_dim = 32
num_layers_dec_local = 2

# local encoder
num_layers_enc_local_list = [3]


autoencoder_latent_dim = [int(3712//2)]

emb_dims = [2048]

# dimension of latent space positional encoding
emb_local_latents = [2048]
latent_dims_local = [21]

num_layers_global = [2]

device = torch.device("cuda")


v_ae_config = {
    "input_dim": 32,
    "latent_dim": 16,
    "num_layers_enc": 3,
    "num_layers_dec": 2,
    "emb_dim": 256,
    "num_vec": num_vec,
    "local_latent_emb": 256,
    "with_upscaling": True
}

model_config = TransformerDecoderConfig(n_embd=128, block_size=num_vec, n_head=8, n_layer=4, output_size=128, dropout=0.0)


class MLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(MLP, self).__init__()
        self.layer1 = nn.Linear(input_dim, v_ae_config["latent_dim"])
        self.layer2 = nn.Linear(v_ae_config["latent_dim"], output_dim)
        self.activation = nn.GELU()
        
    def forward(self, x):
        x = self.activation(self.layer1(x))
        x = self.layer2(x)
        return x
    
mlps = [MLP(autoencoder.encoder.encoder[-1].out_features, model_config.n_embd) for _ in range(116)]

[mlp.to(device) for mlp in mlps]

decoder = nn.Sequential(
    TransformerDecoder(model_config),
    nn.Linear(model_config.output_size, 64),
    nn.GELU(),
    nn.Linear(64, 32),

)

autoencoder.to(device)
decoder.to(device)


lr_decay_iters = 10000
warmup_iters = 0.05 * lr_decay_iters
learning_rate = 0.001

eval_iters = 4



wandb.init(project="autoencoder", name=f"Autoencoder Transformer")



criterion = nn.MSELoss()
optimizer = optim.Adam(list(autoencoder.parameters()) + list(decoder.parameters()), lr=learning_rate)

# Training the autoencoder
batch_size = 16
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)


exp_avg_loss = None

iter = 0
epoch = 0

while True:
    batch_bat = tqdm(train_dataloader)
    for batch in batch_bat:
        if iter%eval_iters == 0:
            for val_batch in val_dataloader:
                val_batch.to(device)
                decoder.eval()
                with torch.no_grad():
                    reconstructed_local, latent = autoencoder(val_batch)
                    latent_transformed = torch.stack([mlps[i](latent[:, i, :]) for i in range(116)], dim = 1)
                    reconstruction_transformer = decoder(latent_transformed)
                    local_loss = criterion(reconstructed_local[:, :, :-1], val_batch[:, :, :-1])
                    global_loss = criterion(reconstruction_transformer, val_batch[:, :, :-1])
                    local_loss_split, global_loss_split, splits = interpolate_loss(iter, lr_decay_iters, local_loss, global_loss, max_loss_2=0.9)
                    
                    val_loss = global_loss# local_loss_split + global_loss_split
                decoder.train()
                break               
        batch.to(device)
        lr = get_lr(iter)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        optimizer.zero_grad()
        
        
        reconstructed_local, latent = autoencoder(batch)
        latent_transformed = torch.stack([mlps[i](latent[:, i, :]) for i in range(116)], dim = 1)
        reconstruction_transformer = decoder(latent_transformed)
        local_loss = criterion(reconstructed_local[:, :, :-1], batch[:, :, :-1])
        global_loss = criterion(reconstruction_transformer, batch[:, :, :-1])
        local_loss_split, global_loss_split, splits = interpolate_loss(iter, lr_decay_iters, local_loss, global_loss, max_loss_2=0.9)

        loss = global_loss # local_loss_split + global_loss_split
        loss.backward()
        optimizer.step()
        if exp_avg_loss == None:
            exp_avg_loss = loss
        exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
        batch_bat.set_description(f"Avg. Loss {exp_avg_loss} split {splits}")
        wandb.log({"iter": iter, "loss": loss.item(), "local_loss": local_loss, "global_loss": global_loss, "val_loss": val_loss, "epoch": epoch, "lr": lr, "split 0": splits[0], "split 1": splits[1]})
        iter += 1
        
        if iter > lr_decay_iters:
            break
    
    epoch += 1
        
    if iter > lr_decay_iters:
        break





number of parameters: 1.65M


Avg. Loss 0.2875300645828247 split (0.9826, 0.0174): 100%|██████████| 175/175 [00:03<00:00, 45.25it/s] 
Avg. Loss 0.15839022397994995 split (0.9651, 0.0349): 100%|██████████| 175/175 [00:03<00:00, 46.15it/s]            
Avg. Loss 0.12080341577529907 split (0.9476, 0.0524): 100%|██████████| 175/175 [00:03<00:00, 45.57it/s]            
Avg. Loss 0.11903344839811325 split (0.9301, 0.0699): 100%|██████████| 175/175 [00:04<00:00, 43.26it/s]            
Avg. Loss 0.1005176529288292 split (0.9126, 0.0874): 100%|██████████| 175/175 [00:04<00:00, 43.38it/s]             
Avg. Loss 0.09641324728727341 split (0.8951, 0.1049): 100%|██████████| 175/175 [00:03<00:00, 43.77it/s]            
Avg. Loss 0.0867348462343216 split (0.8776, 0.1224): 100%|██████████| 175/175 [00:03<00:00, 44.31it/s]             
Avg. Loss 0.085593581199646 split (0.8601, 0.1399): 100%|██████████| 175/175 [00:03<00:00, 46.41it/s]              
Avg. Loss 0.08202693611383438 split (0.8426, 0.1574): 100%|██████████| 175/175 [00:0

In [12]:
decoder

TransformerDecoder(
  (transformer): ModuleDict(
    (wpe): Embedding(116, 128)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0-5): 6 x Block(
        (ln_1): LayerNorm()
        (attn): CausalSelfAttention(
          (c_attn): Linear(in_features=128, out_features=384, bias=True)
          (c_proj): Linear(in_features=128, out_features=128, bias=True)
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm()
        (mlp): MLP(
          (c_fc): Linear(in_features=128, out_features=512, bias=True)
          (gelu): GELU(approximate='none')
          (c_proj): Linear(in_features=512, out_features=128, bias=True)
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm()
  )
  (lm_head): Linear(in_features=128, out_features=32, bias=False)
)

In [11]:
autoencoder

VanillaAutoencoder(
  (encoder): PositionEncoder(
    (emb): Embedding(116, 256)
    (encoder): Sequential(
      (0): Linear(in_features=288, out_features=197, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=197, out_features=107, bias=True)
      (3): GELU(approximate='none')
      (4): Linear(in_features=107, out_features=16, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (upscaling): Linear(in_features=16, out_features=256, bias=True)
  (decoder): PositionEncoder(
    (emb): Embedding(116, 256)
    (encoder): Sequential(
      (0): Linear(in_features=512, out_features=272, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=272, out_features=32, bias=True)
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [None]:
for key in dataset_weights[0][0]["state_dict"].keys():
    print(key)

IndexError: too many indices for tensor of dimension 3