In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import wandb
import torchvision


In [None]:
train_dataset = torchvision.datasets.MNIST(
    root="data",
    train=True,
    transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()]
    ),
    download=True,
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)

test_dataset = torchvision.datasets.MNIST(
    root="data",
    train=False,
    transform=torchvision.transforms.Compose(
        [torchvision.transforms.ToTensor()]
    ),
    download=True,
)

test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
class Encoder(nn.Module):
    def __init__(self, input_channels, latent_dim):
        super(Encoder, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(784, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, latent_dim),
        )
        
    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.layers(x)
    
class Decoder(nn.Module):
    def __init__(self, output_channels, latent_dim):
        super(Decoder, self).__init__()
        self.layers = nn.Sequential(
            nn.Linear(latent_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 784),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.layers(x).view(x.size(0), 1, 28, 28)


In [None]:
encoder = Encoder(1, 2)
decoder = Decoder(1, 2)


In [None]:
optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=0.001)


In [None]:
wandb.init(project="vae")

for epoch in range(1):
    for i, (images, _) in enumerate(train_loader):
        optimizer.zero_grad()
        z = encoder(images)
        x_hat = decoder(z)
        loss = F.binary_cross_entropy(x_hat, images)
        loss.backward()
        optimizer.step()
        wandb.log({"loss": loss.item()})
        if i % 100 == 0:
            wandb.log({"reconstruction": [wandb.Image(images[0]), wandb.Image(x_hat[0])]})
            print(f"Epoch {epoch} Iteration {i} Loss {loss.item()}")
        

encoder.eval()
decoder.eval()

embeddings = []
for i, (images, _) in enumerate(test_loader):
    z = encoder(images)
    embeddings.append(z)
embeddings = torch.cat(embeddings, dim=0).detach()
labels = torch.cat([y for x, y in test_loader], dim=0)

wandb.log({"embeddings": wandb.Table(data=torch.cat([labels.unsqueeze(1), embeddings], dim=1).numpy().tolist(), columns=["label"] + [f"dim{i}" for i in range(embeddings.size(1))])})
    

    
wandb.finish()
