# MobileNet v2 with CIFAR10

## Librerías

In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm

from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torchinfo import summary

import torchvision
from torchvision import models
import torchvision.transforms as transforms

import torchattacks

from utils.evaluation import NormalizationLayer, get_topk_accuracy
from utils.evaluation import plot_adversarial, get_same_predictions, get_different_predictions
from utils.mobilenetv2 import build_mobilenet_v2
from utils.training import train

import warnings
warnings.filterwarnings('ignore')

In [2]:
# Reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
device 

device(type='cuda', index=0)

## Modelo
Usaré una versión modificada de MobileNet v2 para trabajar con CIFAR10, ver [PyTorch models trained on CIFAR-10 dataset](https://github.com/huyvnphan/PyTorch_CIFAR10). De hecho ya hay una versión pre-entrenada, pero la versión que usa de PyTorch no es compatible con TorchAttacks, de modo que lo más sencillo es entrenarla de cero. La salida tiene puntuaciones no normalizadas, para obtener probabilidades hay que ejecutar un softmax en la salida.

Por la forma en que funciona [Adversarial-Attacks-PyTorch](https://github.com/Harry24k/adversarial-attacks-pytorch) las imágenes de entrada que se le tienen que pasar deben estar en el rango [0,1], pero los modelos pre-entrenados de PyTorch esperan imágenes normalizadas, las cuáles no están en el [0,1]. La forma de resolver ésto es añadiendo una capa de normalización al inicio. Ver [Demo - White Box Attack (Imagenet)](https://nbviewer.jupyter.org/github/Harry24k/adversarial-attacks-pytorch/blob/master/demos/White%20Box%20Attack%20%28ImageNet%29.ipynb) para un ejemplo con los modelos entrenados en ImageNet.

Lo único que cambia es que las medias y std serán diferentes, ver [How to use models](https://github.com/huyvnphan/PyTorch_CIFAR10#how-to-use-pretrained-models).

In [4]:
mobilenet_v2 = nn.Sequential(
    NormalizationLayer(mean=[0.4914, 0.4822, 0.4465], std=[0.2471, 0.2435, 0.2616]),
    build_mobilenet_v2(pretrained=False))

In [5]:
summary(mobilenet_v2)

Layer (type:depth-idx)                             Param #
Sequential                                         --
├─NormalizationLayer: 1-1                          --
├─MobileNetV2: 1-2                                 --
│    └─Sequential: 2-1                             --
│    │    └─ConvBNReLU: 3-1                        928
│    │    └─InvertedResidual: 3-2                  896
│    │    └─InvertedResidual: 3-3                  5,136
│    │    └─InvertedResidual: 3-4                  8,832
│    │    └─InvertedResidual: 3-5                  10,000
│    │    └─InvertedResidual: 3-6                  14,848
│    │    └─InvertedResidual: 3-7                  14,848
│    │    └─InvertedResidual: 3-8                  21,056
│    │    └─InvertedResidual: 3-9                  54,272
│    │    └─InvertedResidual: 3-10                 54,272
│    │    └─InvertedResidual: 3-11                 54,272
│    │    └─InvertedResidual: 3-12                 66,624
│    │    └─InvertedResidual: 3-13   

In [6]:
# Lo movemos a la GPU, en caso de que haya
mobilenet_v2 = mobilenet_v2.to(device)

In [7]:
x = torch.zeros(1, 3, 32, 32).to(device)
mobilenet_v2(x)

tensor([[ 0.4590,  0.1412,  0.0520,  0.2391, -0.0676,  0.0619, -0.0529, -0.1412,
         -0.0436,  0.0306]], device='cuda:0', grad_fn=<AddmmBackward>)

## Dataset & dataloader

In [8]:
transform = transforms.Compose([transforms.ToTensor()])

batch_size = 128

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4)

print(f'Trainset: {len(trainset)}')
print(f'Testset: {len(testset)}')

Files already downloaded and verified
Files already downloaded and verified
Trainset: 50000
Testset: 10000


## Entrenamiento

In [9]:
loss_hist, acc_hist = train(mobilenet_v2, trainloader, testloader, lr=1e-3, epochs=20)

  5%|▌         | 1/20 [00:43<13:50, 43.71s/it]

E00 loss=[111.59,117.64] acc=[59.66,57.26]


 10%|█         | 2/20 [01:27<13:11, 43.99s/it]

E01 loss=[ 82.77, 95.32] acc=[70.84,66.80]


 15%|█▌        | 3/20 [02:11<12:28, 44.03s/it]

E02 loss=[ 68.56, 82.31] acc=[76.25,71.63]


 20%|██        | 4/20 [02:56<11:45, 44.10s/it]

E03 loss=[ 52.32, 71.46] acc=[81.82,75.50]


 25%|██▌       | 5/20 [03:40<11:00, 44.01s/it]

E04 loss=[ 43.66, 69.68] acc=[85.15,76.16]


 30%|███       | 6/20 [04:23<10:15, 43.95s/it]

E05 loss=[ 36.78, 65.88] acc=[87.69,77.96]


 35%|███▌      | 7/20 [05:07<09:29, 43.83s/it]

E06 loss=[ 30.14, 65.25] acc=[89.79,78.58]


 40%|████      | 8/20 [05:51<08:45, 43.76s/it]

E07 loss=[ 25.31, 62.44] acc=[91.59,79.64]


 45%|████▌     | 9/20 [06:34<08:00, 43.70s/it]

E08 loss=[ 21.91, 62.78] acc=[92.69,80.22]


 50%|█████     | 10/20 [07:18<07:17, 43.75s/it]

E09 loss=[ 19.78, 66.82] acc=[93.08,80.16]


 55%|█████▌    | 11/20 [08:02<06:35, 43.94s/it]

E10 loss=[ 15.28, 65.24] acc=[94.90,80.69]


 60%|██████    | 12/20 [08:46<05:51, 43.92s/it]

E11 loss=[ 12.54, 69.05] acc=[95.74,80.70]


 65%|██████▌   | 13/20 [09:31<05:08, 44.03s/it]

E12 loss=[ 14.78, 74.32] acc=[94.69,79.95]


 70%|███████   | 14/20 [10:16<04:26, 44.37s/it]

E13 loss=[ 12.14, 75.67] acc=[95.83,81.06]


 75%|███████▌  | 15/20 [11:00<03:41, 44.30s/it]

E14 loss=[  9.14, 73.80] acc=[96.93,81.05]


 80%|████████  | 16/20 [11:44<02:57, 44.32s/it]

E15 loss=[  9.86, 76.11] acc=[96.65,80.86]


 85%|████████▌ | 17/20 [12:28<02:12, 44.23s/it]

E16 loss=[  9.45, 77.78] acc=[96.69,81.04]


 90%|█████████ | 18/20 [13:12<01:28, 44.20s/it]

E17 loss=[  9.30, 81.76] acc=[96.78,81.15]


 95%|█████████▌| 19/20 [13:56<00:44, 44.17s/it]

E18 loss=[  7.76, 80.52] acc=[97.33,81.29]


100%|██████████| 20/20 [14:41<00:00, 44.06s/it]

E19 loss=[  8.32, 81.60] acc=[97.07,81.29]





Guardamos el modelo.

In [11]:
torch.save(mobilenet_v2.state_dict(), 'models/mobilenet_v2.pt')