In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import torch
from data.neural_field_datasets_shapenet import AllWeights3D, FlattenTransform3D, ImageTransform3D, ShapeNetDataset
from torch.utils.data import DataLoader
from networks.shapenet_ae import GlobalAutoencoder

from networks.shapenet_ae import VanillaDecoder, VanillaEncoder, PositionEncoder



dataset_weights = ShapeNetDataset(os.path.join("./", "datasets", "shapenet_nefs", "pretrained"), transform=ImageTransform3D())

input_size = 3712
dimensions = [input_size, int(input_size//1.25), int(input_size//1.5), int(input_size//1.75), int(input_size//2)]
encoder = VanillaEncoder(dimensions)
dimensions.reverse()
decoder = VanillaDecoder(dimensions)



In [None]:
import torch

# Assuming tensor of shape [3482, 116, 32]
tensor = torch.stack([weights for weights in dataset_weights])

num_vec = tensor.shape[1]

# Define the split ratio
split_ratio = 0.8
num_train = int(split_ratio * tensor.shape[0])

# Generate random indices
indices = torch.randperm(tensor.shape[0])

# Split the indices for training and validation
train_indices = indices[:num_train]
val_indices = indices[num_train:]

# Index the tensor to create training and validation sets
train_data = tensor[train_indices]
val_data = tensor[val_indices]

print(f'Training data shape: {train_data.shape}')
print(f'Validation data shape: {val_data.shape}')


In [4]:
input_dim = 32

# local encoder
latent_dims_local = [32] #[24]
num_layers_enc_local_list = [2, 3]


autoencoder_latent_dim = [32*116]

emb_dims = [1, 64]

num_layers_global = [2, 3, 4]



for latent_dim_global in autoencoder_latent_dim:
    for num_layer_global in num_layers_global:
        for num_layers_enc_local in num_layers_enc_local_list:
            for emb_dim in emb_dims:
                    for latent_dim_local in latent_dims_local:
                        
                        config = {
                            "input_dim": input_dim,
                            # local encoder
                            "latent_dim_local": latent_dim_local,
                            "num_layers_enc_local": num_layers_enc_local,
                            # global encoder
                            "latent_dim_global": latent_dim_global,
                            "num_layer_enc_global": num_layer_global,
                            # global decoder
                            "num_layers_dec_global": num_layer_global,
                            "emb_dim_local": emb_dim,
                            "num_vec_local": num_vec,      
                        }
                        
                        model = GlobalAutoencoder(**config)
                        

In [11]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math

def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

def interpolate_loss(iter, max_iter, loss_1, loss_2, max_loss_2 = 1.0):
    
    split_1, split_2 =  1 - iter / max_iter, iter / max_iter
    
    if split_2 > max_loss_2:
        split_1 = 1 - max_loss_2
        split_2 = max_loss_2
    
    return loss_1 * split_1, loss_2 * split_2, (split_1, split_2)
    
     
input_dim = 32

# local encoder
latent_dims_local = [16]
num_layers_enc_local_list = [2, 3]


autoencoder_latent_dim = [int(3712//2)]

emb_dims = [64, 128, 256, 512]

num_layers_global = [2, 3]



for latent_dim_global in autoencoder_latent_dim:
    for num_layer_global in num_layers_global:
        for num_layers_enc_local in num_layers_enc_local_list:
            for emb_dim in emb_dims:
                    for latent_dim_local in latent_dims_local:
                        
                        config = {
                            "input_dim": input_dim,
                            # local encoder
                            "latent_dim_local": latent_dim_local,
                            "num_layers_enc_local": num_layers_enc_local,
                            # global encoder
                            "latent_dim_global": latent_dim_global,
                            "num_layer_enc_global": num_layer_global,
                            # global decoder
                            "num_layers_dec_global": num_layer_global,
                            "emb_dim_local": emb_dim,
                            "num_vec_local": num_vec,  
                            "only_global_decode": True,    
                        }
                        
                        model = GlobalAutoencoder(**config)
                        print(model)

                        
                        lr_decay_iters = 2000
                        warmup_iters = 0.05 * lr_decay_iters
                        learning_rate = 0.001

                        eval_iters = 4

                        wandb.init(project="autoencoder", name=f"GRID SEARCH LOCAL+GLOBAL ## {config} ##")
                        
                        

                        criterion = nn.MSELoss()
                        optimizer = optim.Adam(model.parameters(), lr=learning_rate)

                        # Training the autoencoder
                        num_epochs = lr_decay_iters
                        batch_size = 16
                        train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
                        val_dataloader = DataLoader(val_data, batch_size=batch_size, shuffle=True)


                        exp_avg_loss = None

                        iter = 0

                        for epoch in range(num_epochs):
                            batch_bat = tqdm(train_dataloader)
                            for batch in batch_bat:
                                if iter%eval_iters == 0:
                                    for batch in val_dataloader:
                                        model.eval()
                                        reconstructed_global, reconstructed_local = model(batch)
                                        local_loss = criterion(reconstructed_local, batch[:, :, :-1])
                                        global_loss = criterion(reconstructed_global, torch.flatten(batch[:, :, :-1], start_dim=1))
                                        local_loss, global_loss, splits = interpolate_loss(iter, lr_decay_iters, local_loss, global_loss, max_loss_2=0.05)
                                        
                                        val_loss = local_loss + global_loss
                                        model.train() 
                                        break               
                                
                                lr = get_lr(iter)
                                for param_group in optimizer.param_groups:
                                    param_group["lr"] = lr
                                optimizer.zero_grad()
                                reconstructed_global, reconstructed_local = model(batch)
                                local_loss = criterion(reconstructed_local, batch[:, :, :-1])
                                global_loss = criterion(reconstructed_global, torch.flatten(batch[:, :, :-1], start_dim=1))
                                local_loss, global_loss, splits = interpolate_loss(iter, lr_decay_iters, local_loss, global_loss, max_loss_2=0.05)

                                loss = local_loss + global_loss
                                loss.backward()
                                optimizer.step()
                                if exp_avg_loss == None:
                                    exp_avg_loss = loss
                                exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
                                batch_bat.set_description(f"Avg. Loss {exp_avg_loss} split {splits}")
                                wandb.log({"iter": iter, "loss": loss.item(), "local_loss": local_loss, "global_loss": global_loss, "val_loss": val_loss, "epoch": epoch, "lr": lr})
                                iter += 1
                                
                                if iter > lr_decay_iters:
                                    break
                                
                            if iter > lr_decay_iters:
                                break
                        
                        print(config)



GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 64)
      (encoder): Sequential(
        (0): Linear(in_features=96, out_features=56, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=56, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 64)
      (encoder): Sequential(
        (0): Linear(in_features=80, out_features=56, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=56, out_features=32, bias=True)
      )
    )
  )
  (flattened_encoder): Identity()
  (decoder): VanillaDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=1856, out_features=2784, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=2784, out_features=3712, bias=True)
    )
  )
)


