In [None]:
#%pip install torch torchvision matplotlib numpy lgbt scikit-learn ipywidgets

Collecting torch
  Using cached torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Using cached torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting matplotlib
  Using cached matplotlib-3.10.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)
Collecting numpy
  Using cached numpy-2.2.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (62 kB)
Collecting lgbt
  Using cached lgbt-0.2.2-py3-none-any.whl.metadata (1.0 kB)
Collecting scikit-learn
  Using cached scikit_learn-1.6.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting ipywidgets
  Using cached ipywidgets-8.1.5-py3-none-any.whl.metadata (2.3 kB)
Collecting filelock (from torch)
  Using cached filelock-3.18.0-py3-none-any.whl.metadata (2.9 kB)
Collecting typing-extensions>=4.10.0 (from torch)
  Downloading typing_extensions-4.13.1-py3-none-any.whl.metadata (3.0 kB)
Collecting networkx (f

In [None]:
import torch.nn as nn
from torch.utils.data import DataLoader
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import torch.optim as optim
from sklearn.manifold import TSNE
import ipywidgets as widgets
from IPython.display import display

ModuleNotFoundError: No module named 'torch'

In [None]:
transform = transforms.Compose([transforms.ToTensor()])

train_dataset = datasets.FashionMNIST(
    root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(
    root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False)


classes = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
           'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']


def show_images(images, title):
    images = images.cpu().detach().numpy()
    plt.figure(figsize=(10, 2))
    for i in range(5):
        plt.subplot(1, 5, i+1)
        plt.imshow(images[i][0], cmap='gray')
        plt.axis('off')
    plt.suptitle(title)
    plt.show()

In [None]:
def train_ae(model, train_loader, test_loader, epochs=20):
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    train_losses = []
    test_losses = []
    
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        for data, _ in train_loader:
            optimizer.zero_grad()
            outputs = model(data)
            loss = criterion(outputs, data)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
        
        train_loss /= len(train_loader)
        train_losses.append(train_loss)
        
        model.eval()
        test_loss = 0
        with torch.no_grad():
            for data, _ in test_loader:
                outputs = model(data)
                loss = criterion(outputs, data)
                test_loss += loss.item()
        
        test_loss /= len(test_loader)
        test_losses.append(test_loss)
        
        print(f'Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
    
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.show()
    
    return model

In [None]:
class FullyConnectedAE(nn.Module):
    def __init__(self):
        super(FullyConnectedAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28*28, 256),
            nn.ReLU(),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 32),
        )

        self.decoder = nn.Sequential(
            nn.Linear(32, 64),
            nn.ReLU(),
            nn.Linear(64, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 28*28),
            nn.Sigmoid()  # to probability
        )
    
    def forward(self, x):
        x = x.view(-1, 28*28)  # flatten
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded.view(-1, 1, 28, 28)  # restore shape

In [None]:
fc_ae = FullyConnectedAE()
fc_ae = train_ae(fc_ae, train_loader, test_loader)

In [None]:
with torch.no_grad():
    test_images, _ = next(iter(test_loader))
    reconstructed = fc_ae(test_images)
    
show_images(test_images, 'Original Images')
show_images(reconstructed, 'Reconstructed Images')

In [None]:
torch.save(fc_ae, 'fc_autoencoder.pth')

In [None]:
def get_latent_representations(model, dataloader):
    model.eval()
    latent_vectors = []
    labels = []
    with torch.no_grad():
        for data, label in dataloader:
            if isinstance(model, FullyConnectedAE):
                data = data.view(-1, 28*28)
                latent = model.encoder(data)
            else: 
                latent = model.encoder(data)
                latent = latent.view(latent.size(0), -1)
            latent_vectors.append(latent)
            labels.append(label)
    return torch.cat(latent_vectors), torch.cat(labels)

In [None]:
latent, labels = get_latent_representations(fc_ae, test_loader)

tsne = TSNE(n_components=2, random_state=42)
latent_2d = tsne.fit_transform(latent.cpu().numpy())

fig = plt.figure(figsize=(10, 8))
scatter = plt.scatter(
    x=latent_2d[:, 0], 
    y=latent_2d[:, 1],
    c=labels.cpu().numpy(),  
    cmap='tab10',
    alpha=0.6
)
fig.legend(
    handles=scatter.legend_elements()[0],
    labels=classes,
    title="Classes"
)

plt.title('t-SNE visualization')
plt.show()

In [None]:
class ConvAE(nn.Module):
    def __init__(self):
        super(ConvAE, self).__init__()

        self.encoder = nn.Sequential(
            nn.Conv2d(1, 16, 3, stride=2, padding=1),  # 16x14x14
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, stride=2, padding=1), # 32x7x7
            nn.ReLU(),
            nn.Conv2d(32, 64, 7),                     # 64x1x1
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 7),              # 32x7x7
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 3, stride=2, padding=1, output_padding=1), # 16x14x14
            nn.ReLU(),
            nn.ConvTranspose2d(16, 1, 3, stride=2, padding=1, output_padding=1), # 1x28x28
            nn.Sigmoid()
        )
    
    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded

In [None]:
conv_ae = ConvAE()
conv_ae = train_ae(conv_ae, train_loader, test_loader)

In [None]:
with torch.no_grad():
    test_images, _ = next(iter(test_loader))
    reconstructed = conv_ae(test_images)
    
show_images(test_images, 'Original Images')
show_images(reconstructed, 'Reconstructed Images')

In [None]:
torch.save(conv_ae, 'conv_autoencoder.pth')

In [None]:
latent, labels = get_latent_representations(conv_ae, test_loader)

tsne = TSNE(n_components=2, random_state=42)
latent_2d = tsne.fit_transform(latent.cpu().numpy())

fig = plt.figure(figsize=(10, 8))
scatter = plt.scatter(
    x=latent_2d[:, 0], 
    y=latent_2d[:, 1],
    c=labels.cpu().numpy(),  
    cmap='tab10',
    alpha=0.6
)
fig.legend(
    handles=scatter.legend_elements()[0],
    labels=classes,
    title="Classes"
)

plt.title('t-SNE visualization')
plt.show()

In [None]:
def interactive_decoder(model, latent_dim=10):
    sliders = []
    for i in range(latent_dim):
        sliders.append(widgets.FloatSlider(
            value=0, min=-3, max=3, step=0.1,
            description=f'Dim {i}', continuous_update=True
        ))
    
    output = widgets.Output()
    
    def update_image(**kwargs):
        latent_vector = torch.tensor([list(kwargs.values())], dtype=torch.float32)
        with torch.no_grad():
            if latent_dim < model.decoder[0].in_features:
                zeros = torch.zeros(1, model.decoder[0].in_features - latent_dim)
                latent_vector = torch.cat([latent_vector, zeros], dim=1)
            decoded = model.decoder(latent_vector).view(1, 1, 28, 28)
        
        with output:
            output.clear_output(wait=True)
            plt.imshow(decoded[0][0].cpu().numpy(), cmap='gray')
            plt.axis('off')
            plt.show()
    
    ui = widgets.VBox(sliders)
    widgets.interactive(update_image, **{slider.description: slider for slider in sliders})
    display(ui, output)

In [None]:
interactive_decoder(fc_ae, latent_dim=10)