# Auto Encoders

In [None]:
device = "mps"

In [None]:
from ml_zoo import *
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import tqdm as tqdm
from sklearn.decomposition import PCA
import numpy as np
import plotly.express as px
import matplotlib.pyplot as plt

In [None]:
dm = CIFARDataModule(
    CIFARDataModuleConfig(
        "data",
        batch_size=128,
        num_workers=4,
        persistent_workers=True,
        pin_memory=True,
    )
)
dm.prepare_data()
dm.setup()
train_loader = dm.train_dataloader()
val_loader = dm.val_dataloader()

In [None]:
def evaluate(model, loader):
    model.eval()
    with torch.no_grad():
        total_loss = 0
        for x, _ in val_loader:
            x = x.to(device)
            _, loss = model(x, calc_loss=True)
            total_loss += loss.item()

    return total_loss / len(loader)

## Regular Auto Encoder

In [None]:
class Model(nn.Module):
    def __init__(self, latent_dim=16):
        super(Model, self).__init__()
        self.encoder = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(32),
            nn.Conv2d(32, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(64),
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 4 * 4),
            nn.ReLU(),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, calc_loss) -> tuple[torch.Tensor, torch.Tensor | None]:
        if not calc_loss:
            z = self.encoder(x)
            x_hat = self.decoder(z)
            return x_hat, None

        else:
            z = self.encoder(x)
            x_hat = self.decoder(z)

            loss = F.mse_loss(x_hat, x)
            return x_hat, loss


model = Model().to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5)

In [None]:
roll_loss = 0
for epoch in range(100):
    model.train()
    val_loss = evaluate(model, val_loader)
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")

    for x, _ in pbar:
        x = x.to(device)
        x_hat, loss = model(x, calc_loss=True)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        roll_loss = 0.9 * roll_loss + 0.1 * loss.item()

        pbar.set_postfix_str(
            f"Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Overfit: {loss.item() / val_loss:.4f}, Roll Loss: {roll_loss:.4f}, LR: {optimizer.param_groups[0]['lr']:.4f}"
        )

    scheduler.step(roll_loss)

In [None]:
print(scheduler.get_last_lr())

In [None]:
model.eval()
with torch.no_grad():
    values = []
    classes = []
    for x, y in val_loader:
        x = x.to(device)
        z = model.encoder(x)

        values.append(z.cpu().numpy())
        classes.append(y.cpu().numpy().flatten())

    values = np.concatenate(values, axis=0)
    classes = np.concatenate(classes, axis=0)

    # Perform PCA for 3 highest components
    pca = PCA(n_components=2)
    values = pca.fit_transform(values)

    fig = px.scatter(
        x=values[:, 0],
        y=values[:, 1],
        color=classes,
        labels={"values": values},
        title="Latent Space",
        category_orders={"color": np.unique(classes)},
        height=800,
        
    )
    fig.show()

In [None]:
model.eval()
x, y = next(iter(val_loader))

x = x.to(device)
x_hat = model(x, calc_loss=False)[0].detach()

fig, axs = plt.subplots(2, 5, figsize=(20, 8))
for i in range(5):
    axs[0, i].imshow(x[i].cpu().numpy().squeeze(), cmap="gray")
    axs[0, i].axis("off")
    axs[1, i].imshow(x_hat[i].cpu().numpy().squeeze(), cmap="gray")
    axs[1, i].axis("off")

plt.suptitle("Reconstructed Images")
plt.show()

## Denoising Auto Encoder

In [None]:
class Model(nn.Module):
    def __init__(self, latent_dim=4):
        super(Model, self).__init__()
        self.encoder = nn.Sequential(
            nn.Dropout(0.5), # Dropout layer
            nn.Conv2d(1, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, latent_dim),
        )
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 4 * 4),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, stride=2, padding=1),
            nn.Sigmoid(),
            
        )

    def forward(self, x, calc_loss) -> tuple[torch.Tensor, torch.Tensor | None]:
        if not calc_loss:
            z = self.encoder(x)
            x_hat = self.decoder(z)
            return x_hat, None

        else:
            z = self.encoder(x)
            x_hat = self.decoder(z)

            loss = F.mse_loss(x_hat, x)
            return x_hat, loss


model = Model().to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in range(10):
    model.train()
    val_loss = evaluate(model, val_loader)
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, _ in pbar:
        x = x.to(device)
        x_hat, loss = model(x, calc_loss=True)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix_str(
            f"Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Overfit: {loss.item() / val_loss:.4f}"
        )

In [None]:
model.eval()
with torch.no_grad():
    values = []
    classes = []
    for x, y in val_loader:
        x = x.to(device)
        z = model.encoder(x)

        values.append(z.cpu().numpy())
        classes.append(y.cpu().numpy().flatten())

    values = np.concatenate(values, axis=0)
    classes = np.concatenate(classes, axis=0)

    # Perform PCA for 3 highest components
    pca = PCA(n_components=2)
    values = pca.fit_transform(values)

    fig = px.scatter(
        x=values[:, 0],
        y=values[:, 1],
        color=classes,
        labels={"values": values},
        title="Latent Space",
        category_orders={"color": np.unique(classes)},
        height=800,
        
    )
    fig.show()

In [None]:
model.eval()
x, y = next(iter(val_loader))

x = x.to(device)
x_hat = model(x, calc_loss=False)[0].detach()

fig, axs = plt.subplots(2, 5, figsize=(20, 8))
for i in range(5):
    axs[0, i].imshow(x[i].cpu().numpy().squeeze(), cmap="gray")
    axs[0, i].axis("off")
    axs[1, i].imshow(x_hat[i].cpu().numpy().squeeze(), cmap="gray")
    axs[1, i].axis("off")