Avg. Loss 0.3406434953212738 split (0.913, 0.087): 100%|██████████| 175/175 [00:14<00:00, 12.23it/s]             
Avg. Loss 0.1345338523387909 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 14.38it/s]      
Avg. Loss 0.1110081672668457 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 13.22it/s] 
Avg. Loss 0.09851808845996857 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 14.56it/s]
Avg. Loss 0.09111965447664261 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 13.55it/s]
Avg. Loss 0.08379268646240234 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 13.71it/s]
Avg. Loss 0.0837230235338211 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 14.44it/s] 
Avg. Loss 0.08161348849534988 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 13.51it/s]
Avg. Loss 0.08715872466564178 split (0.9, 0.1): 100%|██████████| 175/175 [00:11<00:00, 14.60it/s]
Avg. Loss 0.07588676363229752 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 13.43it/s]

{'input_dim': 32, 'latent_dim_local': 16, 'num_layers_enc_local': 2, 'latent_dim_global': 1856, 'num_layer_enc_global': 2, 'num_layers_dec_global': 2, 'emb_dim_local': 64, 'num_vec_local': 116, 'only_global_decode': True}
GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 128)
      (encoder): Sequential(
        (0): Linear(in_features=160, out_features=88, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=88, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 128)
      (encoder): Sequential(
        (0): Linear(in_features=144, out_features=88, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=88, out_features=32, bias=True)
      )
    )
  )
  (flattened_encoder): Identity()
  (decoder): VanillaDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=1856, out_features=2784, bias=True)
      

