# Jupyter Notebook pour S2M2R

### Quelques éléments de rappel

La méthode S2M2R permet d'entraîner un backbone de façon efficace. Pour cela elle combine plusieurs ingrédients : self-supervision avec des rotations et manifold-mixup.

Dans le cadre de ce notebook, nous allons tenter de mettre à l'épreuve cette méthode sur le jeu de donnée CIFAR-FS, plus petit que miniImageNet, mais permettant d'entraîner plus vite.

Nous allons nous concentrer sur une architecture de type ResNet, car elles sont connues pour leur bonne capacité à généraliser.

### Importation des données

Tout d'abord il nous faut charger quelques bibliothèques utiles

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import random
import pickle
import time
import datetime
from torchvision import datasets, transforms
print(np.__version__)
print(torch.__version__)

save_path = "/home/tesbed/datasets/"

Ensuite il nous faut définir les "splits", c'est-à-dire quelles classes correspondent au base dataset, au val dataset et au novel dataset.

In [3]:
novel_labels = ["baby","bed","bicycle","chimpanzee","fox","leopard","man","pickup_truck","plain","poppy","rocket","rose","snail","sweet_pepper","table","telephone","wardrobe","whale","woman","worm"]
val_labels = ["otter","motorcycle","television","lamp","crocodile","shark","butterfly","beaver","beetle","tractor","flatfish","maple_tree","camel","crab","sea","cattle"]

Nous pouvons à présent charger les données, et identifier les trois datasets.

In [24]:
path = "/home/tesbed/datasets"

data_train = datasets.CIFAR100(path, train=True, download=True)
data_val = datasets.CIFAR100(path, train=False, download=True)

all_data = np.concatenate((data_train.data, data_val.data))
all_labels = np.concatenate((data_train.targets, data_val.targets))

novel_targets = [data_train.class_to_idx[label] for label in novel_labels]    
val_targets = [data_train.class_to_idx[label] for label in val_labels]
train_targets = [x for x in np.arange(100) if x not in novel_targets and x not in val_targets]

Files already downloaded and verified
Files already downloaded and verified


In [27]:
batch_size = 60

from torchvision.datasets import VisionDataset
from typing import Any, Callable, Optional, Tuple
from PIL import Image
class CIFAR(VisionDataset):
    def __init__(
            self,
            root : str,
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None
    ) -> None:

        super(CIFAR, self).__init__(root, transform=transform, target_transform=target_transform)
        
        self.data = all_data
        self.targets = all_labels

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return self.data.shape[0]

train_transforms = [        
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip()
    ] # used for standard data augmentation during training

standard_transforms = [
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ] # used for all data

transform_train = transforms.Compose(train_transforms + standard_transforms)
transform_all = transforms.Compose(standard_transforms)

train_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(CIFAR(path, transform=transform_train), np.nonzero(np.isin(all_labels,train_targets))[0]),
    batch_size=batch_size, shuffle=True, num_workers = 4)
val_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(CIFAR(path, transform=transform_all), np.nonzero(np.isin(all_labels,val_targets))[0]),
    batch_size=batch_size, shuffle=False, num_workers = 4)
novel_loader = torch.utils.data.DataLoader(
    torch.utils.data.Subset(CIFAR(path, transform=transform_all), np.nonzero(np.isin(all_labels,novel_targets))[0]),
    batch_size=batch_size, shuffle=False, num_workers = 4)


On peut vérifier la quantité de données dans chaque dataset.

In [28]:
print(batch_size * len(train_loader))
print(batch_size * len(val_loader))
print(batch_size * len(novel_loader))

38400
9600
12000


### Définition des architectures

Il est à présent temps de construire un réseau de neurones pour notre backbone. Il faut non seulement bien définir sa structure mais s'assurer qu'on peut récupérer l'avant dernière couche facilement. On va ici considérer un resnet18 ou un resnet20 avec un multiplicateur pour le nombre de feature maps.

In [6]:
# Un resnet est obtenu par assemblage de blocks, ici on considère des resnet très simples où on utilise uniquement des basicblocks
class BasicBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        return out
        
