In [12]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
from ff import FF, FFLayer
from data import MNIST
from tqdm import tqdm

In [13]:
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [14]:
batch_size_train = 512
batch_size_test = 512

In [15]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

In [16]:
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [17]:
squared_error = lambda x: x.pow(2).mean(1)
deviation_error = lambda x: -((x - x.mean(1).unsqueeze(1)).pow(2).mean(1))
mean_error = lambda x: x.mean(1)
#samples sin at 15 degree intervals and computes the mean squared error
sin_error = lambda x: -(((torch.sin(torch.FloatTensor([i * 2 * np.pi / 15 for i in range(x.shape[1])]).to(device)).to(device) - x).pow(2)).mean(1))

In [18]:
threshold = 1.5
epochs_per_layer = 5
model = FF(logging=False, device=device)
optim_config = {
    "lr": 0.01,
}
positive_optim_config = {
    "lr": 0.001,

}
negative_optim_config = {
    "lr": 0.001,
}

goodness_function = squared_error
awake_period = 1
sleep_period = 1

model.add_layer(FFLayer(nn.Linear(784, 500).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 1", device = device, goodness_function=goodness_function).to(device))
model.add_layer(FFLayer(nn.Linear(500, 500).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 2", device = device, goodness_function=goodness_function).to(device))
model.add_layer(FFLayer(nn.Linear(500, 500).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 3", device = device, goodness_function=goodness_function).to(device))

# Training

In [19]:
wandb.init(project="MNIST", entity="ffalgo", name="MNIST-FF-3L-500-500-500-relu-awake-1-sleep-1-epochs-5-threshold-1.5-lr-0.01-0.001-0.001-squared_error")
wandb.config = {
  "learning_rate": 0.01,
  "awake_period": awake_period,
  "sleep_period": sleep_period,
  "epochs_per_layer": epochs_per_layer,
  "batch_size": 32,
  "activation": "relu",
  "positive_lr": 0.005,
  "negative_lr": 0.01,
  "threshold": threshold,
  "optimizer": torch.optim.Adam,
  "device": device
}

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrd211[0m ([33mffalgo[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [20]:
model = model.to(device)
epochs = 25
best_acc = 0.0
hour = 0
for i in tqdm(range(epochs)):
    if i % 10 == 0:
        predictions, real = MNIST.predict(test_loader, model, device)
        acc = np.sum(predictions == real)/len(real)
        wandb.log({"Accuracy on test data": acc})
        if acc > best_acc:
            best_acc = acc
            torch.save(model.state_dict(), 'best_mnist.ph')
        
    predictions, real = MNIST.predict(train_loader, model, device)
    acc = np.sum(predictions == real)/len(real)
    wandb.log({"Accuracy on train data": acc})
    model.train()
    for _, (x, y) in enumerate(train_loader):
        x_pos, _ = MNIST.overlay_y_on_x(x, y)
        rnd = torch.randperm(x.size(0))
        x_neg, _ = MNIST.overlay_y_on_x(x, y[rnd])
        x_pos, x_neg = x_pos.to(device), x_neg.to(device)
        if hour % (awake_period + sleep_period) < awake_period:
            model.forward_positive(x_pos)
        else:
            model.forward_negative(x_neg)
        
        hour += 1

        
wandb.finish()

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 25/25 [13:28<00:00, 32.32s/it]


0,1
Accuracy on test data,▁▇█
Accuracy on train data,▁▅▄▃▂▄▅▅▆▆▆▆▆▆▆▇▇▇▇▇▇▇███

0,1
Accuracy on test data,0.5746
Accuracy on train data,0.6316


Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/threading.py", line 932, in _bootstrap_inner
    self.run()
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/wandb/sdk/service/server_sock.py", line 100, in run
    sreq = self._sock_client.read_server_request()
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 274, in read_server_request
    data = self._read_packet_bytes()
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 248, in _read_packet_bytes
    rec = self._extract_packet_bytes()
  File "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/wandb/sdk/lib/sock_client.py", line 230, in _extract_packet_bytes
    assert magic == ord("W")
AssertionError
Exception in thread SockSrvRdThr:
Traceback (most recent call last):
  File "/anaconda/envs/azureml_py38_PT_TF/lib/pyth