In [3]:
!pip install matplotlib

Collecting matplotlib
  Downloading matplotlib-3.10.0-cp311-cp311-macosx_11_0_arm64.whl.metadata (11 kB)
Collecting contourpy>=1.0.1 (from matplotlib)
  Downloading contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl.metadata (5.4 kB)
Collecting cycler>=0.10 (from matplotlib)
  Using cached cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)
Collecting fonttools>=4.22.0 (from matplotlib)
  Downloading fonttools-4.55.3-cp311-cp311-macosx_10_9_universal2.whl.metadata (165 kB)
Collecting kiwisolver>=1.3.1 (from matplotlib)
  Downloading kiwisolver-1.4.8-cp311-cp311-macosx_11_0_arm64.whl.metadata (6.2 kB)
Collecting pyparsing>=2.3.1 (from matplotlib)
  Downloading pyparsing-3.2.0-py3-none-any.whl.metadata (5.0 kB)
Downloading matplotlib-3.10.0-cp311-cp311-macosx_11_0_arm64.whl (8.0 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.0/8.0 MB[0m [31m32.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hDownloading contourpy-1.3.1-cp311-cp311-macosx_11_0_arm64.whl (254 kB)


In [5]:
import copy

from torch.utils.data import DataLoader

from learnitation.shape_classifier import ShapeClassifier, get_shape_probabilities
from learnitation.shape_dataset import create_shape_dataset, visualize_shapes
from learnitation.trainer import (
    analyze_parameter_differences,
    analyze_weight_spectrum,
    calculate_l2_distances,
    perform_model_pca,
    plot_pca_pairwise,
    print_distance_matrix,
    pure_test,
    train_model,
)

In [6]:
# Create datasets
train_dataset, test_dataset, _, _ = create_shape_dataset(
    num_samples=10000,
    rotation_range=(-45, 45),
    scale_range=(0.7, 1.3),
    translation_range=(-3, 3),
)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

# Create and train model with snapshot saving enabled
model = ShapeClassifier()
history = train_model(
    model,
    train_loader,
    test_loader,
    num_epochs=50,  # Increased epochs to get more snapshots
    save_last_snapshots=True,  # Enable snapshot saving
)

# Access the model snapshots from the last 10% of training
if "model_snapshots" in history:
    snapshots = history["model_snapshots"]
    print(f"\nNumber of model snapshots captured: {len(snapshots)}")

    # You can now analyze these snapshots using the various analysis functions
    # For example:
    distances, param_distances = calculate_l2_distances(snapshots)
    print_distance_matrix(snapshots, distances)

Epoch 1/50:
Train Loss: 0.9341, Train Accuracy: 58.77%
Test Loss: 0.6096, Test Accuracy: 68.75%
--------------------
Epoch 2/50:
Train Loss: 0.5724, Train Accuracy: 72.77%
Test Loss: 0.4601, Test Accuracy: 79.15%
--------------------
Epoch 3/50:
Train Loss: 0.4387, Train Accuracy: 80.00%
Test Loss: 0.2755, Test Accuracy: 89.30%
--------------------
Epoch 4/50:
Train Loss: 0.2583, Train Accuracy: 89.72%
Test Loss: 0.1332, Test Accuracy: 96.45%
--------------------
Epoch 5/50:
Train Loss: 0.1461, Train Accuracy: 94.66%
Test Loss: 0.0574, Test Accuracy: 97.70%
--------------------
Epoch 6/50:
Train Loss: 0.0862, Train Accuracy: 97.00%
Test Loss: 0.0294, Test Accuracy: 99.60%
--------------------
Epoch 7/50:
Train Loss: 0.0579, Train Accuracy: 97.94%
Test Loss: 0.0116, Test Accuracy: 99.65%
--------------------
Epoch 8/50:
Train Loss: 0.0384, Train Accuracy: 98.71%
Test Loss: 0.0099, Test Accuracy: 99.80%
--------------------
Epoch 9/50:
Train Loss: 0.0345, Train Accuracy: 98.94%
Test Loss