In [2]:
# Enable autoreload of module
%load_ext autoreload
%autoreload 2

In [3]:
import torch
from torchvision import transforms
from networks.mlp_models import MLP3D
from data.neural_field_datasets_shapenet import AllWeights3D, ModelTransform3D, ShapeNetDataset, FlattenTransform3D, ZScore3D, get_neuron_mean_n_std, get_total_mean_n_std


shapeNetData = ShapeNetDataset("./datasets/plane_mlp_weights", transform=AllWeights3D())
mean, std = get_total_mean_n_std(shapeNetData)
normalizer = ZScore3D(mean, std)
shapeNetData_normalized = ShapeNetDataset("./datasets/plane_mlp_weights", transform=[AllWeights3D(), normalizer])

all_weights = torch.stack([sample[0] for sample in shapeNetData])
all_weights_normalized = torch.stack([sample[0] for sample in shapeNetData_normalized])

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math


warmup_iters = 100
lr_decay_iters = 150000
learning_rate = 0.003

def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

wandb.init(project="autoencoder")


data = all_weights_normalized.view(-1, 128)

pos_enc = torch.Tensor([i for _ in range(3883) for i in range(287)]).unsqueeze(-1)
nef_enc = torch.Tensor([i for i in range(3883) for _ in range(287)]).unsqueeze(-1)
concatenated_data = torch.cat((data, pos_enc, nef_enc), dim=1)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.n_emb_input = 2048
        self.n_emb_latent = 2048
        
        self.nef_input = nn.Embedding(3883, self.n_emb_input + 128)
        self.emb_input = nn.Embedding(287, self.n_emb_input)
        self.emb_latent = nn.Embedding(287, self.n_emb_latent)
        
        self.encoder = nn.Sequential(
            nn.Linear(128 + self.n_emb_input, 112 + int(self.n_emb_input/2)),
            nn.GELU(),
            nn.Linear(112 + int(self.n_emb_input/2), 96 + int(self.n_emb_input/4)),
            nn.GELU(),
            nn.Linear(96 + int(self.n_emb_input/4), 96 + int(self.n_emb_input/8)),
            nn.GELU(),
            nn.Linear(96 + int(self.n_emb_input/8), 96),
        )
        self.decoder = nn.Sequential(
            nn.Linear(96 + self.n_emb_latent, 64 + self.n_emb_latent),
            nn.GELU(),
            nn.Linear(64 + self.n_emb_latent, 64 + self.n_emb_latent),
            nn.GELU(),
            nn.Linear(64 + self.n_emb_latent, 128),
            nn.GELU(),
            nn.Linear(128, 128)
        )

    def forward(self, x):
        pos = x[:, -2].int()
        nef = x[:, -1].int()
        embedding_input = self.emb_input(pos)
        nef_input = self.nef_input(nef)
        x = nef_input + torch.cat((x[:, :-2], embedding_input), dim=1)
        latent = self.encoder(x)
        embedding_latent = self.emb_latent(pos)

        reconstructed = self.decoder(torch.cat((latent, embedding_latent), dim=1))
        return latent, reconstructed

# Initialize the model, loss function, and optimizer
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training the autoencoder
num_epochs = lr_decay_iters
batch_size = 2048
data_loader = torch.utils.data.DataLoader(concatenated_data, batch_size=batch_size, shuffle=True)

exp_avg_loss = None

iters = 0

