In [2]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms

from torchvision.models.resnet import BasicBlock

from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision.utils import make_grid

from tqdm import tqdm

import matplotlib.pyplot as plt

from structures import LateralInhibition, LIBlock


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"using {device}")

using cuda


## Load and normalize ImageNet

In [3]:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)

TRAIN_NORMALIZE = transforms.Compose(
    [
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN.tolist(), std=IMAGENET_STD.tolist()),
        transforms.RandomHorizontalFlip(),
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0), ratio=(0.75, 1.33)),
    ]
)

TEST_NORMALIZE = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=IMAGENET_MEAN.tolist(), std=IMAGENET_STD.tolist()),
    ]
)

def deprocess(img):
    transform = transforms.Compose(
        [
            transforms.Normalize(mean=[0, 0, 0], std=(1.0 / IMAGENET_STD).tolist()),
            transforms.Normalize(mean=(-IMAGENET_MEAN).tolist(), std=[1, 1, 1]),
            transforms.ToPILImage(),
        ]
    )
    return transform(img)


def load_datas(batch_size=128):
    train_dataset = ImageFolder(root="imagenet-mini/train", transform=TRAIN_NORMALIZE)
    test_dataset = ImageFolder(root="imagenet-mini/val", transform=TEST_NORMALIZE)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=8, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=8, shuffle=True)

    return train_loader, test_loader

train_loader, test_loader = load_datas()
TRAIN_SIZE, TEST_SIZE = len(train_loader.dataset), len(test_loader.dataset)

print(f"train dataset: {TRAIN_SIZE}, test_dataset: {TEST_SIZE}")

train dataset: 34745, test_dataset: 3923


## Utils