plt.suptitle("Reconstructed Images")
plt.show()

## Variational Auto Encoder

In [None]:
class Model(nn.Module):
    def __init__(self, latent_dim=128):
        super(Model, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Flatten(),
        )
        self.logvar = nn.Linear(128 * 4 * 4, latent_dim) # Log variance
        self.mu = nn.Linear(128 * 4 * 4, latent_dim) # Mean

        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, 128 * 4 * 4),
            nn.Unflatten(1, (128, 4, 4)),
            nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 4, stride=2, padding=1),
            nn.Sigmoid(),
            
        )

    def forward(self, x, calc_loss) -> tuple[torch.Tensor, torch.Tensor | None]:
        if calc_loss:
            z = self.encoder(x)
            mean = self.mu(z)
            logvar = self.logvar(z)
            z = self.reparameterize(mean, logvar)

            x_hat = self.decoder(z)

            loss = F.mse_loss(x_hat, x) + (-0.5 * torch.sum(1 + logvar - mean.pow(2) - logvar.exp()))
            return x_hat, loss
        
        else:
            z = self.encoder(x)
            mean = self.mu(z)
            logvar = self.logvar(z)
            z = self.reparameterize(mean, logvar)
            x_hat = self.decoder(z)
            return x_hat, None
        
    def reparameterize(self, mean, logvar): # Reparameterization trick
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std, device=device)
        return mean + eps * std
    


model = Model().to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in range(10):
    model.train()
    val_loss = evaluate(model, val_loader)
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, _ in pbar:
        x = x.to(device)
        x_hat, loss = model(x, calc_loss=True)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix_str(
            f"Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Overfit: {loss.item() / val_loss:.4f}"
        )

In [None]:
model.eval()
with torch.no_grad():
    values = []
    classes = []
    for x, y in val_loader:
        x = x.to(device)
        z = model.encoder(x)

        values.append(z.cpu().numpy())
        classes.append(y.cpu().numpy().flatten())

    values = np.concatenate(values, axis=0)
    classes = np.concatenate(classes, axis=0)

    # Perform PCA for 3 highest components
    pca = PCA(n_components=2)
    values = pca.fit_transform(values)

    fig = px.scatter(
        x=values[:, 0],
        y=values[:, 1],
        color=classes,
        labels={"values": values},
        title="Latent Space",
        category_orders={"color": np.unique(classes)},
        height=800,
        
    )
    fig.show()

In [None]:
model.eval()
x, y = next(iter(val_loader))

x = x.to(device)
x_hat = model(x, calc_loss=False)[0].detach()

fig, axs = plt.subplots(2, 5, figsize=(20, 8))
for i in range(5):
    axs[0, i].imshow(x[i].cpu().permute(1, 2, 0))
    axs[0, i].axis("off")
    axs[1, i].imshow(x_hat[i].cpu().permute(1, 2, 0))
    axs[1, i].axis("off")

plt.suptitle("Reconstructed Images")
plt.show()

## Sparse Auto Encoder

In [None]:
class Model(nn.Module):
    def __init__(self, latent_dim=4):
        super(Model, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3,  stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.ReLU(),
        )
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 3, 3, stride=2, padding=1),
            
        )

    def forward(self, x, calc_loss) -> tuple[torch.Tensor, torch.Tensor | None]:
        if not calc_loss:
            z = self.encoder(x)
            x_hat = self.decoder(z)
            return x_hat, None

        else:
            z = self.encoder(x)
            x_hat = self.decoder(z)

            loss = F.mse_loss(x_hat, x) + 1e-3 * F.l1_loss(z, torch.zeros_like(z)) # L1 regularization
            return x_hat, loss


model = Model().to(device)
model(torch.randn(1, 3, 32, 32, device=device), calc_loss=False)[0].shape

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
for epoch in range(10):
    model.train()
    val_loss = evaluate(model, val_loader)
    pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch}")
    for x, _ in pbar:
        x = x.to(device)
        x_hat, loss = model(x, calc_loss=True)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        pbar.set_postfix_str(
            f"Loss: {loss.item():.4f}, Val Loss: {val_loss:.4f}, Overfit: {loss.item() / val_loss:.4f}"
        )

In [None]:
model.eval()
with torch.no_grad():
    values = []
    classes = []
    for x, y in val_loader:
        x = x.to(device)
        z = model.encoder(x)

        values.append(z.cpu().numpy())
        classes.append(y.cpu().numpy().flatten())

    values = np.concatenate(values, axis=0)
    classes = np.concatenate(classes, axis=0)

    # Perform PCA for 3 highest components
    pca = PCA(n_components=2)
    values = pca.fit_transform(values)

    fig = px.scatter(
        x=values[:, 0],
        y=values[:, 1],
        color=classes,
        labels={"values": values},
        title="Latent Space",
        category_orders={"color": np.unique(classes)},
        height=800,
        
    )
    fig.show()

In [None]:
model.eval()
x, y = next(iter(val_loader))

x = x.to(device)
x_hat = model(x, calc_loss=False)[0].detach()

fig, axs = plt.subplots(2, 5, figsize=(20, 8))
for i in range(5):
    axs[0, i].imshow(x[i].cpu().numpy().squeeze(), cmap="gray")
    axs[0, i].axis("off")
    axs[1, i].imshow(x_hat[i].cpu().numpy().squeeze(), cmap="gray")
    axs[1, i].axis("off")

plt.suptitle("Reconstructed Images")
plt.show()

In [None]:
# print encoded images
model.eval()
x, y = next(iter(val_loader))
print(model.encoder(x.to(device))[0])