Avg. Loss 0.2282196432352066 split (0.913, 0.087): 100%|██████████| 175/175 [01:28<00:00,  1.98it/s]              
Avg. Loss 0.111301489174366 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 12.51it/s]       
Avg. Loss 0.08954267203807831 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 14.35it/s]
Avg. Loss 0.0796278715133667 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 14.02it/s] 
Avg. Loss 0.07925333082675934 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 14.45it/s]
Avg. Loss 0.07724597305059433 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 13.60it/s]
Avg. Loss 0.07713853567838669 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 13.84it/s]
Avg. Loss 0.08295440673828125 split (0.9, 0.1): 100%|██████████| 175/175 [00:11<00:00, 14.71it/s]
Avg. Loss 0.07804232835769653 split (0.9, 0.1): 100%|██████████| 175/175 [00:12<00:00, 13.65it/s]
Avg. Loss 0.08287503570318222 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 13.39it/s

{'input_dim': 32, 'latent_dim_local': 16, 'num_layers_enc_local': 2, 'latent_dim_global': 1856, 'num_layer_enc_global': 2, 'num_layers_dec_global': 2, 'emb_dim_local': 128, 'num_vec_local': 116, 'only_global_decode': True}
GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 256)
      (encoder): Sequential(
        (0): Linear(in_features=288, out_features=152, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=152, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 256)
      (encoder): Sequential(
        (0): Linear(in_features=272, out_features=152, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=152, out_features=32, bias=True)
      )
    )
  )
  (flattened_encoder): Identity()
  (decoder): VanillaDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=1856, out_features=2784, bias=True)
 

