In [None]:
# | default_exp sparse_autoenc

%load_ext autoreload
%autoreload 2

%env PYDEVD_DISABLE_FILE_VALIDATION=1

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

device = "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
# Hyperparamètres
batch_size = 128
learning_rate = 0.001
num_epochs = 10
sparsity_lambda = 1e-3  # Coefficient pour la régularisation L1

# Prétraitement du dataset MNIST
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)
train_dataset = datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)
test_dataset = datasets.MNIST(
    root="./data", train=False, transform=transform, download=True
)
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Définition de l'autoencodeur
class SparseAutoencoder(nn.Module):
    def __init__(self, input_size=784, hidden_size=128):
        super(SparseAutoencoder, self).__init__()
        self.encoder = nn.Sequential(nn.Linear(input_size, hidden_size), nn.ReLU())
        self.decoder = nn.Sequential(
            nn.Linear(hidden_size, input_size),
            nn.Sigmoid(),  # Pour une reconstruction des pixels entre 0 et 1
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded


# Initialisation du modèle
input_size = 28 * 28  # Taille des images MNIST aplaties
hidden_size = 128
model = SparseAutoencoder(input_size=input_size, hidden_size=hidden_size).to(device)

# Optimisateur et fonction de perte
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Entraînement
model.to(device)
if False:
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0

        for images, _ in train_loader:
            images = images.view(-1, 28 * 28).to(device)  # Aplatir les images
            encoded, decoded = model(images)

            # Perte de reconstruction
            reconstruction_loss = criterion(decoded, images)

            # Pénalité L1 pour encourager la sparsité
            l1_loss = torch.mean(torch.abs(encoded))  # Sparsité sur la couche cachée
            loss = reconstruction_loss + sparsity_lambda * l1_loss

            # Rétropropagation
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        train_loss /= len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {train_loss:.4f}")

In [None]:
# Tester l'autoencodeur (visualisation)


encoded = []


def hook_fn(module, input, output):
    # Normalize the output for visualization
    # print("encoder output:", output)
    encoded.append(output)


model.encoder.register_forward_hook(hook_fn)

model.eval()
with torch.no_grad():
    for images, _ in test_loader:
        images = images.view(-1, 28 * 28).to(device)
        _, decoded = model(images)
        images = images.cpu().view(-1, 28, 28)
        decoded = decoded.cpu().view(-1, 28, 28)

        # Visualiser les images originales et reconstruites
        fig, axes = plt.subplots(2, 10, figsize=(10, 2))
        for i in range(10):
            axes[0, i].imshow(images[i], cmap="gray")
            axes[0, i].axis("off")
            axes[1, i].imshow(decoded[i], cmap="gray")
            axes[1, i].axis("off")
        plt.show()
        break

In [None]:
len(encoded)