class BasicBlockWRN(nn.Module):
    def __init__(self, in_planes, out_planes, stride, drop_rate):
        super(BasicBlockWRN, self).__init__()
        self.bn1 = nn.BatchNorm2d(in_planes)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                               padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_planes)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_planes, out_planes, kernel_size=3, stride=1,
                               padding=1, bias=False)
        self.droprate = drop_rate
        self.equalInOut = (in_planes == out_planes)
        self.convShortcut = (not self.equalInOut) and nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride,
                               padding=0, bias=False) or None
    def forward(self, x):
        if not self.equalInOut:
            x = self.relu1(self.bn1(x))
        else:
            out = self.relu1(self.bn1(x))
        out = self.relu2(self.bn2(self.conv1(out if self.equalInOut else x)))
        if self.droprate > 0:
            out = F.dropout(out, p=self.droprate, training=self.training)
        out = self.conv2(out)
        return torch.add(x if self.equalInOut else self.convShortcut(x), out)

class NetworkBlock(nn.Module):
    def __init__(self, nb_layers, in_planes, out_planes, block, stride, drop_rate):
        super(NetworkBlock, self).__init__()
        self.layer = self._make_layer(block, in_planes, out_planes, nb_layers, stride, drop_rate)
    def _make_layer(self, block, in_planes, out_planes, nb_layers, stride, drop_rate):
        layers = []
        for i in range(int(nb_layers)):
            layers.append(block(i == 0 and in_planes or out_planes, out_planes, i == 0 and stride or 1, drop_rate))
        return nn.Sequential(*layers)
    def forward(self, x):
        return self.layer(x)

Pour pouvoir implémenter manifold-mixup, il faut modifier le calcul dans le réseau. On va commencer par introduire les fonctions d'interpolation :

In [7]:
def mixup_data(x, y, alpha=1.0, lam = None, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if lam == None:
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1    

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mix_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

Dorénavant on peut définir notre réseau :

In [8]:
# on peut à présent définir notre resnet
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, feature_maps, num_classes=100):
        super(ResNet, self).__init__()
        self.in_planes = feature_maps
        self.length = len(num_blocks)
        self.conv1 = nn.Conv2d(3, feature_maps, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(feature_maps)
        layers = []
        layers.append(self._make_layer(block, feature_maps, num_blocks[0], stride=1))
        layers.append(self._make_layer(block, 2*feature_maps, num_blocks[1], stride=2))
        layers.append(self._make_layer(block, 4*feature_maps, num_blocks[2], stride=2))
        layers.append(self._make_layer(block, 8*feature_maps, num_blocks[3], stride=2))
        self.layers = nn.Sequential(*layers)
        self.linear = nn.Linear(8*feature_maps*block.expansion, num_classes, bias=False)
        self.rotationLinear = nn.Linear(8*feature_maps*block.expansion, 4, bias=False)
        
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for i in range(len(strides)):
            stride = strides[i]
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x, mixup_target = None):
        if mixup_target is not None:
            mixup_layer = random.randint(0, len(self.layers))
        else:
            mixup_layer, target_a, target_b, lam = -1, None, None, None
        out = x
        if mixup_layer == 0:
            out, target_a, target_b, lam = mixup_data(out, mixup_target, lam = 0.4)
        out = F.relu(self.bn1(self.conv1(out)))
        for i in range(len(self.layers)):
            out = self.layers[i](out)
            if mixup_layer == i + 1:
                out, target_a, target_b, lam = mixup_data(out, mixup_target, lam = 0.4)
            out = F.relu(out)
        out = F.avg_pool2d(out, out.shape[2])
        features = out.view(out.size(0), -1)
        out = self.linear(features)
        out_rotation = self.rotationLinear(features)
        return out, out_rotation, features, target_a, target_b, lam
    
