In [1]:
# Note: This is a hack to allow importing from the parent directory
import sys
from pathlib import Path

sys.path.append(str(Path().resolve().parent))

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from models import AbstractAutoencoder


class ShallowAutoencoder(AbstractAutoencoder):
    def __init__(self, input_dim, latent_dim, use_bias=True):
        super(ShallowAutoencoder, self).__init__()
        self.type = "ShallowAutoencoder"
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, latent_dim, bias=use_bias),
            nn.ReLU(inplace=True),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, input_dim, bias=use_bias),
            nn.Sigmoid(),
        )

In [3]:
from data import CIFAR10GaussianSplatsDataset
from utils import train, test, transform, noop_collate, transform_and_collate

dataset = CIFAR10GaussianSplatsDataset(
    root="../data/CIFAR10GS",
    train=True,
    init_type="grid",
)

latent_dim = 128
input_dim = 23552

In [4]:
model = ShallowAutoencoder(input_dim=input_dim, latent_dim=latent_dim)
train_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    num_workers=4,
    collate_fn=transform_and_collate,
)
val_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=transform_and_collate,
)
test_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=32,
    shuffle=False,
    num_workers=4,
    collate_fn=transform_and_collate,
)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
epochs = 1
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
train(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    epochs=epochs,
    device=device,
)

Epoch 1/1: 100%|██████████| 12/12 [00:00<00:00, 12.07batch/s]


Epoch 1/1 | Train Loss: 0.5303 | Val Loss: 0.5041


{'train_loss': [0.5303028970956802], 'val_loss': [0.5040653645992279]}

In [6]:
test(model=model, test_loader=test_loader, criterion=criterion, device=device)

Test Loss: 0.5041


0.5040653645992279

In [7]:
data_loader = torch.utils.data.DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    num_workers=4,
    collate_fn=noop_collate,
)

results = transform(model=model, data_loader=data_loader, device=device)

In [8]:
source = results[0][-2]
solution = results[0][-1]
source["means"], solution["means"]

(Parameter containing:
 tensor([[-1.0000, -1.0000,  0.0000],
         [-1.0000, -0.9355,  0.0000],
         [-1.0000, -0.8710,  0.0000],
         ...,
         [ 1.0000,  0.8710,  0.0000],
         [ 1.0000,  0.9355,  0.0000],
         [ 1.0000,  1.0000,  0.0000]], requires_grad=True),
 Parameter containing:
 tensor([[0.6119, 0.6858, 0.2839],
         [0.4844, 0.4107, 0.5451],
         [0.6521, 0.2578, 0.1811],
         ...,
         [0.3059, 0.4258, 0.4643],
         [0.3053, 0.6044, 0.4042],
         [0.4018, 0.4646, 0.3707]], requires_grad=True))

In [12]:
from utils import transform_autoencoder_input, transform_autoencoder_output

transform_autoencoder_output(transform_autoencoder_input(source))["means"]

Parameter containing:
tensor([[-1.0000, -1.0000,  0.0000],
        [-1.0000, -0.9355,  0.0000],
        [-1.0000, -0.8710,  0.0000],
        ...,
        [ 1.0000,  0.8710,  0.0000],
        [ 1.0000,  0.9355,  0.0000],
        [ 1.0000,  1.0000,  0.0000]], requires_grad=True)