In [1]:
import sys
sys.path.append("../") 

from src.utils.driver import set_seed

set_seed(57)

### Model and Dataset

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
import numpy as np
import os
import imageio

from src.utils.model import DistLayer

# Define the model class
class SimpleNN(nn.Module):
    def __init__(self, harmonic=False):
        super(SimpleNN, self).__init__()
        self.harmonic = harmonic
        if harmonic:
            self.fc1 = DistLayer(28 * 28, 10, n=1.)
        else:
            self.fc1 = nn.Linear(28 * 28, 10)
        nn.init.normal_(self.fc1.weight, mean=0, std=1/28.)

    def forward(self, x):
        x = x.view(-1, 28 * 28)  # Flatten the input
        x = self.fc1(x)
        if self.harmonic:
            prob = x/torch.sum(x, dim=1, keepdim=True)
            logits = (-1)*torch.log(prob)
            return logits
        return x

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Hyperparameters
batch_size = 64
learning_rate = 0.001
max_epochs = 100

# Load MNIST dataset
transform = transforms.Compose([
    transforms.ToTensor()
])

train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
import imageio

def save_weight_visualization_gif(model, train_loader, test_loader, 
                                   max_epochs=100, 
                                   learning_rate=0.001, 
                                   device='cuda', 
                                   output_dir='../results/mnist_vis',
                                   save_prefix='',
                                   selected_classes=[3, 5, 7, 9]):
    """
    Train the model and save weight visualizations as a GIF
    
    Args:
        model (nn.Module): Neural network model
        train_loader (DataLoader): Training data loader
        test_loader (DataLoader): Test data loader
        max_epochs (int): Maximum number of training epochs
        learning_rate (float): Learning rate for optimizer
        device (str): Training device (cuda/cpu)
        output_dir (str): Directory to save visualizations
        selected_classes (list): Classes to visualize
    """
    # Ensure output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Move model to device
    model = model.to(device)
    
    # Optimizer and loss
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    
    # Visualization storage
    weight_images = []
    
    # Training loop
    for epoch in [1]:
        # Training phase
        model.train()
        running_loss = 0.0
        
        for batch_idx, (data, targets) in enumerate(train_loader):
            data, targets = data.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(data)
            loss = outputs[range(targets.size(0)), targets].mean()
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
        
            # Periodically save weight visualization
            if (batch_idx+1) % 10 == 0:
                
                plt.figure(figsize=(16, 12))
                plt.suptitle(f'Weight Visualization - Epoch {epoch}')
                
                for i, cls in enumerate(selected_classes, 1):
                    # Extract weights for specific class
                    weights = model.fc1.weight.detach().cpu().numpy()[cls].reshape(28, 28)
                    weights = np.where(weights < 0.01, 1, 0)
                    
                    plt.subplot(2, 2, i)
                    plt.title(f'Class {cls}')
                    plt.imshow(weights, cmap='viridis')
#                    plt.colorbar()
                    plt.axis('off')
                
                plt.tight_layout()
                
                # Save plot to a temporary file
                temp_plot_path = os.path.join(output_dir, f'{save_prefix}_mnist_{(batch_idx+1)}.png')
                torch.save(model.state_dict(), temp_plot_path.replace('.png', '.pt'))
                plt.savefig(temp_plot_path)
                plt.close()
                
                # Read the image and append to list
                weight_images.append(imageio.imread(temp_plot_path))
                print(batch_idx)
                if batch_idx > 500:
                    break
    
    # Save as GIF
#    imageio.mimsave('../figures/mnist_weights_evolution.gif', weight_images, duration=0.5)
    
    # Evaluation
    model.eval()
    correct = 0
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device), targets.to(device)
            outputs = (-1)*model(data)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == targets).sum().item()
    
    accuracy = correct / len(test_loader.dataset) * 100
    print(f"Test Accuracy: {accuracy:.2f}%")



In [4]:
model = SimpleNN(harmonic=True).to(device)
save_weight_visualization_gif(model, train_loader, test_loader, save_prefix='harmonic')

  weight_images.append(imageio.imread(temp_plot_path))


9
19
29
39
49
59
69
79
89
99
109
119
129
139
149
159
169
179
189
199
209
219
229
239
249
259
269
279
289
299
309
319
329
339
349
359
369
379
389
399
409
419
429
439
449
459
469
479
489
499
509
Test Accuracy: 73.27%


In [5]:
model = SimpleNN(harmonic=False).to(device)
save_weight_visualization_gif(model, train_loader, test_loader, save_prefix='standard')

  weight_images.append(imageio.imread(temp_plot_path))


9
19
29
39
49
59
69
79
89
99
109
119
129
139
149
159
169
179
189
199
209
219
229
239
249
259
269
279
289
299
309
319
329
339
349
359
369
379
389
399
409
419
429
439
449
459
469
479
489
499
509
Test Accuracy: 67.38%
