In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import pytorch_lightning as pl

## Goals: in jax

- [ ] Module class (with linear)
- [ ] Trainer class (with plotting and logging and training step, etc...)
- [ ] Dataclass class
- [ ] Passes pylint, pydoc, and mypy



In [None]:



class Module:
    '''Base module class for nn models'''

    def __init__(self) -> None:
        pass


class Linear(Module):
    r"""Linear nn layer with learnable weights w and bias b. 
        The return on input data x is y = x @ w.T + b.

    Paramaters:
        in_features: size of each input sample
        out_features: size of each output sample
        bias: If set to ``False``, the layer will not learn an additive bias.
            Default: ``True``
    """
    def __init__(self, in_features, out_features, bias=True, device=None, dtype=jnp.float32) -> None:
        super().__init__()

        self.in_features = in_features
        self.out_features = out_features
        self.bias = bias
        self.dtype = dtype


    def init(self, key) -> Array:
        '''Initialize the weights and bias from a key'''
        key, w_key, b_key = jax.random.split(key: KeyArray, num=3)
        w = jax.random.normal(w_key, shape=(self.in_features,self.out_features), dtype=self.dtype)
        jnp.array




In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

In [3]:
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [5]:
dataset = MNIST(os.getcwd(), download=False, transform=transforms.ToTensor())
#train_loader = DataLoader(dataset)

RuntimeError: Dataset not found. You can use download=True to download it

In [4]:
# model
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# train model
trainer = pl.Trainer(max_epochs=4,  enable_progress_bar=False)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


NameError: name 'train_loader' is not defined

In [16]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[-1.2964e+23, -1.8063e+23, -2.2563e+23],
        [ 2.0754e+22, -7.8390e+21, -1.8989e+23],
        [ 5.1086e+23,  1.4693e+23, -1.6050e+23],
        [ 1.4023e+24, -8.8235e+22,  4.0475e+23]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡


In [29]:
%reload_ext tensorboard
%tensorboard --logdir=lightning_logs/

Reusing TensorBoard on port 6006 (pid 42316), started 0:00:10 ago. (Use '!kill 42316' to kill it.)