# Benchmarking Sleep Deprivation FF

## Imports

In [None]:
import ssl
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from FF import FF
from data import CIFAR10, MNIST, FMNIST, MergedDataset
try:
    _create_unverified_https_context = ssl._create_unverified_context
except AttributeError:
    pass
else:
    ssl._create_default_https_context = _create_unverified_https_context
device = "cuda" if torch.cuda.is_available() else "cpu"

## Data setup

In [None]:
batch_size_train = 512 # This batch size was used for the paper but a larger batch size might be beneficial for normal use.
batch_size_test = 10_000
base_dataset = FMNIST # This can be changed to CIFAR10, MNIST, or FMNIST
negative_data_with_masks_iteration = 5 # Number of convolution steps to apply to negative data masks.

# This caches the datasets in memory to speed up training.
train_dataset = base_dataset.get_in_memory(train=True)
train_negative_dataset = base_dataset.get_in_memory(train=True)
test_dataset = base_dataset.get_in_memory(train=False)

# Loaders for train and test data
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size_test, shuffle=False)

# Loader for masked negative data( has batch size of 1 to batch the mask application process )
train_loader_negative = torch.utils.data.DataLoader(MergedDataset(train_negative_dataset, negative_data_with_masks_iteration, batch_size_train, base_dataset.negative_data_with_masks), batch_size=1, shuffle=False)

Showing some of the images of the dataset

In [None]:
# get some random training images
dataiter = iter(train_loader)
images, labels = next(dataiter)

# show images
plt.imshow(np.transpose(torchvision.utils.make_grid(images[:4], padding=2, normalize=True).cpu(),(1,2,0)))
plt.show()

# print labels
print(' '.join('%5s' % labels[j].item() for j in range(4)))

Showing how an overlayed image looks like

In [None]:
image = next(iter(train_loader))[0][0]
image = base_dataset.overlay_y_on_x(image, 1)[0].reshape(base_dataset.get_image_shape())
plt.imshow(np.transpose(image, (1, 2, 0)))

Showing how a negative data image looks like

In [None]:
image = next(iter(train_loader_negative))[0].cpu()
image = image.reshape((-1, *base_dataset.get_image_shape()))[0]
plt.imshow(np.transpose(image, (1, 2, 0)))

## Model training

Here we set all parameters and create the model.
It uses an architecture of 
input_size -> 500 -> 500 -> 500

In [None]:
threshold = 1.5
epochs_per_layer = 50 # Represents how many updates per batch of training data a layer will get.

awake_period = 1 # The period that the model will be awake for. Setting it to -1 will make the model act without phase separation.
sleep_period = 1 # The period that the model will be asleep for.

assert awake_period > 0 or awake_period == -1, "awake_period must be greater than 0 or -1"
assert sleep_period > 0, "sleep_period must be greater than 0"

optim_lr = 0.001 # The learning rate for the optimizer when awake_period is set to -1.
positive_lr = 0.001 / awake_period # The learning rate for the optimizer when awake_period is set to a positive value, it will be divided by the awake_period as to scale the learning rate down.

negative_lr = 0.001 # The learning rate for the negative forward pass optimizer.

epochs = 500
hour = 0 # The figurative hour of the day that the model will start on. This is used to determine the phase of the model.
print_every = 1
with_masks = True # Toggles between using masks and not using masks.

model = FF(device=device)
model.to(device)

# 3 layers of 500 neurons each
model.add_layer(np.prod(base_dataset.get_image_shape()), 500, optim_lr, positive_lr, negative_lr, threshold, epochs_per_layer)
model.add_layer(500, 500, optim_lr, positive_lr, negative_lr, threshold, epochs_per_layer)
model.add_layer(500, 500, optim_lr, positive_lr, negative_lr, threshold, epochs_per_layer)

This is the main training loop.

In [None]:
def get_random_number_besides(x):
    """Returns a random number between 0 and 9 that is not x"""
    num = random.randint(0,9)
    if num==x: return get_random_number_besides(x)
    return num

def get_negative_y(y):
    """Returns a tensor of the same shape as y but with random numbers between 0 and 9 that are not the same as y"""
    return torch.tensor([get_random_number_besides(i) for i in y], dtype = torch.long).to(device)

accuracies = []

# Training loop
for i in tqdm(range(epochs), desc="Epochs"):

    # Prints every print_every epochs
    if i % print_every == 0:
        
        predictions, real = base_dataset.predict(test_loader, model, device)
        acc = np.sum(predictions == real)/len(real)
        print("Accuracy on test data: ", acc)
        accuracies.append(acc)
        

    model.train()
    
    if with_masks:
        
        for a, b in zip(enumerate(train_loader), enumerate(train_loader_negative)):
            x_pos = a[1][0].to(device)
            y = a[1][1].to(device)
            x_neg = b[1].reshape(batch_size_train, -1)
            if awake_period == -1:
              x_pos,_ = base_dataset.overlay_y_on_x(x_pos, y)
              x_neg,_ = base_dataset.overlay_y_on_x(x_neg, get_negative_y(y))
              a = model.forward(x_pos, x_neg)
            else:
              if hour % (awake_period + sleep_period) < awake_period:
                  x_pos,_ = base_dataset.overlay_y_on_x(x_pos, y)
                  loss = model.forward_positive(x_pos)
              else:
                  x_neg,_ = base_dataset.overlay_y_on_x(x_neg, get_negative_y(y))
                  loss = model.forward_negative(x_neg)
            hour += 1
    else:

        for a in enumerate(train_loader):
            x_pos = a[1][0].to(device)
            y = a[1][1].to(device)
            if awake_period == -1:
                x_pos,_ = base_dataset.overlay_y_on_x(x_pos, y)
                x_neg,_ = base_dataset.overlay_y_on_x(x_pos.clone(), get_negative_y(y))
                model.forward(x_pos, x_neg)
            else:
              if hour % (awake_period + sleep_period) < awake_period:
                  x_pos,_ = base_dataset.overlay_y_on_x(x_pos, y)
                  model.forward_positive(x_pos)
              else:
                  x_neg,_ = base_dataset.overlay_y_on_x(x_pos, get_negative_y(y))
                  model.forward_negative(x_neg)
            hour += 1

Plotting the test accuracy history

In [None]:
import matplotlib.pyplot as plt

plt.plot(accuracies)
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.title("Test accuracy of the model over time")
plt.show()