# Trenowanie modelu

In [1]:
import os
from bindsnet.datasets import MNIST
from bindsnet.encoding import PoissonEncoder
from SNN_supervised import create_network, create_monitorings, train_network, save_network, test_network, load_config
from torchvision import transforms
import torch

## Parametry modelu

Parametry mogą być zmieniane bezpośrednio w notebooku

In [2]:
config = {
    "n_neurons": 100, # Number of excitatory, inhibitory neurons
    "n_train": 50000, # Training set size
    "n_test": 10000, # Testing set size
    "n_eval": 500, # Evaluation set size
    "exc": 22.5, # Strength of synapse weights from excitatory to inhibitory layer
    "inh": 120, # Strength of synapse weights from inhibitory to excitatory layer
    "theta_plus": 0.05, # On-spike increment of DiehlAndCookNodes membrane threshold potential
    "time": 1, # Simulation time
    "dt": 1.0, # Simulation time step
    "intensity": 32,
    "update_interval": 10, # Accuracy update interval
    "eval_interval": 1000, # Evaulation interval
    "n_classes": 10, # Count of dataset classes
    "lr": [1e-10, 1e-3] # learning rate
}

lub pobrane z pliku konfiguracyjnego

In [None]:
config = load_config("config.json")

### Zainicjalizowanie modelu sieci

In [3]:
network = create_network(config["n_neurons"], config["exc"], config["inh"], config["dt"], config["theta_plus"], config["lr"])
network.load_state_dict(torch.load("models/standard_5000.pth"))
network, spikes = create_monitorings(network, config["time"])

In [4]:
dataset = MNIST(
    PoissonEncoder(time=config["time"], dt=config["dt"]),
    None,
    root=os.path.join("..", "..", "data", "MNIST"),
    download=True,
    transform=transforms.Compose(
        [transforms.ToTensor(), transforms.Lambda(lambda x: x * config["intensity"])]
    ),
)

In [5]:
train_set, val_set = torch.utils.data.random_split(dataset, [config["n_train"], config["n_test"]])

#### Trenowanie modelu

In [None]:
network = train_network(network, train_set, val_set, spikes, config)

In [None]:
save_network(network, "test.pth")

#### Testowanie modelu

In [None]:
test_network(network, config, spikes, val_set)