# Analogous Recurrent ANN Trainer

This notebook can be used to generate and train a recurrent ANN so that the weights can be copied over to a SNN with the same architecture.

## Imports

In [1]:
import os
import torch
import struct
import numpy as np
from torch import nn
import lightning as L
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from lightning.pytorch.callbacks import ModelCheckpoint
cwd = os.getcwd()

In [6]:
import random

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Set the seed for reproducibility
set_seed(42397)

## Load Data

In [7]:
# load all data and prepare for vector conversion

# load data
f = open("..\\Data\\Training Data\\train_5500.txt")
data = f.read()

# split data into sentences
sents = data.split('\n')

# split each sentence into words
for i in range(len(sents)):
    sents[i] = sents[i].split(' ')[:-1]

In [8]:
# prepare word2vector vocabulary (ty chatGPT :) )

def read_word_vectors(filepath):
    with open(filepath, 'rb') as f:
        header = f.readline()
        vocab_size, vector_size = map(int, header.split())
        binary_len = np.dtype('float32').itemsize * vector_size
        word_vectors = {}

        for _ in range(vocab_size):
            word = []
            while True:
                ch = f.read(1)
                if ch == b' ':
                    break
                if ch != b'\n':
                    word.append(ch)
            word = b''.join(word).decode('utf-8')
            vector = np.frombuffer(f.read(binary_len), dtype='float32')
            word_vectors[word] = vector

    return word_vectors

def get_word_vector(word, word_vectors):
    return word_vectors.get(word)

# Load the word vectors
word_vectors = read_word_vectors('..\\Data\\true_vectors.bin')

In [9]:
# test vocabulary

# Define a list of words to convert to vectors
words = ['example', 'word', 'vector', 'king', 'queen']

# Convert words to vectors
for word in words:
    vector = get_word_vector(word, word_vectors)
    print(vector.shape)
    if vector is not None:
        print(f"Word: {word}\nVector: {vector}\n")
    else:
        print(f"Word: {word} not found in vocabulary.\n")

(64,)
Word: example
Vector: [ 0.30070093 -0.14708109 -0.05512658 -0.28891692  0.16964374 -0.26381055
 -0.1238706  -0.26599023  0.04002719 -0.0139405  -0.09649079 -0.20608869
 -0.34638372 -0.05804674  0.373583   -0.03040116  0.1421226  -0.1942786
  0.09814846 -0.19097279 -0.11167672  0.28619     0.06731004  0.33642802
  0.16544527  0.12827992 -0.06165301 -0.15542434  0.32104218  0.01503407
  0.39649448 -0.06764489  0.31964797  0.18220599 -0.2872599  -0.03778282
  0.00464729  0.37776193  0.05840794 -0.00786143 -0.40444744  0.2015574
  0.5840568   0.18758695  0.04742671  0.32666302 -0.1868712   0.30280098
  0.237733   -0.6091492  -0.08265247  0.44745898 -0.14270085 -0.6882093
  0.05940733 -0.14634833 -0.03241995  0.115207   -0.07995818 -0.19041005
  0.2315593   0.15892205  0.13351633 -0.43538508]