In [16]:
def train_model(model, max_epochs=10, batch_accumulation=2, eval_freq=2, comment=""):
    writer = SummaryWriter(comment=comment)

    model.train()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(
        model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

    for epoch in range(max_epochs):
        total_loss, correct = 0, 0

        for batch_idx, (images, labels) in enumerate(tqdm(train_loader)):
            images, labels = images.to(device), labels.to(device)

            output = model(images)
            pred = output.argmax(dim=1)

            loss = criterion(output, labels) / batch_accumulation
            total_loss += loss

            correct += torch.sum(labels == pred).sum().item()

            loss.backward()

            if (batch_idx + 1) % batch_accumulation:
                optimizer.step()
                optimizer.zero_grad()

        if scheduler:
            scheduler.step()

        total_loss /= TRAIN_SIZE
        accuracy = correct / TRAIN_SIZE

        writer.add_scalar("Train/loss", total_loss, epoch)
        writer.add_scalar("Train/accuracy", accuracy, epoch)

        print(
            f"[{epoch + 1:2d}/{max_epochs}] loss_train: {total_loss:.2E} accuracy_train: {accuracy:.2%}"
        )

        if not epoch % eval_freq:
            continue

        with torch.no_grad():
            test_loss, correct = 0, 0

            for images, labels in test_loader:
                images, labels = images.to(device), labels.to(device)

                output = model(images)
                pred = output.argmax(dim=1)

                loss = criterion(output, labels)
                test_loss += loss

                correct += torch.sum(labels == pred).sum().item()

            test_loss /= TEST_SIZE
            accuracy = correct / TEST_SIZE

            writer.add_scalar("Test/loss", test_loss, epoch)
            writer.add_scalar("Test/accuracy", accuracy, epoch)

        print(f"loss_test: {total_loss:.2E} accuracy_test: {accuracy:.2%}")

        if hasattr(model, "log"):
            W = torch.concatenate([t.flatten() for t in model.log["W"]])
            m = torch.tensor([t.item() for t in model.log["m"]])
            v = torch.tensor([t.item() for t in model.log["v"]])
            b = torch.tensor([t.item() for t in model.log["b"]])

            print("LI-layers params (avg, min, max):")
            print(f"* W: ({W.mean():.2f}, {W.min():.2f}, {W.max():.2f})")
            print(f"* m: ({m.mean():.2f}, {m.min():.2f}, {b.max():.2f})")
            print(f"* v: ({v.mean():.2f}, {v.min():.2f}, {v.max():.2f})")
            print(f"* b: ({b.mean():.2f}, {b.min():.2f}, {b.max():.2f})")


def evaluate(model, data_loader, mode="train"):
    correct = 0
    for _, (images, labels) in enumerate(data_loader):
        images, labels = images.to(device), labels.to(device)
        pred = torch.argmax(model(images), dim=1)
        correct += torch.sum(labels == pred).sum().item()

    print(f"{mode} accuracy: {correct / len(data_loader.dataset):.2%}")


## AlexNet

In [5]:
train_loader, test_loader = load_datas(batch_size=128)

### Baseline

In [9]:
alexnet = models.alexnet(weights="DEFAULT").to(device)
alexnet.eval()

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [28]:
evaluate(alexnet, train_loader, mode="train")

train accuracy: 53.05%


In [48]:
evaluate(alexnet, test_loader, mode="test")

test accuracy: 52.23%


In [13]:
alexnet = models.alexnet(weights="DEFAULT").to(device)
train_model(alexnet, 10, 2, 2, f"ALEXNET")

100%|██████████| 272/272 [00:32<00:00,  8.30it/s]


[ 1/10] loss_train: 1.15E-02 accuracy_train: 37.27%


100%|██████████| 272/272 [00:32<00:00,  8.30it/s]

[ 2/10] loss_train: 1.01E-02 accuracy_train: 42.92%





loss_test: 1.01E-02 accuracy_test: 37.45%


100%|██████████| 272/272 [00:31<00:00,  8.69it/s]


[ 3/10] loss_train: 9.50E-03 accuracy_train: 45.72%


100%|██████████| 272/272 [00:32<00:00,  8.25it/s]

[ 4/10] loss_train: 8.76E-03 accuracy_train: 49.15%





loss_test: 8.76E-03 accuracy_test: 34.92%


100%|██████████| 272/272 [00:32<00:00,  8.47it/s]


[ 5/10] loss_train: 8.18E-03 accuracy_train: 51.58%


100%|██████████| 272/272 [00:32<00:00,  8.38it/s]

[ 6/10] loss_train: 7.78E-03 accuracy_train: 53.63%





loss_test: 7.78E-03 accuracy_test: 33.34%


100%|██████████| 272/272 [00:32<00:00,  8.35it/s]


[ 7/10] loss_train: 7.37E-03 accuracy_train: 55.69%


100%|██████████| 272/272 [00:32<00:00,  8.33it/s]

[ 8/10] loss_train: 7.14E-03 accuracy_train: 57.23%





loss_test: 7.14E-03 accuracy_test: 32.32%


100%|██████████| 272/272 [00:33<00:00,  8.19it/s]


[ 9/10] loss_train: 6.68E-03 accuracy_train: 59.71%


100%|██████████| 272/272 [00:32<00:00,  8.37it/s]

[10/10] loss_train: 6.57E-03 accuracy_train: 60.54%





loss_test: 6.57E-03 accuracy_test: 31.68%


### Alexnet+LI

In [6]:
class AlexnetLI(nn.Module):
    def __init__(self, weights="DEFAULT", li_weights="zeros", freeze=False):
        super(AlexnetLI, self).__init__()

        alexnet = models.alexnet(weights=weights)

        if freeze:
            for param in alexnet.parameters():
                param.require_grad = False

        self.log = {"W": [], "m": [], "v": [], "b": []}

        # Rebuild alexnet features, by adding a LI layer after each convolutions'
        # activation function
        features = list(alexnet.features.children())
        new_features = []

        for i, l in enumerate(features):
            new_features.append(l)
            if isinstance(l, nn.ReLU):
                li = LateralInhibition(features[i - 1].out_channels, weights=li_weights)
                new_features.append(li)
                self.log["W"].append(li.weights)
                self.log["m"].append(li.m)
                self.log["v"].append(li.v)
                self.log["b"].append(li.b)

        self.features = nn.Sequential(*new_features)

        # Copy all the non-convolutional parts of AlexNet
        self.avg_pool = alexnet.avgpool
        self.classifier = alexnet.classifier

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x


In [17]:
alexnetLI = AlexnetLI().to(device)
train_model(alexnetLI, 10, 2, 2, f"ALEXNET_LI 256 log")

100%|██████████| 272/272 [00:52<00:00,  5.18it/s]


[ 1/10] loss_train: 1.12E-02 accuracy_train: 38.42%


100%|██████████| 272/272 [00:51<00:00,  5.23it/s]

[ 2/10] loss_train: 1.01E-02 accuracy_train: 43.23%





loss_test: 1.01E-02 accuracy_test: 36.45%
LI layers params (avg, min, max) :
* W: (0.00, -0.02, 0.10)
* m: (-0.09, -0.16, -0.07)
* v: (0.96, 0.73, 1.11)
* b: (-0.10, -0.16, -0.07)


100%|██████████| 272/272 [00:52<00:00,  5.20it/s]


[ 3/10] loss_train: 9.32E-03 accuracy_train: 46.33%


100%|██████████| 272/272 [00:52<00:00,  5.19it/s]

[ 4/10] loss_train: 8.55E-03 accuracy_train: 50.21%





loss_test: 8.55E-03 accuracy_test: 36.04%
LI layers params (avg, min, max) :
* W: (0.00, -0.03, 0.15)
* m: (-0.11, -0.24, -0.09)
* v: (0.97, 0.63, 1.18)
* b: (-0.13, -0.23, -0.09)


100%|██████████| 272/272 [00:52<00:00,  5.14it/s]


[ 5/10] loss_train: 8.08E-03 accuracy_train: 52.36%


100%|██████████| 272/272 [00:51<00:00,  5.23it/s]

[ 6/10] loss_train: 7.56E-03 accuracy_train: 54.88%





loss_test: 7.56E-03 accuracy_test: 34.44%
LI layers params (avg, min, max) :
* W: (0.00, -0.04, 0.20)
* m: (-0.15, -0.28, -0.13)
* v: (0.98, 0.60, 1.25)
* b: (-0.17, -0.27, -0.13)


100%|██████████| 272/272 [00:52<00:00,  5.20it/s]


[ 7/10] loss_train: 7.16E-03 accuracy_train: 57.12%


100%|██████████| 272/272 [00:51<00:00,  5.31it/s]


[ 8/10] loss_train: 6.70E-03 accuracy_train: 59.29%
loss_test: 6.70E-03 accuracy_test: 31.76%
LI layers params (avg, min, max) :
* W: (0.00, -0.05, 0.22)
* m: (-0.16, -0.30, -0.13)
* v: (0.98, 0.57, 1.29)
* b: (-0.19, -0.29, -0.13)


100%|██████████| 272/272 [00:51<00:00,  5.26it/s]


[ 9/10] loss_train: 6.41E-03 accuracy_train: 61.46%


100%|██████████| 272/272 [00:52<00:00,  5.19it/s]

[10/10] loss_train: 6.11E-03 accuracy_train: 62.84%





loss_test: 6.11E-03 accuracy_test: 30.89%
LI layers params (avg, min, max) :
* W: (0.00, -0.05, 0.25)
* m: (-0.18, -0.35, -0.15)
* v: (0.97, 0.52, 1.31)
* b: (-0.21, -0.33, -0.15)


### Alexnet+BatchNorm

In [24]:
class AlexnetBatchNorm(nn.Module):
    def __init__(self, weights="DEFAULT"):
        super(AlexnetBatchNorm, self).__init__()

        alexnet = models.alexnet(weights=weights)

        # Rebuild alexnet features, by adding a BatchNorm layer after each convolutions'
        # activation function
        features = list(alexnet.features.children())
        new_features = []

        for i, l in enumerate(features):
            new_features.append(l)
            if isinstance(l, nn.ReLU):
                new_features.append(nn.BatchNorm2d(features[i-1].out_channels))

        self.features = nn.Sequential(*new_features)

        # Copy all the non-convolutional parts of AlexNet
        self.avg_pool = alexnet.avgpool
        self.classifier = alexnet.classifier

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [28]:
alexnetBN = AlexnetBatchNorm().to(device)
train_model(alexnetBN, 10, 2, 2, "ALEXNET_BatchNorm 256")

100%|██████████| 272/272 [00:39<00:00,  6.92it/s]


[ 1/10] loss_train: 1.28E-02 accuracy_train: 32.65%


100%|██████████| 272/272 [00:36<00:00,  7.40it/s]

[ 2/10] loss_train: 1.03E-02 accuracy_train: 42.07%





loss_test: 1.03E-02 accuracy_test: 35.15%


100%|██████████| 272/272 [00:37<00:00,  7.23it/s]


[ 3/10] loss_train: 9.16E-03 accuracy_train: 47.54%


100%|██████████| 272/272 [00:34<00:00,  8.00it/s]

[ 4/10] loss_train: 8.37E-03 accuracy_train: 51.39%





loss_test: 8.37E-03 accuracy_test: 34.49%


100%|██████████| 272/272 [00:35<00:00,  7.73it/s]


[ 5/10] loss_train: 7.65E-03 accuracy_train: 54.78%


100%|██████████| 272/272 [00:34<00:00,  7.79it/s]

[ 6/10] loss_train: 7.15E-03 accuracy_train: 57.26%





loss_test: 7.15E-03 accuracy_test: 34.72%


100%|██████████| 272/272 [00:35<00:00,  7.72it/s]


[ 7/10] loss_train: 6.74E-03 accuracy_train: 59.41%


100%|██████████| 272/272 [00:34<00:00,  7.79it/s]

[ 8/10] loss_train: 6.39E-03 accuracy_train: 61.51%





loss_test: 6.39E-03 accuracy_test: 33.32%


100%|██████████| 272/272 [00:35<00:00,  7.63it/s]


[ 9/10] loss_train: 5.97E-03 accuracy_train: 64.01%


100%|██████████| 272/272 [00:35<00:00,  7.65it/s]

[10/10] loss_train: 5.77E-03 accuracy_train: 65.23%





loss_test: 5.77E-03 accuracy_test: 32.09%


### Alexnet+GroupNorm

In [26]:
class AlexnetGroupNorm(nn.Module):
    def __init__(self, weights="DEFAULT"):
        super(AlexnetGroupNorm, self).__init__()

        alexnet = models.alexnet(weights=weights)

        # Rebuild alexnet features, by adding a LayerNorm layer after each convolutions'
        # activation function
        features = list(alexnet.features.children())
        new_features = []
        
        for i, l in enumerate(features):
            new_features.append(l)
            if isinstance(l, nn.ReLU):
                new_features.append(nn.GroupNorm(1, features[i-1].out_channels))

        self.features = nn.Sequential(*new_features)

        # Copy all the non-convolutional parts of AlexNet
        self.avg_pool = alexnet.avgpool
        self.classifier = alexnet.classifier

    def forward(self, x):
        x = self.features(x)
        x = self.avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)
        return x

