In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
from ff import FF, FFLayer, FFEncoder
from data import MNIST, MergedDataset
from tqdm import tqdm
# pip install lion-pytorch
from lion_pytorch import Lion


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# device = "mps" if torch.backends.mps.is_available() else "cpu"
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
batch_size_train = 512
batch_size_test = 512

In [4]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)
train_loader_negative = torch.utils.data.DataLoader(MergedDataset(torchvision.datasets.MNIST('./datasets/MNIST/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ]))), batch_size=batch_size_train, shuffle=True)

In [5]:
test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('./datasets/MNIST/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)
  

In [6]:
squared_error = lambda x: x.pow(2).mean(1)
deviation_error = lambda x: -((x - x.mean(1).unsqueeze(1)).pow(2).mean(1))


In [7]:
threshold = 1.5
epochs_per_layer = 50
model = FF(logging=False, device=device)
optim_config = {
    "lr": 0.01,
}
positive_optim_config = {
    "lr": 0.001,

}
negative_optim_config = {
    "lr": 0.001,
}

goodness_function = squared_error
awake_period = 1
sleep_period = 1

model.add_layer(FFLayer(nn.Linear(784, 500).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 1", device = device, goodness_function=goodness_function).to(device))
model.add_layer(FFLayer(nn.Linear(500, 500).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 2", device = device, goodness_function=goodness_function).to(device))
model.add_layer(FFLayer(nn.Linear(500, 500).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 3", device = device, goodness_function=goodness_function).to(device))

for a, b in zip(enumerate(train_loader), enumerate(train_loader_negative)):
    _, (x,y) = a
    _, x2 = b
    x_pos, _ = MNIST.overlay_y_on_x(x, y)

    rnd = torch.randperm(x.size(0))
    x_neg = x2
    print(x_neg.shape)
    print(x_pos.shape)
    x_pos, x_neg = x_pos.to(device), x_neg.to(device)
    model.forward(x_pos, x_neg)

predictions, real = MNIST.predict(test_loader, model, device)
acc = np.sum(predictions == real)/len(real)
print(acc)

torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size([512, 784])
torch.Size([512, 784]) torch.Size(

# Training

In [8]:
wandb.init(project="MNIST", entity="ffalgo", name="Pretrained-model-1-awake-1-sleep-0.001")
wandb.config = {
  "learning_rate": 0.01,
  "awake_period": awake_period,
  "sleep_period": sleep_period,
  "epochs_per_layer": epochs_per_layer,
  "batch_size": 512,
  "activation": "relu",
  "positive_lr": 0.001,
  "negative_lr": 0.001,
  "threshold": threshold,
  "optimizer": torch.optim.Adam,
  "device": device
}

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mrd211[0m ([33mffalgo[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
model = model.to(device)
epochs = 500
best_acc = 0.0
hour = 0
for i in tqdm(range(epochs)):
    if i % 1 == 0:
        predictions, real = MNIST.predict(test_loader, model, device)
        acc = np.sum(predictions == real)/len(real)
        wandb.log({"Accuracy on test data": acc})
        if acc > best_acc and acc > 0.8:
            best_acc = acc
            # torch.save(model.state_dict(), 'best_mnist_80%.ph')
        
    predictions, real = MNIST.predict(train_loader, model, device)
    acc = np.sum(predictions == real)/len(real)
    wandb.log({"Accuracy on train data": acc})
    model.train()
    for _, (x, y) in enumerate(train_loader):
        x_pos, _ = MNIST.overlay_y_on_x(x, y)
        rnd = torch.randperm(x.size(0))
        x_neg, _ = MNIST.overlay_y_on_x(x, y[rnd])
        x_pos, x_neg = x_pos.to(device), x_neg.to(device)
        if hour % (awake_period + sleep_period) < awake_period:
            model.forward_positive(x_pos)
        else:
            model.forward_negative(x_neg)
        # model.forward(x_pos, x_neg)
        
        hour += 1

        
wandb.finish()

  0%|          | 1/500 [01:26<12:03:07, 86.95s/it]


KeyboardInterrupt: 