### Import Libraries

In [16]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from mnn_torch.devices import load_SiOx_multistate
from mnn_torch.models import MSNN, MCSNN
from snntorch import surrogate
from mnn_torch.effects import compute_PooleFrenkel_parameters


In [3]:
# test cuda availability
x = torch.rand(5, 3)
print(x)

torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.get_device_name(0)

tensor([[0.5899, 0.3754, 0.5193],
        [0.1065, 0.6849, 0.8509],
        [0.3237, 0.6957, 0.5388],
        [0.7886, 0.1921, 0.6163],
        [0.3079, 0.7732, 0.7179]])


'NVIDIA GeForce RTX 4090'

### Load Data and Initialize Parameters

In [14]:
# Load experimental data and Poole-Frenkel parameters
current_dir = os.getcwd()
experimental_data = load_SiOx_multistate("../data/SiO_x-multistate-data.mat")
G_off, G_on, R, c, d_epsilon = compute_PooleFrenkel_parameters(experimental_data)

# Default hyperparameters
batch_size = 64
num_epochs = 1
num_steps = 25
beta = 0.95
data_path = "../data"
lr = 5e-4

# Memristive configuration
PF_config = {
    "ideal": False,
    "k_V": 0.5,
    "G_off": G_off,
    "G_on": G_on,
    "R": R,
    "c": c,
    "d_epsilon": d_epsilon,
    "disturb_conductance": True,
    "disturb_mode": "fixed",
    "disturbance_probability": 0.1,
    "homeostasis_dropout": True,
    "homeostasis_threshold": 10,
}

# Data loading
transform = transforms.Compose(
    [
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,)),
    ]
)
mnist_train = datasets.MNIST(data_path, train=True, download=True, transform=transform)
mnist_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)

training_loader = DataLoader(mnist_train, batch_size=batch_size, shuffle=True, drop_last=True)
validation_loader = DataLoader(mnist_test, batch_size=batch_size, shuffle=True, drop_last=True)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### Define Model and Training Function

In [17]:
def train_model(model_type="MSNN"):
    # Initialize the model
    if model_type == "MSNN":
        net = MSNN(28 * 28, 100, 10, num_steps, beta, PF_config).to(device)
    elif model_type == "MCSNN":
        net = MCSNN(
            beta=beta,
            spike_grad=surrogate.fast_sigmoid(slope=25),
            num_steps=num_steps,
            batch_size=batch_size,
            num_kernels=5,
            num_conv1=12,
            num_conv2=64,
            max_pooling=2,
            num_outputs=10,
            memristive_config=PF_config,
        ).to(device)
    else:
        raise ValueError("Invalid model_type. Choose 'MSNN' or 'MCSNN'.")

    # Loss and optimizer
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    # Metrics storage
    loss_hist = []
    test_loss_hist = []
    test_acc_hist = []

    # Training loop
    start_time = time.time()
    for epoch in range(num_epochs):
        for iter_counter, (data, targets) in enumerate(training_loader):
            data, targets = data.to(device), targets.to(device)

            # Forward pass
            net.train()
            spk_rec, mem_rec = net(data.view(batch_size, -1))
            loss_val = sum(loss_fn(mem_rec[step], targets) for step in range(num_steps))

            # Backward pass
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            loss_hist.append(loss_val.item())

            # Evaluate on validation set
            if iter_counter % 50 == 0:
                with torch.no_grad():
                    net.eval()
                    test_data, test_targets = next(iter(validation_loader))
                    test_data, test_targets = test_data.to(device), test_targets.to(device)

                    test_spk, test_mem = net(test_data.view(batch_size, -1))
                    test_loss = sum(loss_fn(test_mem[step], test_targets) for step in range(num_steps))
                    test_loss_hist.append(test_loss.item())

                    # Compute accuracy
                    _, idx = test_spk.sum(dim=0).max(1)
                    acc = (idx == test_targets).float().mean().item()
                    test_acc_hist.append(acc)

                    print(
                        f"Epoch {epoch}, Iteration {iter_counter}\n"
                        f"Train Loss: {loss_val.item():.2f}, Test Loss: {test_loss.item():.2f}, "
                        f"Test Accuracy: {acc * 100:.2f}%"
                    )

    print(f"Training completed in {time.time() - start_time:.2f} seconds")

    return loss_hist, test_loss_hist, test_acc_hist


### Train the Model

In [18]:
loss_hist, test_loss_hist, test_acc_hist = train_model(model_type="MSNN")

Epoch 0, Iteration 0
Train Loss: 569.54, Test Loss: 568.59, Test Accuracy: 18.75%
Epoch 0, Iteration 50
Train Loss: 353.54, Test Loss: 355.43, Test Accuracy: 67.19%
Epoch 0, Iteration 100
Train Loss: 313.02, Test Loss: 347.79, Test Accuracy: 68.75%
Epoch 0, Iteration 150
Train Loss: 285.08, Test Loss: 303.24, Test Accuracy: 75.00%
Epoch 0, Iteration 200
Train Loss: 291.11, Test Loss: 303.96, Test Accuracy: 82.81%
Epoch 0, Iteration 250
Train Loss: 273.17, Test Loss: 285.61, Test Accuracy: 82.81%
Epoch 0, Iteration 300
Train Loss: 299.91, Test Loss: 262.17, Test Accuracy: 87.50%
Epoch 0, Iteration 350
Train Loss: 297.71, Test Loss: 290.14, Test Accuracy: 76.56%
Epoch 0, Iteration 400
Train Loss: 261.68, Test Loss: 285.95, Test Accuracy: 87.50%
Epoch 0, Iteration 450
Train Loss: 300.36, Test Loss: 293.65, Test Accuracy: 78.12%
Epoch 0, Iteration 500
Train Loss: 298.51, Test Loss: 294.39, Test Accuracy: 82.81%
Epoch 0, Iteration 550
Train Loss: 290.59, Test Loss: 305.08, Test Accuracy: 84

KeyboardInterrupt: 

### Plot Metrics

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(loss_hist, label="Training Loss")
plt.plot(np.linspace(0, len(loss_hist), len(test_loss_hist)), test_loss_hist, label="Test Loss")
plt.xlabel("Iterations")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss over Time")

plt.subplot(1, 2, 2)
plt.plot(test_acc_hist, label="Test Accuracy")
plt.xlabel("Validation Steps")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Accuracy over Time")

plt.tight_layout()
plt.show()