(64,)
Word: word
Vector: [ 0.21292834 -0.45759308 -0.18516563  0.12465494  0.02143879 -0.23055013
 -0.14937697 -0.43313178  0.23816882 -0.33727258 -0.25391486 -0.3715287
 -0.39470986  0.10469216

In [10]:
# perform word to vector conversion

# perform conversion
vec_data = []
for sent in sents[:-1]:
    vecs = []
    vecs.append(sent[0])
    for word in sent[1:]:
        try:
            word = word.lower()
            vec = get_word_vector(word, word_vectors)
            vecs.append(torch.from_numpy(vec))
        except:
            pass
    vecs.append(torch.zeros(64))
    vec_data.append(vecs)

# pad all sentences to length of longest sentence
max_len = max([len(sent) for sent in vec_data])
vec_data_pad = []
for sent in vec_data:
    pad_len = max_len - len(sent)
    for i in range(pad_len):
        sent.append(torch.zeros(64))
    vec_data_pad.append(sent)
vec_data = vec_data_pad

# split into training and test data
train_data = vec_data[:5000]
test_data = vec_data[5000:-1]

# NOTE: first word of each sentence is correct categ. -- last sentence is empty (excluded)

  vecs.append(torch.from_numpy(vec))


In [11]:
# create DataSet which can be used with PyTorch DataLoader

ans_key = { 'DESC' :  0,
            'ENTY' :  1,
            'ABBR' :  2,
            'HUM'  :  3,
            'LOC'  :  4,
            'NUM'  :  5 }

class QuestionDataset(Dataset):
    """ Question Dataset """
    
    def __init__(self, data):
        """
        data = list of (list of words -- first word is label)
        """
        self.labels = []
        self.sents = []
        for sent in data:
            lab_val = ans_key[sent[0].split(":")[0]]
            lab_arr = torch.tensor(lab_val)
            self.labels.append(lab_arr)
            self.sents.append(sent[1:])
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        sent = self.sents[idx]
        label = self.labels[idx]
        return sent, label

train_DSet = QuestionDataset(train_data)
test_DSet = QuestionDataset(test_data)

In [12]:
# Create data loaders.

batch_size = 1
train_dataloader = DataLoader(train_DSet, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_DSet, batch_size=batch_size, shuffle=False)

for X, y in test_dataloader:
    print(X)
    print(y)
    break

[tensor([[ 0.0411, -0.0476,  0.3912,  0.1446,  0.2843,  0.0124, -0.0079, -0.0237,
         -0.0122, -0.1286, -0.1066, -0.1421, -0.2869,  0.1265,  0.2813,  0.1080,
          0.1238, -0.3376,  0.1359, -0.0410,  0.2352,  0.1185, -0.1682,  0.3894,
          0.4837, -0.1884,  0.0923,  0.3109,  0.1763,  0.0841,  0.0832, -0.2112,
          0.0484, -0.1428, -0.2887, -0.0050,  0.0670,  0.0890,  0.2170,  0.0736,
          0.0262,  0.0147,  0.3572,  0.3137,  0.0837,  0.3432, -0.2798,  0.0383,
          0.3950, -0.5202,  0.1438,  0.5015, -0.4562, -0.6651,  0.0707,  0.1524,
          0.0379,  0.3060,  0.4334, -0.2978,  0.1729,  0.5553, -0.1821, -0.1045]]), tensor([[ 0.0643, -0.1233,  0.1670, -0.0083,  0.3737,  0.0114, -0.1574, -0.0829,
         -0.0848, -0.1920, -0.1029, -0.0367, -0.3347, -0.0631,  0.4439,  0.1288,
          0.0839, -0.1391,  0.1688, -0.0889, -0.0923,  0.1780,  0.1029,  0.0891,
          0.0895,  0.1983, -0.1704, -0.1576,  0.3401, -0.1080,  0.2319, -0.0226,
          0.4397,  0.285

## Create and Optimize Model

In [13]:
# Define model
class LitNeuralNetwork(L.LightningModule):
    def __init__(self):
        """ Builds recurrent neural network model """
        super().__init__()
        
        # build layers
        self.ff = nn.Sequential(
            nn.Linear(64, 48, bias=False),
            nn.ReLU(),
        )
        self.rnn = nn.RNN(48, 16, nonlinearity='relu', bias=False)
        self.out = nn.Sequential(
            nn.Linear(16, 6, bias=False),
            nn.LogSoftmax(dim=1),
        )
        
        # build mask for recurrent layer (hh weights -- no self-connect)
        self.mask = torch.ones(16, 16).to(device)
        dia_ind = np.diag_indices(self.mask.shape[0])
        self.mask[dia_ind[0], dia_ind[1]] = torch.zeros(self.mask.shape[0]).to(device)
        
        # set loss function
        self.loss_fn = nn.NLLLoss()

    def forward(self, q):
        """ Implements feed-forward then recurrent layer """
        ff_q = []
        for word in q:
            ff_q.append(self.ff(word))
        ff_q = torch.stack(ff_q)
        h_N = Variable(torch.zeros(1, 16)).to(device)
        rnn_out = Variable(torch.zeros(1, 16)).to(device)
        self.rnn._parameters['weight_hh_l0'].data.mul_(self.mask)
        for word in ff_q:
            if not torch.all(word.eq(0)):
                #word = word / word.sum().item() # normalize
                word = word / torch.max(word) # normalize
                rnn_out, h_N = self.rnn(word, h_N)
            else:
                break
        output = self.out(rnn_out)
        return output
    
    def training_step(self, batch, batch_idx):
        # perform training
        X, y = batch
        X = torch.stack(X)
        pred = self(X)
        loss = self.loss_fn(pred, y)
        self.log("train_loss", loss)

        # Log learning rate
        lr = self.optimizers().param_groups[0]['lr']
        self.log('learning_rate', lr, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        # perform validation (using test dataset)
        X, y = batch
        X = torch.stack(X)
        out = self(X)
        loss = self.loss_fn(out, y)

        # calculate acc
        labels_hat = torch.argmax(out, dim=1)
        val_acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)

        # log the outputs
        self.log_dict({'val_loss': loss, 'val_acc': val_acc})
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]

# Define a model checkpoint callback to save the model every 5 epochs
checkpoint_callback = ModelCheckpoint(
    save_top_k=-1,               # Save all checkpoints (set to 1 to save only the best one)
    every_n_epochs=5,            # Save every 5 epochs
    filename='{epoch}-{val_loss:.2f}',  # Format to save the file name with epoch and validation loss
    verbose=True                 # Print information when saving
)

# Get cpu, gpu or mps device for training.
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

# create model
model = LitNeuralNetwork()
print(model)

# use lightning for training
trainer = L.Trainer(
    max_epochs=20,            # Number of epochs
    devices=1,                # Number of devices (GPUs/CPUs)
    accelerator='auto',       # Automatically select the device
    precision=32,             # Use mixed precision for faster training
    gradient_clip_val=0.5,    # Clip gradients -- unnecessary?
    callbacks=[checkpoint_callback]
)
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Using cuda device
LitNeuralNetwork(
  (ff): Sequential(
    (0): Linear(in_features=64, out_features=48, bias=False)
    (1): ReLU()
  )
  (rnn): RNN(48, 16, bias=False)
  (out): Sequential(
    (0): Linear(in_features=16, out_features=6, bias=False)
    (1): LogSoftmax(dim=1)
  )
  (loss_fn): NLLLoss()
)


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | ff      | Sequential | 3.1 K  | train
1 | rnn     | RNN        | 1.0 K  | train
2 | out     | Sequential | 96     | train
3 | loss_fn | NLLLoss    | 0      | train
-----------------------------------------------
4.2 K     Trainable params
0         Non-trainable params
4.2 K     Total params
0.017     Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<…

C:\Users\Liamr\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\Liamr\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=20` reached.


## Check Model and Continue Training

In [14]:
# check performance on test dataset

def test(dataloader, model):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    i = 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = torch.stack(X).to(device), y.to(device) #torch.FloatTensor(y).to(device)
            pred = model(X)
            test_loss += model.loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            print("Actual: {0} -- Inferred: {1}".format(y.item(), pred.argmax(1).item()))
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")

model.to(device)
test(test_dataloader, model)

Actual: 0 -- Inferred: 0
Actual: 0 -- Inferred: 0
Actual: 5 -- Inferred: 5
Actual: 3 -- Inferred: 3
Actual: 2 -- Inferred: 1
Actual: 3 -- Inferred: 5
Actual: 0 -- Inferred: 0
Actual: 0 -- Inferred: 0
Actual: 5 -- Inferred: 5
Actual: 0 -- Inferred: 0
Actual: 4 -- Inferred: 4
Actual: 0 -- Inferred: 0
Actual: 2 -- Inferred: 2
Actual: 4 -- Inferred: 4
Actual: 4 -- Inferred: 4
Actual: 3 -- Inferred: 3
Actual: 0 -- Inferred: 0
Actual: 1 -- Inferred: 1
Actual: 0 -- Inferred: 0
Actual: 5 -- Inferred: 5
Actual: 1 -- Inferred: 1
Actual: 3 -- Inferred: 3
Actual: 3 -- Inferred: 3
Actual: 1 -- Inferred: 1
Actual: 4 -- Inferred: 4
Actual: 1 -- Inferred: 1
Actual: 0 -- Inferred: 0
Actual: 0 -- Inferred: 0
Actual: 3 -- Inferred: 3
Actual: 5 -- Inferred: 5
Actual: 5 -- Inferred: 1
Actual: 4 -- Inferred: 4
Actual: 0 -- Inferred: 0
Actual: 1 -- Inferred: 1
Actual: 3 -- Inferred: 3
Actual: 5 -- Inferred: 5
Actual: 1 -- Inferred: 1
Actual: 1 -- Inferred: 1
Actual: 0 -- Inferred: 5
Actual: 3 -- Inferred: 3


In [9]:
# load model from checkpoint (lightning has automatic, most-recent-epoch checkpointing) and continue training

model = LitNeuralNetwork.load_from_checkpoint(cwd+"\\lightning_logs\\version_1\\checkpoints\\epoch=19-val_loss=1.01.ckpt")
trainer = L.Trainer(max_epochs=10)
trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params | Mode 
-----------------------------------------------
0 | ff      | Sequential | 3.1 K  | train
1 | rnn     | RNN        | 1.0 K  | train
2 | out     | Sequential | 96     | train
3 | loss_fn | NLLLoss    | 0      | train
-----------------------------------------------
4.2 K     Trainable params
0         Non-trainable params
4.2 K     Total params
0.017     Total estimated model params size (MB)


Sanity Checking: |                                                                               | 0/? [00:00<…

C:\Users\Liamr\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.
C:\Users\Liamr\AppData\Local\Programs\Python\Python311\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=15` in the `DataLoader` to improve performance.


Training: |                                                                                      | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

Validation: |                                                                                    | 0/? [00:00<…

`Trainer.fit` stopped: `max_epochs=10` reached.


In [12]:
# check all weights

param_list = [*model.parameters()]
i = 0
for lay in param_list:
    i += 1
    print("Layer {0}".format(i))
    print(lay.shape)
    print(lay)

Layer 1
torch.Size([48, 64])
Parameter containing:
tensor([[ 0.3191, -0.1007,  0.2364,  ...,  0.0308, -0.6877, -0.3279],
        [-0.3180, -0.0725,  0.1145,  ...,  0.0960, -0.7081, -0.1855],
        [ 0.2543,  0.0214,  0.3525,  ...,  0.0748, -0.0897,  0.1017],
        ...,
        [ 0.5771, -0.0987, -0.1120,  ..., -0.0190, -0.4856, -0.4656],
        [ 0.1736, -0.1623,  0.1439,  ...,  0.1375, -0.3094,  0.2778],
        [-0.5711, -0.0402, -0.2456,  ..., -0.1565,  0.0552, -0.5107]],
       device='cuda:0', requires_grad=True)
Layer 2
torch.Size([16, 48])
Parameter containing:
tensor([[ 0.6916,  0.6156,  0.8338,  0.7775,  0.1623,  0.7049,  0.6421,  0.1841,
          0.7779,  0.7551,  0.9258,  0.5336,  0.6808,  0.9846,  0.5271,  0.5804,
          0.9415,  0.5120, -0.0985,  0.8340,  0.7431,  0.7384,  0.7556,  1.1425,
          0.6219,  0.2811,  0.9587,  0.9397,  0.7397,  0.8261,  0.9777,  0.1242,
          0.9205, -0.7322, -0.1207,  1.2899,  0.5577,  1.1331,  0.2080,  0.2077,
          0.752

## Save Model

In [15]:
model_dir = "Recurrent ANN Models//"
model_name = "RANN_2.pth"

torch.save(model.state_dict(), model_dir + model_name)
print("Saved PyTorch Model State to " + model_dir + model_name)

Saved PyTorch Model State to Recurrent ANN Models//RANN_2.pth
