### Import Libraries

In [1]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from mnn_torch.devices import load_SiOx_multistate
from mnn_torch.models import MSNN, MCSNN
from snntorch import surrogate
from mnn_torch.effects import compute_PooleFrenkel_parameters


In [2]:
# test cuda availability
x = torch.rand(5, 3)
print(x)

torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()
torch.cuda.get_device_name(0)

tensor([[0.9292, 0.4785, 0.2735],
        [0.7004, 0.6785, 0.4075],
        [0.5779, 0.8990, 0.7539],
        [0.1442, 0.6784, 0.9076],
        [0.3517, 0.0471, 0.5894]])


'NVIDIA GeForce RTX 4090'

### Load Data and Initialize Parameters

In [36]:
# Initialize device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load experimental data
current_dir = os.getcwd()
experimental_data = load_SiOx_multistate("../data/SiO_x-multistate-data.mat")
(G_off, G_on, R, c, d_epsilon) = compute_PooleFrenkel_parameters(experimental_data)

# Hyperparameters
batch_size = 64
num_epochs = 3
num_inputs = 28 * 28
num_hidden = 100
num_outputs = 10
num_steps = 25
beta = 0.95
data_path = "../data"
lr = 5e-4

# Memristive configuration
PF_config = {
    "ideal": False,
    "k_V": 0.5,
    "G_off": G_off,
    "G_on": G_on,
    "R": R,
    "c": c,
    "d_epsilon": d_epsilon,
    "disturb_conductance": True,
    "disturb_mode": "fixed",
    "disturbance_probability": 0.1,
    "homeostasis_dropout": True,
    "homeostasis_threshold": 10,
}

# Data loading
transform = transforms.Compose(
    [
        transforms.Resize((28, 28)),
        transforms.Grayscale(),
        transforms.ToTensor(),
        transforms.Normalize((0,), (1,)),
    ]
)

mnist_train = datasets.MNIST(
    data_path, train=True, download=True, transform=transform
)
mnist_test = datasets.MNIST(
    data_path, train=False, download=True, transform=transform
)

training_loader = DataLoader(
    mnist_train, batch_size=batch_size, shuffle=True, drop_last=True
)
validation_loader = DataLoader(
    mnist_test, batch_size=batch_size, shuffle=True, drop_last=True
)


In [37]:
# Memristive configuration
def train_model_with_dropout(homeostasis_dropout=True):
    PF_config = {
        "ideal": False,
        "k_V": 0.5,
        "G_off": G_off,
        "G_on": G_on,
        "R": R,
        "c": c,
        "d_epsilon": d_epsilon,
        "disturb_conductance": True,
        "disturb_mode": "device",
        "disturbance_probability": 0.01,
        "homeostasis_dropout": homeostasis_dropout,
        "homeostasis_threshold": 10,
    }

    # Initialize network
    net = MSNN(num_inputs, num_hidden, num_outputs, num_steps, beta, PF_config).to(device)
    loss = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    loss_hist = []
    test_loss_hist = []
    test_acc_hist = []

    # Training loop
    start_time = time.time()
    for epoch in range(num_epochs):
        for iter_counter, (data, targets) in enumerate(training_loader):
            data, targets = data.to(device), targets.to(device)

            # Forward pass
            net.train()
            spk_rec, mem_rec = net(data.view(batch_size, -1))
            loss_val = sum(loss(mem_rec[step], targets) for step in range(num_steps))

            # Backward pass
            optimizer.zero_grad()
            loss_val.backward()
            optimizer.step()

            loss_hist.append(loss_val.item())

            # Evaluate on validation set
            if iter_counter % 50 == 0:
                with torch.no_grad():
                    net.eval()
                    test_data, test_targets = next(iter(validation_loader))
                    test_data, test_targets = test_data.to(device), test_targets.to(device)

                    test_spk, test_mem = net(test_data.view(batch_size, -1))
                    test_loss = sum(
                        loss(test_mem[step], test_targets) for step in range(num_steps)
                    )
                    test_loss_hist.append(test_loss.item())

                    # Compute accuracy
                    _, idx = test_spk.sum(dim=0).max(1)
                    acc = (idx == test_targets).float().mean().item()
                    test_acc_hist.append(acc)

                    print(
                        f"Epoch {epoch}, Iteration {iter_counter}\n"
                        f"Train Loss: {loss_val.item():.2f}, Test Loss: {test_loss.item():.2f}, "
                        f"Test Accuracy: {acc * 100:.2f}%"
                    )

    print(f"Training completed in {time.time() - start_time:.2f} seconds")
    return loss_hist, test_loss_hist, test_acc_hist

In [38]:
# Train and collect data for both configurations
loss_hist_dropout = []
test_loss_hist_dropout = []
test_acc_hist_dropout = []

loss_hist_no_dropout = []
test_loss_hist_no_dropout = []
test_acc_hist_no_dropout = []

# Train with homeostasis dropout
loss_hist_dropout, test_loss_hist_dropout, test_acc_hist_dropout = train_model_with_dropout(homeostasis_dropout=True)

# Train without homeostasis dropout
loss_hist_no_dropout, test_loss_hist_no_dropout, test_acc_hist_no_dropout = train_model_with_dropout(homeostasis_dropout=False)

# Plotting the results
plt.figure(figsize=(12, 6))

# Loss comparison
plt.subplot(1, 2, 1)
plt.plot(test_loss_hist_dropout, label='With Homeostasis Dropout', color='blue')
plt.plot(test_loss_hist_no_dropout, label='Without Homeostasis Dropout', color='red')
plt.title('Test Loss Comparison')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()

# Accuracy comparison
plt.subplot(1, 2, 2)
plt.plot(test_acc_hist_dropout, label='With Homeostasis Dropout', color='blue')
plt.plot(test_acc_hist_no_dropout, label='Without Homeostasis Dropout', color='red')
plt.title('Test Accuracy Comparison')
plt.xlabel('Iterations')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

Epoch 0, Iteration 0
Train Loss: 70.11, Test Loss: 70.99, Test Accuracy: 12.50%
Epoch 0, Iteration 50
Train Loss: 69.35, Test Loss: 68.31, Test Accuracy: 4.69%
Epoch 0, Iteration 100
Train Loss: 68.93, Test Loss: 64.91, Test Accuracy: 4.69%
Epoch 0, Iteration 150
Train Loss: 70.19, Test Loss: 66.88, Test Accuracy: 6.25%
Epoch 0, Iteration 200
Train Loss: 64.52, Test Loss: 68.70, Test Accuracy: 9.38%
Epoch 0, Iteration 250
Train Loss: 61.91, Test Loss: 66.10, Test Accuracy: 10.94%
Epoch 0, Iteration 300
Train Loss: 70.96, Test Loss: 70.69, Test Accuracy: 7.81%
Epoch 0, Iteration 350
Train Loss: 66.84, Test Loss: 63.45, Test Accuracy: 9.38%
Epoch 0, Iteration 400
Train Loss: 68.34, Test Loss: 64.42, Test Accuracy: 3.12%
Epoch 0, Iteration 450
Train Loss: 67.05, Test Loss: 71.20, Test Accuracy: 9.38%
Epoch 0, Iteration 500
Train Loss: 65.80, Test Loss: 67.54, Test Accuracy: 6.25%
Epoch 0, Iteration 550
Train Loss: 67.61, Test Loss: 66.07, Test Accuracy: 6.25%


KeyboardInterrupt: 