In [1]:
import numpy as np
import matplotlib.pyplot as plt
from src.util.util import run_experiment

# test
model, train_losses, test_losses, l0_norms, l1_norms, l2_norms = run_experiment(BITS=2, ROUND="ROUND", EPOCHS=2, BASELINE=False)

Files already downloaded and verified
Files already downloaded and verified


  return super().rename(names)



Test set: Average loss: 0.0127, Accuracy: 4221/10000 (42.21%)

Epoch 0: Train Loss: 1.761767, Test Loss: 0.012697
L0 Norm: 17703.000000, L1 Norm: 2041.247505, L2 Norm: 32.253666

Test set: Average loss: 0.0121, Accuracy: 4554/10000 (45.54%)

Epoch 1: Train Loss: 1.580117, Test Loss: 0.012068
L0 Norm: 18122.000000, L1 Norm: 2426.363026, L2 Norm: 35.610047

Test set: Average loss: 0.0119, Accuracy: 4616/10000 (46.16%)

Epoch 2: Train Loss: 1.517968, Test Loss: 0.011906
L0 Norm: 20440.000000, L1 Norm: 2940.692379, L2 Norm: 38.658958


In [None]:
def run_multiple_experiments(bits, num_runs, epochs):
    results = {}
    
    for bit in bits:
        bit_results = {
            'train_losses': [],
            'test_losses': [],
            'l0_norms': [],
            'l1_norms': [],
            'l2_norms': []
        }
        
        for _ in range(num_runs):
            model, train_losses, test_losses, l0_norms, l1_norms, l2_norms = run_experiment(
                BITS=bit, ROUND="ROUND", EPOCHS=epochs, BASELINE=(bit == 32)
            )
            
            bit_results['train_losses'].append(train_losses)
            bit_results['test_losses'].append(test_losses)
            bit_results['l0_norms'].append(l0_norms)
            bit_results['l1_norms'].append(l1_norms)
            bit_results['l2_norms'].append(l2_norms)
        
        # Calculate averages
        avg_results = {
            'avg_train_losses': np.mean(bit_results['train_losses'], axis=0),
            'avg_test_losses': np.mean(bit_results['test_losses'], axis=0),
            'avg_l0_norms': np.mean(bit_results['l0_norms'], axis=0),
            'avg_l1_norms': np.mean(bit_results['l1_norms'], axis=0),
            'avg_l2_norms': np.mean(bit_results['l2_norms'], axis=0)
        }
        
        results[bit] = avg_results
    
    return results

def plot_norms(results, bits, epochs):
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 15))
    for bit in bits:
        epochs_range = range(1, epochs + 1)
        ax1.plot(epochs_range, results[bit]['avg_l0_norms'], label=f'{bit}-bit')
        ax2.plot(epochs_range, results[bit]['avg_l1_norms'], label=f'{bit}-bit')
        ax3.plot(epochs_range, results[bit]['avg_l2_norms'], label=f'{bit}-bit')
    
    ax1.set_title('L0 Norm')
    ax2.set_title('L1 Norm')
    ax3.set_title('L2 Norm')
    
    for ax in (ax1, ax2, ax3):
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Norm Value')
        ax.legend()
        ax.grid(True)
    
    plt.tight_layout()
    plt.savefig('norms_plot.png')
    plt.close()

def plot_losses(results, bits, epochs):
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10))
    for bit in bits:
        epochs_range = range(1, epochs + 1)
        ax1.plot(epochs_range, results[bit]['avg_train_losses'], label=f'{bit}-bit (Train)')
        ax1.plot(epochs_range, results[bit]['avg_test_losses'], label=f'{bit}-bit (Test)', linestyle='--')
    
    ax1.set_title('Train and Test Losses')
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Loss')
    ax1.legend()
    ax1.grid(True)
    
    for bit in bits:
        epochs_range = range(1, epochs + 1)
        test_train_diff = np.array(results[bit]['avg_test_losses']) - np.array(results[bit]['avg_train_losses'])
        ax2.plot(epochs_range, test_train_diff, label=f'{bit}-bit')
    
    ax2.set_title('Test Loss - Train Loss')
    ax2.set_xlabel('Epochs')
    ax2.set_ylabel('Generalization gap')
    ax2.legend()
    ax2.grid(True)
    
    plt.tight_layout()
    plt.savefig('losses_plot.png')
    plt.close()

# Run experiments
bits = [2, 3, 4, 8, 32]
num_runs = 5
epochs = 50

results = run_multiple_experiments(bits, num_runs, epochs)

# Create plots
plot_norms(results, bits, epochs)
plot_losses(results, bits, epochs)