Avg. Loss 0.17148073017597198 split (0.913, 0.087): 100%|██████████| 175/175 [00:16<00:00, 10.73it/s]             
Avg. Loss 0.08757571130990982 split (0.9, 0.1): 100%|██████████| 175/175 [01:26<00:00,  2.01it/s]     
Avg. Loss 0.0776762068271637 split (0.9, 0.1): 100%|██████████| 175/175 [00:19<00:00,  8.83it/s] 
Avg. Loss 0.08828021585941315 split (0.9, 0.1): 100%|██████████| 175/175 [00:18<00:00,  9.54it/s]
Avg. Loss 0.07314692437648773 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 13.20it/s]
Avg. Loss 0.07379056513309479 split (0.9, 0.1): 100%|██████████| 175/175 [00:14<00:00, 11.68it/s]
Avg. Loss 0.0725371465086937 split (0.9, 0.1): 100%|██████████| 175/175 [00:15<00:00, 11.51it/s] 
Avg. Loss 0.07358840107917786 split (0.9, 0.1): 100%|██████████| 175/175 [00:16<00:00, 10.88it/s]
Avg. Loss 0.07755842059850693 split (0.9, 0.1): 100%|██████████| 175/175 [00:16<00:00, 10.90it/s]
Avg. Loss 0.06854362785816193 split (0.9, 0.1): 100%|██████████| 175/175 [00:17<00:00,  9.89it/s

{'input_dim': 32, 'latent_dim_local': 16, 'num_layers_enc_local': 2, 'latent_dim_global': 1856, 'num_layer_enc_global': 2, 'num_layers_dec_global': 2, 'emb_dim_local': 256, 'num_vec_local': 116, 'only_global_decode': True}
GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 512)
      (encoder): Sequential(
        (0): Linear(in_features=544, out_features=280, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=280, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 512)
      (encoder): Sequential(
        (0): Linear(in_features=528, out_features=280, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=280, out_features=32, bias=True)
      )
    )
  )
  (flattened_encoder): Identity()
  (decoder): VanillaDecoder(
    (decoder): Sequential(
      (0): Linear(in_features=1856, out_features=2784, bias=True)
 

Avg. Loss 0.15787476301193237 split (0.913, 0.087): 100%|██████████| 175/175 [00:26<00:00,  6.60it/s]             
Avg. Loss 0.0814746767282486 split (0.9, 0.1): 100%|██████████| 175/175 [00:20<00:00,  8.48it/s]      
Avg. Loss 0.0755898505449295 split (0.9, 0.1): 100%|██████████| 175/175 [00:22<00:00,  7.84it/s] 
Avg. Loss 0.07231734693050385 split (0.9, 0.1): 100%|██████████| 175/175 [00:27<00:00,  6.40it/s]
Avg. Loss 0.07872050255537033 split (0.9, 0.1): 100%|██████████| 175/175 [00:22<00:00,  7.90it/s]
Avg. Loss 0.06980642676353455 split (0.9, 0.1): 100%|██████████| 175/175 [00:18<00:00,  9.42it/s]
Avg. Loss 0.06942547857761383 split (0.9, 0.1): 100%|██████████| 175/175 [00:19<00:00,  8.86it/s]
Avg. Loss 0.07027234137058258 split (0.9, 0.1): 100%|██████████| 175/175 [00:19<00:00,  8.86it/s]
Avg. Loss 0.06385580450296402 split (0.9, 0.1): 100%|██████████| 175/175 [00:20<00:00,  8.67it/s]
Avg. Loss 0.06243082135915756 split (0.9, 0.1): 100%|██████████| 175/175 [00:21<00:00,  8.25it/s

{'input_dim': 32, 'latent_dim_local': 16, 'num_layers_enc_local': 2, 'latent_dim_global': 1856, 'num_layer_enc_global': 2, 'num_layers_dec_global': 2, 'emb_dim_local': 512, 'num_vec_local': 116, 'only_global_decode': True}
GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 64)
      (encoder): Sequential(
        (0): Linear(in_features=96, out_features=69, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=69, out_features=43, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=43, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 64)
      (encoder): Sequential(
        (0): Linear(in_features=80, out_features=64, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=64, out_features=48, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=48, out_fea

Avg. Loss 0.3793942928314209 split (0.913, 0.087): 100%|██████████| 175/175 [00:16<00:00, 10.45it/s]             
Avg. Loss 0.15212559700012207 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 12.86it/s]     
Avg. Loss 0.11776819080114365 split (0.9, 0.1): 100%|██████████| 175/175 [00:14<00:00, 11.70it/s]
Avg. Loss 0.0959445908665657 split (0.9, 0.1): 100%|██████████| 175/175 [00:16<00:00, 10.41it/s] 
Avg. Loss 0.08959411829710007 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 12.71it/s]
Avg. Loss 0.08385311812162399 split (0.9, 0.1): 100%|██████████| 175/175 [00:14<00:00, 11.78it/s]
Avg. Loss 0.08855654299259186 split (0.9, 0.1): 100%|██████████| 175/175 [00:13<00:00, 12.71it/s]
Avg. Loss 0.08104921132326126 split (0.9, 0.1): 100%|██████████| 175/175 [00:15<00:00, 11.01it/s]
Avg. Loss 0.08004723489284515 split (0.9, 0.1): 100%|██████████| 175/175 [00:21<00:00,  8.20it/s]
Avg. Loss 0.08007019758224487 split (0.9, 0.1): 100%|██████████| 175/175 [00:22<00:00,  7.80it/s]

{'input_dim': 32, 'latent_dim_local': 16, 'num_layers_enc_local': 3, 'latent_dim_global': 1856, 'num_layer_enc_global': 2, 'num_layers_dec_global': 2, 'emb_dim_local': 64, 'num_vec_local': 116, 'only_global_decode': True}
GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 128)
      (encoder): Sequential(
        (0): Linear(in_features=160, out_features=112, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=112, out_features=64, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=64, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 128)
      (encoder): Sequential(
        (0): Linear(in_features=144, out_features=107, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=107, out_features=69, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=69, 

Avg. Loss 0.21710175275802612 split (0.913, 0.087): 100%|██████████| 175/175 [00:29<00:00,  5.84it/s]             
Avg. Loss 0.10726000368595123 split (0.9, 0.1): 100%|██████████| 175/175 [00:21<00:00,  8.15it/s]     
Avg. Loss 0.09101463854312897 split (0.9, 0.1): 100%|██████████| 175/175 [00:19<00:00,  8.97it/s]
Avg. Loss 0.08633514493703842 split (0.9, 0.1): 100%|██████████| 175/175 [00:20<00:00,  8.42it/s]
Avg. Loss 0.07535450905561447 split (0.9, 0.1): 100%|██████████| 175/175 [00:21<00:00,  8.03it/s]
Avg. Loss 0.08257859945297241 split (0.9, 0.1): 100%|██████████| 175/175 [00:19<00:00,  8.93it/s]
Avg. Loss 0.07806489616632462 split (0.9, 0.1): 100%|██████████| 175/175 [00:17<00:00, 10.23it/s]
Avg. Loss 0.07491916418075562 split (0.9, 0.1): 100%|██████████| 175/175 [00:15<00:00, 11.23it/s]
Avg. Loss 0.07799524068832397 split (0.9, 0.1): 100%|██████████| 175/175 [00:15<00:00, 11.49it/s]
Avg. Loss 0.07448960840702057 split (0.9, 0.1): 100%|██████████| 175/175 [00:14<00:00, 12.28it/s

{'input_dim': 32, 'latent_dim_local': 16, 'num_layers_enc_local': 3, 'latent_dim_global': 1856, 'num_layer_enc_global': 2, 'num_layers_dec_global': 2, 'emb_dim_local': 128, 'num_vec_local': 116, 'only_global_decode': True}
GlobalAutoencoder(
  (local_autoencoder): VanillaAutoencoder(
    (encoder): PositionEncoder(
      (emb): Embedding(116, 256)
      (encoder): Sequential(
        (0): Linear(in_features=288, out_features=197, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=197, out_features=107, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=107, out_features=16, bias=True)
      )
    )
    (decoder): PositionEncoder(
      (emb): Embedding(116, 256)
      (encoder): Sequential(
        (0): Linear(in_features=272, out_features=192, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=192, out_features=112, bias=True)
        (3): GELU(approximate='none')
        (4): Linear(in_features=

Avg. Loss 0.1685052514076233 split (0.913, 0.087): 100%|██████████| 175/175 [00:16<00:00, 10.56it/s]              
Avg. Loss 0.09335536509752274 split (0.9, 0.1): 100%|██████████| 175/175 [00:18<00:00,  9.34it/s]     
Avg. Loss 0.08205533772706985 split (0.9, 0.1): 100%|██████████| 175/175 [00:18<00:00,  9.31it/s]
Avg. Loss 0.0757419690489769 split (0.9, 0.1): 100%|██████████| 175/175 [00:17<00:00, 10.26it/s] 
Avg. Loss 0.0761820450425148 split (0.9, 0.1): 100%|██████████| 175/175 [00:23<00:00,  7.50it/s] 
Avg. Loss 0.07536119222640991 split (0.9, 0.1): 100%|██████████| 175/175 [00:21<00:00,  8.20it/s]
Avg. Loss 0.07377185672521591 split (0.9, 0.1):  62%|██████▏   | 108/175 [00:20<00:12,  5.23it/s]


KeyboardInterrupt: 

In [None]:
# local encoder
latent_dims_local = [24]
num_layers_enc_local_list = [2, 3]


autoencoder_latent_dim = [int(3712//2)]

emb_dims = [256, 512]

num_layers_global = [2, 3]

config = {
        "input_dim": input_dim,
        # local encoder
        "latent_dim_local": latent_dim_local,
        "num_layers_enc_local": num_layers_enc_local,
        # global encoder
        "latent_dim_global": latent_dim_global,
        "num_layer_enc_global": num_layer_global,
        # global decoder
        "num_layers_dec_global": num_layer_global,
        "emb_dim_local": emb_dim,
        "num_vec_local": num_vec,      
    } 
    
model = GlobalAutoencoder(**config)


In [None]:
batch[0].shape

torch.Size([33])

In [None]:


train_dataloader = DataLoader(dataset_weights, batch_size=64, shuffle=True)

In [None]:
for key in dataset_weights[0][0]["state_dict"].keys():
    print(key)

IndexError: too many indices for tensor of dimension 3