class WideResNet(nn.Module):
    def __init__(self, feature_maps, depth = 28, widen_factor = 10, num_classes = 100, drop_rate = 0.5):
        super(WideResNet, self).__init__()
        nChannels = [feature_maps, feature_maps*widen_factor, 2 * feature_maps*widen_factor, 4 * feature_maps*widen_factor]
        n = (depth - 4) / 6
        self.conv1 = nn.Conv2d(3, nChannels[0], kernel_size=3, stride=1, padding=1, bias=False)
        
        self.blocks = torch.nn.ModuleList()
        self.blocks.append(NetworkBlock(n, nChannels[0], nChannels[1], BasicBlockWRN, 1, drop_rate))
        self.blocks.append(NetworkBlock(n, nChannels[1], nChannels[2], BasicBlockWRN, 2, drop_rate))
        self.blocks.append(NetworkBlock(n, nChannels[2], nChannels[3], BasicBlockWRN, 2, drop_rate))
        self.bn = nn.BatchNorm2d(nChannels[3])
        self.linear = nn.Linear(nChannels[3], int(num_classes))
        self.rotationLinear = nn.Linear(nChannels[3], 4, bias=False)

    def forward(self, x, type_features = "post", mixup_target= None, lam = None):
        if mixup_target is not None:
            mixup_layer = random.randint(0,3)
        else:
            mixup_layer, target_a, target_b, lam = -1, None, None, None

        out = x
        if mixup_layer == 0:
            out, target_a, target_b, lam = mixup_data(out, mixup_target, lam=lam)
        
        out = self.conv1(out)
        for i in range(len(self.blocks)):
            out = self.blocks[i](out)
            if mixup_layer == i + 1:
                out, target_a, target_b, lam = mixup_data(out, mixup_target, lam=lam)
        if type_features == "pre":
            features = out
        out = torch.relu(self.bn(out))
        if type_features == "all":
            features = out
        out = F.avg_pool2d(out, out.size()[2:])
        out = out.view(out.size(0), -1)
        if type_features == "post":
            features = out
        out = self.linear(features)
        out_rotation = self.rotationLinear(features)
        return out, out_rotation, features, target_a, target_b, lam

On définit à présent un resnet particulier en choisissant le nombre de blocks et leur composition.

In [9]:
def ResNet20(feature_maps):
    return ResNet(BasicBlock, [3,3,3], feature_maps)

def ResNet18(feature_maps):
    return ResNet(BasicBlock, [2,2,2,2], feature_maps)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("We are running on " + str(device))

def generate_model(feature_maps, model_type):
    model = model_type(feature_maps)
    model = model.to(device)
    return model

We are running on cuda


### entraînement du modèle

Il est à présent temps d'entraîner notre modèle. Attention cela peut être très long !


On peut à présent écrire la fonction d'entraînement pour une époque :

In [11]:
# les arguments sont : le modèle, le device utilisé, le dataset d'entraînement, l'optimiseur utilisé, puis un paramètre permettant de décider d'utiliser ou non manifold-mixup, et un paramètre de force de la loss de classification par rotation
def train(model, device, train_loader, optimizer, classification = True, mm = False, rotations = True):
    model.train() # on prévient le modèle de passer en mode "train", ce qui est important notamment pour les BatchNorm
    accuracy = 0 # utilisé pour l'affichage
    total_loss = 0 # idem
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        if mm:
            output, _, _, target_a, target_b, lam = model(data, mixup_target = target)
            loss = mix_criterion(torch.nn.CrossEntropyLoss(), output, target_a, target_b, lam)
            loss.backward()
            total_loss += loss.item()

        if rotations:
            x0 = data
            x1 = x0.transpose(3,2).flip(2)
            x2 = x1.transpose(3,2).flip(2)
            x3 = x2.transpose(3,2).flip(2)
            split = data.shape[0] // 4
            x0, x1, x2, x3 = x0[:split], x1[split:2*split], x2[2*split:3*split], x3[3*split:]
        
            target_rotation = torch.tensor([0]*x0.shape[0] + [1]*x1.shape[0] + [2]*x2.shape[0] + [3]*x3.shape[0]).to(device)        
            data_rot = torch.cat([x0, x1, x2, x3], dim = 0)
            output, output_rotation, _, _, _, _ = model(data_rot)
            loss = torch.nn.CrossEntropyLoss()(output_rotation, target_rotation)
            loss.backward()          
            total_loss += loss.item()
        
        if classification:
            output, output_rotation, features, _, _, _ = model(data)
            loss = torch.nn.CrossEntropyLoss()(output, target)
            loss.backward()
            total_loss += loss.item()

        optimizer.step()
        
    return total_loss / (1+batch_idx)


Puis la méthodologie complète pour entraîner un réseau.

