# Variational Autoencoders (VAEs) for CryptoPunks Dataset

## Introduction

This Jupyter notebook explores the application of Variational Autoencoders (VAEs) to the CryptoPunks dataset. CryptoPunks are unique digital collectibles on the Ethereum blockchain, consisting of 10,000 algorithmically generated characters with distinct attributes.


In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from datasets import load_dataset

In [None]:
# Training on Macbook Pro with M1 chip, using Metal Performance Shaders
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

In [None]:
# Load the dataset
dataset = load_dataset("huggingnft/cryptopunks", split="train")

# Define split ratio
train_ratio = 0.8

# Create the train-test split
split_dataset = dataset.train_test_split(test_size=1 - train_ratio)

# Access the train and test splits
train_data = split_dataset['train']
test_data = split_dataset['test']

In [None]:
# Function to convert image to tensor
def transform_func(img):
    transform = transforms.Compose([
        transforms.Resize((128, 128)),  # Transform image to 128x128
        transforms.PILToTensor(),  # Convert PIL Image to PyTorch tensor
        transforms.ConvertImageDtype(torch.float)   # Convert PIL Image to float
    ])
    return transform(img)

# Define class for Cryptopunks dataset
class PunkDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list
        self.transform = transform_func

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        image = self.data_list[idx]["image"]
        if isinstance(image, Image.Image):
            image = self.transform(image)
        return image
    
# Initialize the Punk dataset for training and test sets
train_dataset = PunkDataset(train_data)
test_dataset = PunkDataset(test_data)

In [None]:
# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
# Validate the dataset sizes that they are 80% and 20% of the total dataset
print(f"Number of training samples: {len(train_data)}")
print(f"Number of test samples: {len(test_data)}")

In [None]:
# Display images in a grid
def plot_images_from_loader(data_loader, title, num_images=8):
    """Display a set of images from the DataLoader in a grid."""
    # Get a batch of images
    images = next(iter(data_loader))
    
    # Create a grid from the images
    fig, axes = plt.subplots(1, num_images, figsize=(32, 3))
    fig.suptitle(title)
    
    # Plot each image
    for i in range(num_images):
        ax = axes[i]
        image = images[i].permute(1, 2, 0).numpy()  # Convert tensor to NumPy array
        ax.imshow(image)
    plt.show()

# Display images from training and test loaders
plot_images_from_loader(train_loader, title="Training Images")
plot_images_from_loader(test_loader, title="Test Images")

In [None]:
# Encoding network based on a simple forward feed neural network
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super(Encoder, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout(0.25)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.dropout3 = nn.Dropout(0.25)
        
        self.fc1 = nn.Linear(128 * 16 * 16, 512)
        self.fc_mean = nn.Linear(512, latent_dim)
        self.fc_logvar = nn.Linear(512, latent_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = self.dropout1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = self.dropout2(x)

        x = F.relu(self.conv3(x))
        x = self.pool3(x)
        x = self.dropout3(x)

        x = x.view(x.size(0), -1)  # Flatten the output
        x = F.relu(self.fc1(x))
        mean = self.fc_mean(x) # Mean of the latent space
        logvar = self.fc_logvar(x) # Log variance of the latent space
        return mean, logvar


In [None]:
# Decoding network based on a simple forward feed neural network
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super(Decoder, self).__init__()
        self.fc1 = nn.Linear(latent_dim, 512)
        self.fc2 = nn.Linear(512, 128 * 16 * 16)
        
        self.conv1 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dropout1 = nn.Dropout(0.25)
        
        self.conv2 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.dropout2 = nn.Dropout(0.25)
        
        self.conv3 = nn.ConvTranspose2d(32, 3, kernel_size=3, stride=2, padding=1, output_padding=1)

    def forward(self, x):
        x = nn.ReLU()(self.fc1(x))
        x = nn.ReLU()(self.fc2(x))
        x = x.view(x.size(0), 128, 16, 16)
        
        x = nn.ReLU()(self.conv1(x))
        x = self.dropout1(x)
        
        x = nn.ReLU()(self.conv2(x))
        x = self.dropout2(x)
        
        x = torch.sigmoid(self.conv3(x))
        return x

In [None]:
# Define the Autoencoder Model
class VAE(nn.Module):
    def __init__(self, latent_dim):
        super(VAE, self).__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    # Reparameterization trick
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        return self.decoder(z), mu, logvar

In [None]:
# Initialize the VAE
latent_dim = 128
vae = VAE(latent_dim)

# Define optimizer
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-4)

# Define the loss function for the VAE, which is a combination of Binary Cross Entropy (BCE) and Kullback-Leibler Divergence (KLD)
def vae_loss(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x, reduction='sum')
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE, KLD

# Training loop
def train_vae(dataloader, model, optimizer, num_epochs=20):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        running_bce = 0.0
        running_kld = 0.0
        for images in dataloader:
            optimizer.zero_grad()
            
            recon_images, mu, logvar = model(images)
            bce_loss, kl_loss = vae_loss(recon_images, images, mu, logvar)
            loss = bce_loss + kl_loss
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            running_bce += bce_loss.item()
            running_kld += kl_loss.item()
        
        epoch_loss = running_loss / len(dataloader.dataset)
        epoch_bce = running_bce / len(dataloader.dataset)
        epoch_kld = running_kld / len(dataloader.dataset)
        
        print(f"Epoch [{epoch+1}/{num_epochs}], Total Loss: {epoch_loss:.4f}, "
              f"BCE Loss: {epoch_bce:.4f}, KL Loss: {epoch_kld:.4f}")

# Train the VAE
train_vae(train_loader, vae, optimizer, num_epochs=20)

In [None]:
# Function to display original and reconstructed images
def plot_vae_results(data_loader, model, num_images=8):
    model.eval()
    
    with torch.no_grad():
        images = next(iter(data_loader))
        recon_images, _, _ = model(images)
    
    # Convert images and reconstructed images to NumPy arrays for visualization
    images = images.cpu().numpy()
    recon_images = recon_images.cpu().numpy()
    
    fig, axes = plt.subplots(2, num_images, figsize=(15, 4))
    fig.suptitle("Original and Reconstructed Images (VAE)")
    
    # Display original images
    for i in range(num_images):
        ax = axes[0, i]
        img = images[i].transpose(1, 2, 0)
        ax.imshow(np.clip(img, 0, 1))
        ax.axis('off')
    
    # Display reconstructed images
    for i in range(num_images):
        ax = axes[1, i]
        img = recon_images[i].transpose(1, 2, 0)
        ax.imshow(np.clip(img, 0, 1))
        ax.axis('off')
    
    plt.show()

# Call the function to display images
plot_vae_results(test_loader, vae, num_images=8)

In [None]:
# Save the model state dictionary
torch.save(vae.state_dict(), 'autoencoder.pth')

# Save the entire model
torch.save(vae, 'autoencoder_complete.pth')