while True:
    bar = tqdm(data_loader)
    for batch in bar:
        lr = get_lr(iters)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        optimizer.zero_grad()
        latent, reconstructed = model(batch)
        loss = criterion(reconstructed, batch[:, :128])
        loss.backward()
        optimizer.step()
        if exp_avg_loss == None:
            exp_avg_loss = loss
        exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
        bar.set_description("Avg. Loss %f, Loss %f" % (exp_avg_loss, loss))
        iters += 1
    
        wandb.log({"lr": lr, "iters": iters, "loss": exp_avg_loss.item()})
    if iters > lr_decay_iters:
            break


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
[34m[1mwandb[0m: Currently logged in as: [33mluis-muschal[0m ([33madl-for-cv[0m). Use [1m`wandb login --relogin`[0m to force relogin


0.00s - make the debugger miss breakpoints. Please pass -Xfrozen_modules=off
0.00s - to python to disable frozen modules.
0.00s - Note: Debugging will proceed. Set PYDEVD_DISABLE_FILE_VALIDATION=1 to disable this validation.
Avg. Loss 0.461147, Loss 0.427890: 100%|██████████| 545/545 [05:08<00:00,  1.77it/s]
Avg. Loss 0.395498, Loss 0.376550: 100%|██████████| 545/545 [05:44<00:00,  1.58it/s]
Avg. Loss 0.344375, Loss 0.341331: 100%|██████████| 545/545 [06:14<00:00,  1.46it/s]
Avg. Loss 0.312484, Loss 0.249385: 100%|██████████| 545/545 [07:19<00:00,  1.24it/s]  
Avg. Loss 0.289615, Loss 0.277658: 100%|██████████| 545/545 [08:06<00:00,  1.12it/s]
Avg. Loss 0.279371, Loss 0.284654:  68%|██████▊   | 368/545 [04:36<01:16,  2.32it/s]

In [None]:
concatenated_data.shape

torch.Size([1114421, 129])

In [19]:
concatenated_data[0]

tensor([-1.1579e+00, -1.2348e-01,  1.2735e-01, -1.9410e-01,  0.0000e+00,
         1.4119e-01, -2.1411e-01, -1.0901e-01,  7.1479e-02,  7.8582e-01,
         9.6256e-01, -1.0795e-01,  9.3229e-02, -1.8610e-02,  8.4163e-02,
        -1.1825e-01,  2.4034e-02,  1.5581e-01,  0.0000e+00, -1.8264e-01,
         2.2689e-01,  2.4634e-01, -6.5313e-02,  2.2073e-01, -4.3676e-01,
         3.6150e-01,  5.8303e-01, -7.4506e-09,  6.9505e-02, -8.6208e-02,
         4.4996e-02, -1.3389e-02, -2.3219e-01, -1.0744e-01,  2.8798e-02,
         1.2213e+00, -3.6251e-02, -1.4901e-08, -2.8241e-01, -1.7725e-01,
        -1.2497e-01,  2.3557e-01,  7.7375e-02,  3.4081e-01,  2.3977e-01,
        -1.8440e-02, -1.2832e-01, -1.0434e-01, -2.2410e-01,  2.2442e-02,
         5.5499e-02,  1.2483e-01,  7.7722e-02,  2.9068e-02, -3.6923e-02,
         1.4052e-03,  5.6786e-01, -1.3231e-01,  6.0694e-02,  6.6704e-02,
         2.3971e-01, -3.0080e-01,  6.2760e-01,  5.7211e-01, -4.7014e-02,
         4.7261e-03,  1.9259e-01, -1.0223e+00,  0.0

In [28]:
(model(concatenated_data[0].unsqueeze(0))[1] - concatenated_data[0][:128])

tensor([[-0.0105, -0.0242, -0.5174,  0.0126,  0.0018,  0.0026, -0.0056,  0.1566,
         -0.0732, -0.0027, -0.0163,  0.0014,  0.0024, -0.0027, -0.0122, -0.0124,
         -0.0525, -0.1384, -0.0062,  0.1295,  0.0293,  0.0072, -0.0201, -0.1308,
          0.0160,  0.0346,  0.0219,  0.0151, -0.0260,  0.0174,  0.0251,  0.0021,
         -0.0230,  0.1000,  0.0115, -0.5205, -0.0046,  0.0082, -0.0076, -0.0259,
          0.0161, -0.0170, -0.0746, -0.0072, -0.0131,  0.0195,  0.0669,  0.0038,
         -0.0059, -0.0287,  0.0188,  0.0007, -0.0128,  0.0606, -0.1089, -0.0363,
          0.0075, -0.0084,  0.0048,  0.2535, -0.0047,  0.3058, -0.0111,  0.0151,
          0.0021, -0.0139, -0.0429, -0.0248,  0.0041, -0.0117, -0.0936,  0.0323,
         -0.0492, -0.0247,  0.0375, -0.0074,  0.0176,  0.0823, -0.0094,  0.0222,
          0.0109,  0.0145, -0.0183, -0.0116, -0.0057,  0.0011,  0.0222,  0.0036,
          0.0516,  0.0510, -0.0074,  0.0091, -0.5948, -0.0168, -0.0750,  0.0274,
          0.0444,  0.0500,  

In [1]:
import torch
from torchvision import transforms
from networks.mlp_models import MLP3D
from data.neural_field_datasets_shapenet import AllWeights3D, ModelTransform3D, ShapeNetDataset, FlattenTransform3D, ZScore3D, get_neuron_mean_n_std, get_total_mean_n_std


shapeNetData = ShapeNetDataset("./datasets/plane_mlp_weights", transform=AllWeights3D())
means, stds = get_total_mean_n_std(shapeNetData)
shapeNetData_simple_normalized = ShapeNetDataset("./datasets/plane_mlp_weights", transform=[AllWeights3D(), ZScore3D(means, stds)])

all_weights_simple_normalized = torch.stack([sample[0] for sample in shapeNetData_simple_normalized])

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math


warmup_iters = 100
lr_decay_iters = 150000
learning_rate = 0.003

def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

wandb.init(project="autoencoder")


# Generate class labels
labels = torch.arange(287).repeat_interleave(4045)

data = all_weights_simple_normalized.view(-1, 128)

pos_enc = torch.Tensor([i for _ in range(3883) for i in range(287)]).unsqueeze(-1)
concatenated_data = torch.cat((data, pos_enc), dim=1)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.n_emb_input = 512
        self.n_emb_latent = 512
        
        self.emb_input = nn.Embedding(287, self.n_emb_input)
        self.emb_latent = nn.Embedding(287, self.n_emb_latent)
        
        self.encoder = nn.Sequential(
            nn.Linear(128 + self.n_emb_input, 112 + int(self.n_emb_input/2)),
            nn.GELU(),
            nn.Linear(112 + int(self.n_emb_input/2), 96 + int(self.n_emb_input/4)),
            nn.GELU(),
            nn.Linear(96 + int(self.n_emb_input/4), 96 + int(self.n_emb_input/8)),
            nn.GELU(),
            nn.Linear(96 + int(self.n_emb_input/8), 96 + int(self.n_emb_input/16)),
            nn.GELU(),
            nn.Linear(96 + int(self.n_emb_input/16), 96),
        )
        self.decoder = nn.Sequential(
            nn.Linear(96 + self.n_emb_latent, 64 + self.n_emb_latent),
            nn.GELU(),
            nn.Linear(64 + self.n_emb_latent, 64 + self.n_emb_latent),
            nn.GELU(),
            nn.Linear(64 + self.n_emb_latent, 128),
            nn.GELU(),
            nn.Linear(128, 128)
        )

    def forward(self, x):
        pos = x[:, -1].int()
        embedding_input = self.emb_input(pos)
        x = torch.cat((x[:, :-1], embedding_input), dim=1)
        latent = self.encoder(x)
        embedding_latent = self.emb_latent(pos)

        reconstructed = self.decoder(torch.cat((latent, embedding_latent), dim=1))
        return latent, reconstructed

# Initialize the model, loss function, and optimizer
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training the autoencoder
num_epochs = lr_decay_iters
batch_size = 2048
data_loader = torch.utils.data.DataLoader(concatenated_data, batch_size=batch_size, shuffle=True)

exp_avg_loss = None

iters = 0

while True:
    bar = tqdm(data_loader)
    for batch in bar:
        lr = get_lr(iters)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        optimizer.zero_grad()
        latent, reconstructed = model(batch)
        loss = criterion(reconstructed, batch[:, :128])
        loss.backward()
        optimizer.step()
        if exp_avg_loss == None:
            exp_avg_loss = loss
        exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
        bar.set_description("Avg. Loss %f, Loss %f" % (exp_avg_loss, loss))
        iters += 1
    
        wandb.log({"lr": lr, "iters": iters, "loss": exp_avg_loss.item()})
    if iters > lr_decay_iters:
            break

# Project the data into the latent space
with torch.no_grad():
    latent_space, _ = model(data)

# Visualization
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Generate a color map for the 27 classes
cmap = plt.get_cmap("tab20", 27)
colors = cmap(labels)

# # Plot each class with its respective color
for class_idx in range(27):
    class_data = latent_space[labels == class_idx]
    ax.scatter(class_data[:, 0], class_data[:, 1], class_data[:, 2], label=f'Class {class_idx}', alpha=0.6, c=colors[labels == class_idx])


ax.set_xlabel('Latent Dimension 1')
ax.set_ylabel('Latent Dimension 2')
ax.set_zlabel('Latent Dimension 3')
plt.title('Latent Space Projection with Autoencoder')
plt.legend()
plt.show()


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mluis-muschal[0m ([33madl-for-cv[0m). Use [1m`wandb login --relogin`[0m to force relogin


Avg. Loss 0.342845, Loss 0.358489: 100%|██████████| 545/545 [01:14<00:00,  7.35it/s]
Avg. Loss 0.233364, Loss 0.242267: 100%|██████████| 545/545 [01:17<00:00,  7.00it/s]
Avg. Loss 0.189202, Loss 0.221404: 100%|██████████| 545/545 [01:14<00:00,  7.30it/s]
Avg. Loss 0.153499, Loss 0.155273: 100%|██████████| 545/545 [01:04<00:00,  8.43it/s]
Avg. Loss 0.131173, Loss 0.153633: 100%|██████████| 545/545 [01:05<00:00,  8.35it/s]
Avg. Loss 0.118248, Loss 0.116163: 100%|██████████| 545/545 [01:19<00:00,  6.84it/s]
Avg. Loss 0.104193, Loss 0.105730: 100%|██████████| 545/545 [01:09<00:00,  7.85it/s]
Avg. Loss 0.099532, Loss 0.094433: 100%|██████████| 545/545 [01:04<00:00,  8.42it/s]
Avg. Loss 0.086249, Loss 0.095640: 100%|██████████| 545/545 [01:10<00:00,  7.70it/s]
Avg. Loss 0.078611, Loss 0.080406: 100%|██████████| 545/545 [01:10<00:00,  7.72it/s]
Avg. Loss 0.071784, Loss 0.075807: 100%|██████████| 545/545 [01:08<00:00,  8.01it/s]
Avg. Loss 0.074445, Loss 0.064656: 100%|██████████| 545/545 [01:1

KeyboardInterrupt: 

In [12]:
normalizer = ZScore3D(means, stds)

normalizer.reverse(model(concatenated_data[0].unsqueeze(0))[1], None)[0] - concatenated_data[0][:128]

tensor([[ 1.2950e+00, -1.6470e-01, -2.7521e-01,  2.1858e-02, -1.5570e+00,
         -5.1293e-02, -1.2569e-01, -4.1379e-02, -1.6014e+00, -6.5952e-01,
         -3.4389e-01, -1.7295e-01, -2.5812e-01, -1.0768e-01, -1.6675e-01,
         -3.8412e-01,  3.4075e-03,  3.2163e-01, -1.1021e-01,  2.8555e-01,
          1.5521e-01, -1.0024e-01,  4.7126e-02, -5.6690e-01, -4.3606e-01,
         -7.9278e-01,  3.1277e-02, -1.6087e-01, -2.6900e-01, -8.4168e-01,
         -2.1610e-01, -1.1836e-01,  7.2839e-02, -1.7513e-02, -5.2016e-01,
         -1.2546e+00,  1.6268e-03, -7.9230e-02, -3.2796e-01, -2.7438e-01,
          9.3172e-02, -3.0386e-01,  1.1159e-01, -5.0528e-01, -6.7823e-01,
         -2.7406e-02,  9.4909e-02, -1.0364e-01, -2.8155e-01, -8.8741e-01,
         -3.1426e-01, -1.5171e-01, -4.6998e-01, -3.4399e-01,  6.7584e-01,
         -3.5733e-02, -3.2437e-01, -1.0999e-01, -3.6226e-01, -1.8213e-01,
         -3.2022e-01,  1.6424e-01, -4.9512e-01, -1.7287e+00, -1.7672e-02,
         -2.2956e-01,  4.1811e-01,  1.

In [11]:
concatenated_data[0][:128]

tensor([-2.5495,  0.8123,  0.3704, -0.1775,  1.6708, -0.0851, -0.9333,  0.2162,
         1.6531,  2.1256,  0.0671,  0.3314,  0.6256,  0.0960,  0.0909, -1.0794,
         0.1248, -0.6179,  0.0723, -0.8027, -0.9267, -0.1643,  0.0611,  0.2004,
         1.1189,  1.4231,  2.2143,  0.2767,  0.5708,  0.9710,  0.1077, -0.0419,
         0.1057, -0.1394,  0.4941,  1.5916, -0.3157,  0.0285, -1.8591,  0.3200,
         0.3421,  0.2940,  0.1683,  0.1650,  0.5711,  0.1115, -0.4845, -0.2298,
         0.6599,  1.1922,  0.7761,  0.4648,  0.6727,  0.4525, -1.5696,  0.1103,
         0.3395,  0.3859,  0.0215,  0.1389,  0.4945, -0.2744,  1.4628,  2.5776,
        -0.0059,  0.3401, -0.8799, -1.3363,  0.3709,  0.9580,  0.7778,  1.0956,
         1.5019,  0.4930,  0.0260,  0.9515, -0.0705,  0.0360,  0.2311,  0.0883,
         1.0115,  0.8481, -1.2776,  0.1004,  1.1366, -0.0948,  0.1798,  0.3943,
         0.7724,  0.0790,  1.6246,  1.3793,  0.9421,  0.1084,  0.0102,  0.6966,
        -0.5904, -1.5365,  3.2594,  0.13

In [2]:
import torch
from torchvision import transforms
from networks.mlp_models import MLP3D
from data.neural_field_datasets_shapenet import AllWeights3D, ModelTransform3D, ShapeNetDataset, FlattenTransform3D, ZScore3D, get_neuron_mean_n_std, get_total_mean_n_std


shapeNetData_flatten = ShapeNetDataset("./datasets/plane_mlp_weights", transform=[FlattenTransform3D()])
all_weights_flatten = torch.stack([sample[0] for sample in shapeNetData_flatten])

In [8]:
all_weights_flatten.view(3883, -1, 17).shape

torch.Size([3883, 2161, 17])

In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from tqdm import tqdm
import wandb
import math


warmup_iters = 100
lr_decay_iters = 20000
learning_rate = 0.008

def get_lr(it):
    # 1) linear warmup for warmup_iters steps
    if it < warmup_iters:
        return learning_rate * it / warmup_iters
    # 2) if it > lr_decay_iters, return min learning rate
    if it > lr_decay_iters:
        return 0.0
    # 3) in between, use cosine decay down to min learning rate
    decay_ratio = (it - warmup_iters) / (
        lr_decay_iters - warmup_iters
    )
    assert 0 <= decay_ratio <= 1
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))  # coeff ranges 0..1
    return coeff * (learning_rate)

wandb.init(project="autoencoder")


# Generate class labels
labels = torch.arange(287).repeat_interleave(4045)

data = all_weights_flatten.view(-1, 17)

emb_size = 2161

pos_enc = torch.Tensor([i for _ in range(3883) for i in range(emb_size)]).unsqueeze(-1)
concatenated_data = torch.cat((data, pos_enc), dim=1)

class Autoencoder(nn.Module):
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.n_emb_input = 128
        self.n_emb_latent = 128
        
        self.emb_input = nn.Embedding(emb_size, self.n_emb_input)
        self.emb_latent = nn.Embedding(emb_size, self.n_emb_latent)
        
        self.encoder = nn.Sequential(
            nn.Linear(17 + self.n_emb_input, 17 + self.n_emb_input//4),
            nn.GELU(),
            nn.Linear(17 + self.n_emb_input//4, 17 + self.n_emb_input//8),
            nn.GELU(),
            nn.Linear(17 + self.n_emb_input//8, 17 + self.n_emb_input//16),
            nn.GELU(),
            nn.Linear(17 + self.n_emb_input//16, 8),
            nn.GELU(),
        )
        self.decoder = nn.Sequential(
            nn.Linear(8 + self.n_emb_latent, 12 + self.n_emb_latent),
            nn.GELU(),
            nn.Linear(12 + self.n_emb_latent, 17),
        )

    def forward(self, x):
        pos = x[:, -1].int()
        embedding_input = self.emb_input(pos)
        x = torch.cat((x[:, :-1], embedding_input), dim=1)
        latent = self.encoder(x)
        embedding_latent = self.emb_latent(pos)

        reconstructed = self.decoder(torch.cat((latent, embedding_latent), dim=1))
        return latent, reconstructed

# Initialize the model, loss function, and optimizer
model = Autoencoder()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training the autoencoder
num_epochs = lr_decay_iters
batch_size = 2048
data_loader = torch.utils.data.DataLoader(concatenated_data, batch_size=batch_size, shuffle=True)

exp_avg_loss = None

iters = 0

while True:
    bar = tqdm(data_loader)
    for batch in bar:
        lr = get_lr(iters)
        for param_group in optimizer.param_groups:
            param_group["lr"] = lr
        optimizer.zero_grad()
        latent, reconstructed = model(batch)
        loss = criterion(reconstructed, batch[:, :17])
        loss.backward()
        optimizer.step()
        if exp_avg_loss == None:
            exp_avg_loss = loss
        exp_avg_loss = 0.95*exp_avg_loss + 0.05*loss
        bar.set_description("Avg. Loss %f, Loss %f" % (exp_avg_loss, loss))
        iters += 1
    
        wandb.log({"lr": lr, "iters": iters, "loss": exp_avg_loss.item()})
    if iters > lr_decay_iters:
            break

# Project the data into the latent space
with torch.no_grad():
    latent_space, _ = model(data)

# Visualization
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')

# Generate a color map for the 27 classes
cmap = plt.get_cmap("tab20", 27)
colors = cmap(labels)

# # Plot each class with its respective color
for class_idx in range(27):
    class_data = latent_space[labels == class_idx]
    ax.scatter(class_data[:, 0], class_data[:, 1], class_data[:, 2], label=f'Class {class_idx}', alpha=0.6, c=colors[labels == class_idx])


ax.set_xlabel('Latent Dimension 1')
ax.set_ylabel('Latent Dimension 2')
ax.set_zlabel('Latent Dimension 3')
plt.title('Latent Space Projection with Autoencoder')
plt.legend()
plt.show()


Problem at: /Users/luis/uni/adl4cv/adl4cv/adl4cv/.venv/lib/python3.12/site-packages/wandb/sdk/wandb_init.py 849 getcaller


CommError: Run initialization has timed out after 90.0 sec. 
Please refer to the documentation for additional information: https://docs.wandb.ai/guides/track/tracking-faq#initstarterror-error-communicating-with-wandb-process-