In [12]:
start_tick = time.time()
def train_model(model, epochs, lr, validate, val_loader, novel_loader, type_features = "post", classification = False, mm = True, rotations = True, validate_each_epoch = True):
    if lr > 0:
        optimizer = torch.optim.SGD(model.parameters(), lr = lr, momentum = 0.9, weight_decay = 5e-4)
    else:
        optimizer = torch.optim.Adam(model.parameters())
    best_score_val = 0
    for epoch in range(epochs):        
        loss = train(model, device, train_loader, optimizer, classification = classification, mm = mm, rotations = rotations)        
        print("\rEpoch: {:4d}, loss: {:.5f}".format(epoch, loss), end="")
        if validate_each_epoch:
            m, _ = validate(model, val_loader, type_features = type_features)
            print(" val score: {:.5f}".format(m), end="")
            if m > best_score_val:
                n, features = validate(model, novel_loader, type_features = type_features)
                print(" novel score: {:.5f}".format(n), end = "")
        new_time = time.time()
        spent_time = int(new_time - start_tick)
        print(" {:d}h{:02d}m{:02d}".format(spent_time // 3600, (spent_time % 3600) // 60, spent_time % 60), end ="")
    print()
    if not validate_each_epoch:
        n, features = validate(model, novel_loader, type_features = type_features)
    torch.save(features, save_path + datetime.datetime.now().strftime("%Y-%M-%d-%H-%M") + "_novel.pt")
    print("\nFinal score is: " + str(n))

### validation et choix des hyperparamètres

Théoriquement, il est possible de choisir les hyperparamètres du modèle en utilisant le val_dataset. Pour ce faire, il faut déterminer une mesure de qualité. Comme on a déjà programmé simpleshot, on va l'utiliser pour obtenir un score de généralisation.

Dans un premier temps on va récupérer toutes les données de validation passée au travers du backbone.

In [13]:
elements_per_class = 600 # nombre d'éléments par classe avec cifar-FS

def features_of_dataset(model, loader, type_features):
    model.eval()
    all_features = []
    all_labels = []
    for i, (data, target) in enumerate(loader):
        data = data.to(device)
        with torch.no_grad():
            _, _, features, _, _, _ = model(data, type_features = type_features)
        all_features += [features.cpu()]
        all_labels += target    
    all_features = torch.cat(all_features, dim = 0)
    all_labels = torch.stack(all_labels, 0)

    data = torch.zeros((0, elements_per_class, all_features.shape[1]))
    labels = all_labels.clone()
    while labels.shape[0] > 0:
        indices = torch.where(all_labels == labels[0])[0]
        data = torch.cat([data, all_features[indices,:].view(1, elements_per_class, -1)], dim = 0)
        indices = torch.where(labels != labels[0])[0]
        labels = labels[indices]
    model = model.to(device)
    return data 


À présent, on peut calculer un score à l'aide de l'algorithme simpleshot :

In [14]:
shuffle_indices = np.arange(elements_per_class)
def shuffle(data):
    global shuffle_indices    
    for i in range(data.shape[0]):
        shuffle_indices = np.random.permutation(shuffle_indices)
        data[i,:,:] = data[i,shuffle_indices,:]
        
# on peut à présent générer un run simplement
def generate_run(w, k, q, data):
    shuffle(data)
    classes = np.random.permutation(np.arange(data.shape[0]))[:w]
    dataset = data[classes,:k+q,:]
    return dataset

def stats(precisions):
    return np.mean(precisions), (np.std(precisions) * 1.96 / math.sqrt(len(precisions)))

def center(dataset, data):
    mean = torch.mean(data.view(-1,data.shape[-1]), dim=0)
    return dataset - mean.view(1, 1, -1)

def normalize(dataset):
    return dataset / torch.norm(dataset, p = 2, dim = 2, keepdim = True)

def ncm(dataset, k):
    means = dataset[:,:k,:].mean(dim=1)
    res = []
    for i in range(dataset.shape[0]):
        dist = torch.norm(dataset[i,k:,:] - means.view(means.shape[0], 1, means.shape[1]), dim = 2, p = 2)
        _, decisions = torch.min(dist, dim = 0)
        res.append((decisions == i).float().mean())
    return np.mean(res)

def simpleshot(dataset, data, k, no_feature_transforms = False):
    if not(no_feature_transforms):
        dataset = center(dataset, data)
        dataset = normalize(dataset)
    return ncm(dataset, k)

def perfs(w, k, q, runs, data, no_feature_transforms = False):
    precisions = []
    precisions_lr = []
    for i in range(runs):
        dataset = generate_run(w, k, q, data)
        precisions.append(simpleshot(dataset, data, k, no_feature_transforms = no_feature_transforms))
    mean, _ = stats(precisions)
    return mean

On peut assembler tout ça dans une fonction de validation très simple :

In [15]:
def validate(model, loader, type_features = "post", runs = 1000):
    data = features_of_dataset(model, loader, type_features)
    return(perfs(5, 5, 15, runs, data)), data

### Expériences

Comme indiqué précédemment, il peut être long d'entraîner un modèle (plusieurs heures). Et il vaut donc mieux utiliser les lignes qui suivent avec précaution.
Dans un premier temps, on va évaluer les performances sans rotations.

In [None]:
root = "/home/tesbed/datasets/miniimagenet/"

def load_datasets(root):
    datasets = {}
    class_index = 0
    for subset in ["train", "validation", "test"]:
        f = open(root + "mini-imagenet-cache-" + subset + ".pkl", "rb")
        dataset = pickle.load(f)
        data = dataset['image_data']
        target = torch.zeros(data.shape[0], dtype=int)
        for cl in dataset['class_dict'].keys():
            for elt in dataset['class_dict'][cl]:
                target[elt] = class_index
            class_index += 1
        datasets[subset] = [data, target]
    return datasets

class MiniImageNet(VisionDataset):
    def __init__(
            self,
            root : str,
            subset = "train",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None
    ) -> None:

        super(MiniImageNet, self).__init__(root, transform=transform, target_transform=target_transform)

        self.data: Any = []
        self.targets = []

        datasets = load_datasets(root)
        
        self.data = datasets[subset][0]
        self.targets = datasets[subset][1]

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], self.targets[index]

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        img = Image.fromarray(img)

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self) -> int:
        return self.data.shape[0]

