In [1]:
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import CosineAnnealingLR
import numpy as np
import matplotlib.pyplot as plt

from ViTGSOM import AutoEncoder, ViTSOMLoss
from help_functions import get_grid_coords, decay_exponential, calculate_purity, plot_umap_som_weights, get_node_labels

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
    transforms.ToTensor(), 
])

dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

In [3]:
config = {
    'img_size': 28,
    'patch_size': 4,
    'num_of_channels': 1,
    'embed_dim': 16,
    'enc_depth': 4,
    'dec_depth': 2,
    'num_heads': 2,
    'mlp_dim': 64,
    'epochs': 400,
    'lr': 0.0005,
    'grow_after_epochs': 10,
    'grow_threshold': 0.5,
    'spread_factor': 0.5,
    'som_rows': 5,
    'som_cols': 5
}

In [4]:
autoencoder = AutoEncoder(img_size=config['img_size'], 
                          patch_size=config['patch_size'], 
                          num_of_channels=config['num_of_channels'], 
                          embed_dim=config['embed_dim'], 
                          enc_depth=config['enc_depth'],                                      
                          dec_depth=config['dec_depth'], 
                          num_heads=config['num_heads'], 
                          mlp_dim=config['mlp_dim'],
                          spread_factor=config['spread_factor'],
                          som_rows=config['som_rows'],
                          som_cols=config['som_cols'])    

In [5]:
optimizer = optim.AdamW(autoencoder.parameters(), lr=config['lr'])
loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=32, shuffle=True)
criterion = ViTSOMLoss()
scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder.to(device)
autoencoder.train()

# starting and ending value of sigma, beta is calculated to reach sigma_end at last epoch
sigma_start = autoencoder.get_sigma()
sigma_end = 0.5
beta = (sigma_end / sigma_start) ** (1 / config['grow_after_epochs'])

grid_coords = get_grid_coords(config['som_rows'], config['som_cols'], device)
current_sigma = sigma_start

history = {'total': [], 'mse': [], 'som': [], 'purity': []}

checkpoints = [0,10,25,50,100,150,200]
snapshot_som_weights = {}
snapshot_som_weights[0] = (
    autoencoder.get_som_weights().detach().cpu().numpy(),
    get_node_labels(autoencoder, loader, device)
)

print("Start training")
for epoch in range(config['epochs']):
    running_loss = 0.0
    running_mse = 0.0
    running_som = 0.0
    current_lambda = 1 
    
    sigma_t = decay_exponential(sigma_start, beta, epoch)
    
    for images, _ in loader:
        images = images.to(device)
        
        reconstructed, latent = autoencoder(images)
        som_weights = autoencoder.get_som_weights()
        
        total_loss, l_nn, l_som = criterion(images, reconstructed, latent, som_weights, grid_coords, sigma_t, current_lambda)
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        running_loss += total_loss.item()
        running_mse += l_nn.item()
        running_som += l_som.item()
    
    if epoch+1 in checkpoints:        
        weights_np = autoencoder.get_som_weights().detach().cpu().numpy()
        labels_np = get_node_labels(autoencoder, loader, device)
        snapshot_som_weights[epoch+1] = (weights_np, labels_np)
        
    if (epoch+1) % config['grow_after_epochs'] == 0:
        if autoencoder.check_growth(loader, device):
            optimizer = optim.AdamW(autoencoder.parameters(), lr=config['lr'])
            scheduler = CosineAnnealingLR(optimizer, T_max=config['epochs'], last_epoch=epoch)
            grid_coords = get_grid_coords(autoencoder.current_row_num, autoencoder.current_col_num, device)
    
    # updating learning rule through CosineAnnealingLR
    scheduler.step()
    purity = calculate_purity(autoencoder, loader, device)
    
    avg_total = running_loss / len(loader)
    avg_mse = running_mse / len(loader)
    avg_som = running_som / len(loader)
    
    history['total'].append(avg_total)
    history['mse'].append(avg_mse)
    history['som'].append(avg_som)
    history['purity'].append(purity)
    
    print(f"Epoch {epoch+1}/{config['epochs']} | Sigma: {sigma_t:.2f} | Loss: {avg_total:.8f} (MSE: {avg_mse:.8f} | SOM: {avg_som:.8f}) | Purity: {purity:.5f}")


Start training


KeyboardInterrupt: 

In [None]:
plot_umap_som_weights(snapshot_som_weights)