In [1]:
import os
import torch
import numpy as np
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset,DataLoader
import lightning as L

In [38]:
np.random.seed(0)
dim = 200
actual_dim = 75
latent_dim = 150
encoder = nn.Sequential(nn.Linear(200,175),nn.GELU(),nn.Linear(175,latent_dim))
decoder = nn.Sequential(nn.Linear(latent_dim, 100), nn.GELU(), nn.Linear(100, 200))

class AutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        z = self.encoder(batch)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, batch)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-4)
        return optimizer
AE = AutoEncoder(encoder,decoder)

In [39]:
class VectorDataset(Dataset):
    def __init__(self, n, idxs, dim=dim, actual_dim=actual_dim):
        self.x = torch.zeros(n,dim)
        self.x[:,idxs] = torch.randn(n,actual_dim)

    def __len__(self):
        return len(self.x)

    def __getitem__(self, idx):
        return self.x[idx]

In [40]:
idxs = np.random.choice(dim,actual_dim)
train_ds,test_ds = VectorDataset(10000,idxs), VectorDataset(100,idxs)
train_dl = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=0)

In [41]:
trainer = L.Trainer(limit_train_batches=100, max_epochs=50)
trainer.fit(model=AE, train_dataloaders=train_dl)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | encoder | Sequential | 61.6 K | train
1 | decoder | Sequential | 35.3 K | train
-----------------------------------------------
96.9 K    Trainable params
0         Non-trainable params
96.9 K    Total params
0.387     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=50` reached.


In [37]:
x = np.zeros(dim)
x[idxs] = np.random.randn(actual_dim)
x = torch.tensor(x).float()
reconstructed = AE.decoder.eval()(AE.encoder.eval()(x))
nn.functional.mse_loss(x,reconstructed)


tensor(0.0121, grad_fn=<MseLossBackward0>)