In [1]:
import torch
import torch.nn as nn

from src import MNISTDM, Encoder, Decoder

In [2]:
HPARAMS = {
    "data/batch_size": 128,
    "data/image_size": 32,
    "data/num_workers": 4,
    "model/width": 6,
    "model/in_channels": 1,
    "train/epochs": 100,
    "train/lr": 8e-4,
    "train/weight_decay": 1e-2,
}

dm = MNISTDM(
    "data",
    HPARAMS["data/batch_size"],
    HPARAMS["data/image_size"],
    HPARAMS["data/num_workers"],
)
dm.prepare_data()
dm.setup()

encoder = Encoder(HPARAMS["model/in_channels"], HPARAMS["model/width"])
decoder = Decoder(HPARAMS["model/in_channels"], HPARAMS["model/width"])

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/train/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 21235293.88it/s]


Extracting data/train/MNIST/raw/train-images-idx3-ubyte.gz to data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/train/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1160403.62it/s]


Extracting data/train/MNIST/raw/train-labels-idx1-ubyte.gz to data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/train/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 10539400.54it/s]


Extracting data/train/MNIST/raw/t10k-images-idx3-ubyte.gz to data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/train/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 4353411.51it/s]


Extracting data/train/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/train/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/test/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 20596299.74it/s]


Extracting data/test/MNIST/raw/train-images-idx3-ubyte.gz to data/test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/test/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 1347689.18it/s]


Extracting data/test/MNIST/raw/train-labels-idx1-ubyte.gz to data/test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/test/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 10374018.64it/s]


Extracting data/test/MNIST/raw/t10k-images-idx3-ubyte.gz to data/test/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/test/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3518103.19it/s]

Extracting data/test/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/test/MNIST/raw






In [3]:
ckpt = torch.load("checkpoints/autoencoder.ckpt", weights_only=False)
encoder_state_dict = {
    k.replace("encoder.", ""): v
    for k, v in ckpt["state_dict"].items()
    if "encoder" in k
}
decoder_state_dict = {
    k.replace("decoder.", ""): v
    for k, v in ckpt["state_dict"].items()
    if "decoder" in k
}

encoder.load_state_dict(encoder_state_dict)
decoder.load_state_dict(decoder_state_dict)

torch.save(encoder.state_dict(), "checkpoints/encoder.pt")
torch.save(decoder.state_dict(), "checkpoints/decoder.pt")

In [4]:
device = torch.device("mps")
encoder.to(device)
torch.set_grad_enabled(False)

latents = []
labels = []
for x, y in dm.train_dataloader():
    x = x.to(device)
    z = encoder(x)
    latents.append(z.cpu())
    labels.append(y)

latents = torch.cat(latents, dim=0)
labels = torch.cat(labels, dim=0)

train_dataset = torch.utils.data.TensorDataset(latents, labels)
torch.save(train_dataset, "checkpoints/train_latents.pt")

latents = []
labels = []
for x, y in dm.val_dataloader():
    x = x.to(device)
    z = encoder(x)
    latents.append(z.cpu())
    labels.append(y)

latents = torch.cat(latents, dim=0)
labels = torch.cat(labels, dim=0)

val_dataset = torch.utils.data.TensorDataset(latents, labels)
torch.save(val_dataset, "checkpoints/val_latents.pt")