In [29]:
alexnetGN = AlexnetGroupNorm().to(device)
train_model(alexnetGN, 10, 2, 2, "ALEXNET_GroupNorm 256")

100%|██████████| 272/272 [00:40<00:00,  6.70it/s]


[ 1/10] loss_train: 1.29E-02 accuracy_train: 32.49%


100%|██████████| 272/272 [00:40<00:00,  6.75it/s]

[ 2/10] loss_train: 1.04E-02 accuracy_train: 41.79%





loss_test: 1.04E-02 accuracy_test: 37.39%


100%|██████████| 272/272 [00:41<00:00,  6.62it/s]


[ 3/10] loss_train: 9.26E-03 accuracy_train: 46.92%


100%|██████████| 272/272 [00:43<00:00,  6.24it/s]

[ 4/10] loss_train: 8.60E-03 accuracy_train: 50.08%





loss_test: 8.60E-03 accuracy_test: 35.36%


100%|██████████| 272/272 [00:40<00:00,  6.76it/s]


[ 5/10] loss_train: 7.95E-03 accuracy_train: 53.60%


100%|██████████| 272/272 [00:40<00:00,  6.69it/s]

[ 6/10] loss_train: 7.30E-03 accuracy_train: 56.55%





loss_test: 7.30E-03 accuracy_test: 35.59%


