In [19]:
import torch
import numpy as np
import torch.nn as nn
from torchvision.datasets import MNIST
import torchvision.transforms as T
from torch.utils.data import DataLoader
from IPython.display import clear_output as co
from tqdm import tqdm
import matplotlib.pyplot as plt
from torch.optim import Adam
import seaborn as sns
sns.set_theme(style="darkgrid", font_scale = 1.4)

In [20]:
mnist_transforms = T.Compose([
    T.Resize((64,64)),
    T.ToTensor(), 
    ])

device = 'cuda' if torch.cuda.is_available() else 'cpu'; print(f'Using {device}')
train_set = MNIST(root='./data', train=True, download=True, transform=mnist_transforms)
valid_set = MNIST(root='./data', train=False, download=True, transform=mnist_transforms)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True)
valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False)

def plot_stats(train_loss, valid_loss):
    plt.figure(figsize=(10, 5))
    plt.plot(train_loss, label='Training Loss')
    plt.plot(valid_loss, label='Validation Loss')
    plt.legend()
    plt.grid()
    plt.show()


criterion = nn.MSELoss()

def train(model):
    optimizer = Adam(model.parameters(), lr = 1e-3)
    model.train()
    train_losses = 0
    for x, _ in tqdm(train_loader, desc = 'Train'):
        x = x.to(device)
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, x)
        train_losses += loss.item()
        loss.backward()
        optimizer.step()
    return train_losses / len(train_loader)

@torch.inference_mode()
def valid(model):
    valid_losses = 0
    model.eval()
    for x, _ in tqdm(valid_loader, desc = 'Valid'):
        with torch.no_grad():
            x = x.to(device)
            y_hat = model(x)
            loss = criterion(y_hat, x)
            valid_losses += loss.item()
    return valid_losses / len(valid_loader)


Using cpu


In [21]:
from torchvision.utils import make_grid

@torch.inference_mode()
def visual(model, xs):
    model.eval()
    z = model(xs.to(device)).cpu()
    plt.figure(figsize = (20, 4))

    img = T.ToPILImage(torch.cat((
        make_grid(xs[:10], nrow = 10),
        make_grid(z[:10], nrow = 10)
    )))
    plt.imshow(img)
    plt.axif('off')
    plt.show()

In [25]:
def octava(model, epochs = 10):
    train_losses, valid_losses = [], []

    for e in range(epochs):
        train_loss = train(model)
        valid_loss = valid(model)

        train_losses.append(train_loss)
        valid_losses.append(valid_loss)

        co(wait = True)
        visual(model, next(iter(valid_loader))[0])
        plot_stats(train_losses, valid_losses)


In [23]:
class Block(nn.Module):
    def __init__(self, in_channels, out_channels, kernel, stride = 1, padding = 1, up = False):
        super().__init__()
        self.up = up
        self.conv = nn.Conv2d(in_channels, out_channels, kernel, stride, padding)
        self.norm = nn.BatchNorm2d(out_channels)
        self.act = nn.ELU()

    def forward(self, x):
        if self.up:
            x = nn.functional.interpolate(x, scale_factor = 2, mode = 'bilinear', align_corners = False, recompute_scale_factor=4)

        return self.act(self.norm(self.conv(x)))

class AutoEncoder(nn.Module):
    def __init__(self, in_channels, base_block = Block, baze_size = 32, num_blocks = 4):
        super().__init__()
        self.baze_size = baze_size
        encoder_blocks = [
            base_block(baze_size if i else in_channels, baze_size, 3, 1, 1)
            for i in range(num_blocks)
        ]
        encoder_blocks.append(base_block(
            baze_size, baze_size, 3, 1, 1, up = True
        ).conv)
        self.encoder = nn.Sequential(*encoder_blocks)

        decoder_blocks = [
            base_block(baze_size, baze_size, 3, 1, 1, up = True)
            for i in range(num_blocks)
        ]
        decoder_blocks.append(base_block(baze_size, in_channels, 3).conv)
        self.decoder = nn.Sequential(*decoder_blocks)

        self.flatten = nn.Flatten()

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x   
    def encode(self, x):
        return self.flatten(self.encoder(x))
    def decode(self, x):
        size = int(np.sqrt(x.shape[1] // self.baze_size))
        return self.decoder(x.view(-1, self.baze_size, size, size))

In [26]:
model = AutoEncoder(1).to(device)
octava(model)

Train:   0%|          | 0/469 [00:00<?, ?it/s]