In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.models.resnet import ResNet, BasicBlock
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import wandb
from ff import FF, FFLayer, FFEncoder
from data import MNIST, MergedDataset, CIFAR10
from tqdm import tqdm


# Define the ResNet-50 model for CIFAR-10
class ResNetCIFAR10(ResNet):
    def __init__(self):
        super(ResNetCIFAR10, self).__init__(BasicBlock, [3, 4, 6, 3])
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)

# Define transforms for the dataset
transform_train = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load the CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root='./datasets/MNIST/', train=True, transform=transform_train, download=True)
test_dataset = datasets.CIFAR10(root='./datasets/MNIST/', train=False, transform=transform_test, download=True)

# Create data loaders for the dataset
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)

# Load the ResNet-50 model
resnet = ResNetCIFAR10()

for param in resnet.parameters():
    param.requires_grad = False

num_out = resnet.fc.in_features

# Move the model to the GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = resnet.to(device)

  from .autonotebook import tqdm as notebook_tqdm


Files already downloaded and verified
Files already downloaded and verified


In [2]:
squared_error = lambda x: x.pow(2).mean(1)
deviation_error = lambda x: -((x - x.mean(1).unsqueeze(1)).pow(2).mean(1))


In [3]:
threshold = 1.5
epochs_per_layer = 50
model = FF(logging=False, device=device)
optim_config = {
    "lr": 0.01,
}
positive_optim_config = {
    "lr": 0.001,

}
negative_optim_config = {
    "lr": 0.001,
}

goodness_function = squared_error
awake_period = 1
sleep_period = 1
print(num_out)
model.add_layer(FFLayer(nn.Linear(1_000, 1_000).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 1", device = device, goodness_function=goodness_function).to(device))
model.add_layer(FFLayer(nn.Linear(1_000, 1_000).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 2", device = device, goodness_function=goodness_function).to(device))
model.add_layer(FFLayer(nn.Linear(1_000, 1_000).to(device), optimizer=torch.optim.Adam, epochs=epochs_per_layer, threshold=threshold, activation=nn.ReLU(), optim_config=optim_config, positive_optim_config=positive_optim_config, negative_optim_config=negative_optim_config, logging=False, name="layer 3", device = device, goodness_function=goodness_function).to(device))

model.eval()
predictions, real = CIFAR10.predict_resnet(test_loader, model, resnet, device)
acc = np.sum(predictions == real)/len(real)
print(acc)

512
0.1015


In [4]:
model.train()
for _, (x, y) in tqdm(enumerate(train_loader)):
    x_pos, _ = CIFAR10.overlay_y_on_x(x, y)
    rnd = torch.randperm(x.size(0))
    x_pos = x_pos.reshape((-1, 3, 32, 32))
    x_neg, _ = CIFAR10.overlay_y_on_x(x, y[rnd])
    x_pos, x_neg = x_pos.to(device), x_neg.to(device)
    x_neg = x_neg.reshape((-1, 3, 32, 32))
    x_pos = resnet(x_pos)
    x_neg = resnet(x_neg)
    model.forward(x_pos, x_neg)

model.eval()
predictions, real = CIFAR10.predict_resnet(test_loader, model, resnet, device)
acc = np.sum(predictions == real)/len(real)
print(acc)

391it [01:51,  3.51it/s]


0.1016