100%|██████████| 272/272 [00:40<00:00,  6.76it/s]


[ 7/10] loss_train: 6.90E-03 accuracy_train: 58.89%


100%|██████████| 272/272 [00:39<00:00,  6.83it/s]

[ 8/10] loss_train: 6.42E-03 accuracy_train: 61.42%





loss_test: 6.42E-03 accuracy_test: 34.31%


100%|██████████| 272/272 [00:40<00:00,  6.73it/s]


[ 9/10] loss_train: 6.17E-03 accuracy_train: 62.96%


100%|██████████| 272/272 [00:40<00:00,  6.75it/s]

[10/10] loss_train: 5.73E-03 accuracy_train: 65.26%





loss_test: 5.73E-03 accuracy_test: 32.04%


## Resnet

In [6]:
train_loader, test_loader = load_datas(batch_size=64)

### Baseline

In [7]:
resnet18 = models.resnet18(weights="DEFAULT").to(device)
resnet18.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [23]:
evaluate(resnet18, train_loader, mode="train")

train accuracy: 65.26%


In [22]:
evaluate(resnet18, test_loader, mode="test")

test accuracy: 66.51%


### Resnet+LI

In [20]:
class ResnetLI(nn.Module):
    def __init__(self, weights="DEFAULT", freeze=False):
        super(ResnetLI, self).__init__() 
        
        resnet = models.resnet18(weights=weights)

        if freeze:
            for param in resnet.parameters():
                param.require_grad = False

        self.log = {"W": [], "m": [], "v": [], "b": []}
        
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.li1 = LateralInhibition(self.conv1.out_channels)
        self.maxpool = resnet.maxpool

        self.layer1 = self.convert_layer_blocks(resnet.layer1)
        self.layer2 = self.convert_layer_blocks(resnet.layer2)
        self.layer3 = self.convert_layer_blocks(resnet.layer3)
        self.layer4 = self.convert_layer_blocks(resnet.layer4)
        
        # Copy all the non-convolutional parts of ResNet
        self.avgpool = resnet.avgpool
        self.fc = resnet.fc

        # Add the "plain" li params to the log
        self.log["W"].append(self.li1.weights)
        self.log["m"].append(self.li1.m)
        self.log["v"].append(self.li1.v)
        self.log["b"].append(self.li1.b)
    
    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.li1(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)

        return x
    
    def convert_layer_blocks(self, layer: nn.Sequential):
        new_layer = []

        for l in layer:
            if isinstance(l, BasicBlock):
                liblock = LIBlock(l)
                new_layer.append(liblock)
                
                self.log["W"].append(self.liblock.li.weights)
                self.log["m"].append(self.liblock.li.m)
                self.log["v"].append(self.liblock.li.v)
                self.log["b"].append(self.liblock.li.b)
            else:
                new_layer.append(l)

        return nn.Sequential(*new_layer)

In [None]:
resnetLI = ResnetLI().to(device)
train_model(resnetLI, 10, 2, 2, "RESNET_LI")

### Testing

In [None]:
batch = enumerate(train_loader)
idx, (image, label) = next(batch)

image, label = image.to(device), label.to(device)

LI = LateralInhibition().cuda()
output = LI(image)

In [None]:
idx = 13
img = deprocess(image[idx])
display(img)

li_img = deprocess(output[idx])
display(li_img)