In [3]:
from torch_snippets import *
from torchvision.datasets import MNIST
from torchvision import transforms
import torch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
    transforms.Lambda(lambda x: x.to(device))
])

In [5]:
trn_ds = MNIST('', transform=img_transform, train=True, download=True)
val_ds = MNIST('', transform=img_transform, train=False, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to MNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████████████████| 9.91M/9.91M [00:29<00:00, 336kB/s]


Extracting MNIST/raw/train-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to MNIST/raw/train-labels-idx1-ubyte.gz


100%|███████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 201kB/s]


Extracting MNIST/raw/train-labels-idx1-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████████████████████████████████| 1.65M/1.65M [00:00<00:00, 1.71MB/s]


Extracting MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 404: Not Found

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 3.53MB/s]

Extracting MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST/raw






In [8]:
import torch.nn as nn
import torch

class AutoEncoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.latent_dim = latent_dim
        self.encoder = nn.Sequential(
            nn.Linear(28*28,128), nn.ReLU(True),
            nn.Linear(128,64), nn.ReLU(True),
            nn.Linear(64,latent_dim))
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim,64),nn.ReLU(True),
            nn.Linear(64,128),nn.ReLU(True),
            nn.Linear(128,28*28),nn.Tanh())

    def forward(self,x):
        x = x.view(len(x), -1)
        x = self.encoder(x)
        x = self.decoder(x)
        x = x.view(len(x), 1, 28, 28)
        return x

In [9]:
def train_batch(input,model,criterio,optimizer):
    model.train()
    optimizer.zero_grad()
    output = model(input)
    loss = criterion(output,input)
    loss.backward()
    optimizer.step()
    return loss

@torch.no_grad()
def validate_batch(input,model,criterion):
    model.eval()
    output = model(input)
    loss = criterion(output,input)
    return loss

In [10]:
from fastprogress import master_bar, progress_bar
from fastprogress.fastprogress import Report

model = AutoEncoder(3).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(),lr=0.001,weight_decay=1e-5)

num_epochs = 5
log = Report(num_epochs)

for epoch in range(num_epochs):
    N = len(trn_dl)
    for ix, (data,_) in enumerate(trn_dl):
        loss = train_batch(data,model,criterion,optimizer)
        log.record(pos=(epoch+(ix+1)/N), trn_loss=loss, end='\r')

    N = len(val_dl)
    for ix, (data,_) in enumerate(val_dl):
        loss = validate_batch(data,model,criterion)
        log.record(pos=(epoch+(ix+1)/N),val_loss=loss,end='\r')
    log.report_avgs(epoch+1)
log.plot(log=True)

ImportError: cannot import name 'Report' from 'fastprogress.fastprogress' (/opt/anaconda3/lib/python3.12/site-packages/fastprogress/fastprogress.py)

In [None]:
model = AutoEncoder(3).to(device)

for _ in range(3):
    ix = np.random.randint(len(val_ds))
    im, _ = val_ds[ix]
    _im = model(im[None])[0]
    fig, ax = plt.subplots(1,2,figsize=(3,3))
    show(im[0], ax=ax[0], title='input')
    show(_im[0], ax=ax[1], title='prediction')
    plt.tight_layout()
    plt.show()