In [None]:
from torchvision import datasets, transforms
batch_size = 64

from PIL import ImageEnhance

transformtypedict=dict(Brightness=ImageEnhance.Brightness, Contrast=ImageEnhance.Contrast, Sharpness=ImageEnhance.Sharpness, Color=ImageEnhance.Color)

class ImageJitter(object):
    def __init__(self, transformdict):
        self.transforms = [(transformtypedict[k], transformdict[k]) for k in transformdict]


    def __call__(self, img):
        out = img
        randtensor = torch.rand(len(self.transforms))

        for i, (transformer, alpha) in enumerate(self.transforms):
            r = alpha*(randtensor[i]*2.0 -1.0) + 1
            out = transformer(out).enhance(r).convert('RGB')

        return out

train_transforms = [   
    transforms.RandomResizedCrop(80),
    ImageJitter(dict(Brightness=0.4, Contrast=0.4, Color=0.4)),
    transforms.RandomHorizontalFlip()
    ] # used for standard data augmentation during training

test_transforms = [
    transforms.Resize([int(80*1.15), int(80*1.15)]),
    transforms.CenterCrop(80)
]

standard_transforms = [
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406) ,(0.229, 0.224, 0.225))
    ] # used for all data

transform_train = transforms.Compose(train_transforms + standard_transforms)
transform_all = transforms.Compose(test_transforms + standard_transforms)

train_loader = torch.utils.data.DataLoader(
    MiniImageNet(root, subset="train", transform=transform_train),
    batch_size=batch_size, shuffle=True, num_workers = 4)

val_loader = torch.utils.data.DataLoader(
    MiniImageNet(root, subset="validation", transform=transform_all),
    batch_size=batch_size, shuffle=False, num_workers = 4)

novel_loader = torch.utils.data.DataLoader(
    MiniImageNet(root, subset="test", transform=transform_all),
    batch_size=batch_size, shuffle=False, num_workers = 4)

In [None]:
feature_maps = 16
epochs = 100

for type_features in ["post", "all", "pre"]:
    print(type_features)
    model = WideResNet(16).to(device)
    for lr, valid in [(0.1, False), (0.01, False), (0.001, False), (0.0001, True)]:
        train_model(model, epochs, lr, validate, val_loader, novel_loader, type_features = type_features, validate_each_epoch = valid)