In [2]:
import os

import pandas as pd
import torch
import torchmetrics
from sklearn.metrics import classification_report
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

In [3]:
os.makedirs("../dataset", exist_ok=True)

In [4]:
transform = transforms.Compose(
    [
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,)),
    ]
)

In [5]:
train_set = datasets.CIFAR10(
    "../dataset", train=True, download=True, transform=transform
)
test_set = datasets.CIFAR10(
    "../dataset", train=False, download=False, transform=transform
)

Files already downloaded and verified


In [6]:
train_set

Dataset CIFAR10
    Number of datapoints: 50000
    Root location: ./dataset
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=224, interpolation=bilinear, max_size=None, antialias=None)
               CenterCrop(size=(224, 224))
               ToTensor()
               Normalize(mean=(0.1307,), std=(0.3081,))
           )

In [7]:
train_set.class_to_idx

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [8]:
# Reduce data dut to avoide OOM
train_set = torch.utils.data.Subset(train_set, list(range(0, 500)))
valid_set = torch.utils.data.Subset(train_set, list(range(0, 50)))
test_set = torch.utils.data.Subset(train_set, list(range(50, 100)))

In [9]:
train_loader = DataLoader(train_set, shuffle=True, batch_size=16)
valid_loader = DataLoader(train_set, shuffle=True, batch_size=10)
test_loader = DataLoader(test_set, shuffle=False, batch_size=1)

### Define model

#### Use vit_b_16 pretrained model


In [10]:
model = models.vit_b_16(weights="DEFAULT")

In [11]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

#### Fine tuning setting


In [12]:
# for param in model.parameters():
#     param.requires_grad = False

In [13]:
# Change output unit number
num_features = model.heads[0].in_features
print(num_features)

model.heads[0] = torch.nn.Linear(num_features, 10)

768


In [14]:
model

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

### Training


In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(device)

cuda


In [16]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [17]:
train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)
val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=10).to(device)

In [18]:
num_epoch = 10

for epoch in range(num_epoch):
    running_loss = 0.0
    model.train()
    for batch in train_loader:
        inputs = batch[0].to(device)
        labels = batch[1].to(device)

        optimizer.zero_grad()  # Initialize params
        outputs = model(inputs)  # Forward
        loss = criterion(outputs, labels)  # Calc loss
        loss.backward()  # Backpropagation
        optimizer.step()  # Update params
        running_loss += loss.item()  # Calc total
        train_acc.update(outputs, labels)  # Calc Acc

    model.eval()
    val_running_loss = 0.0
    val_running_acc = 0.0
    with torch.no_grad():
        for batch in valid_loader:
            inputs = batch[0].to(device)
            labels = batch[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            val_acc.update(outputs, labels)

    print(
        f"epoch:{epoch+1}, loss:{running_loss / len(train_loader):.2f}, train_acc:{train_acc.compute():.2f}, val_acc:{val_acc.compute():.2f}"
    )

epoch:1, loss:2.51, train_acc:0.10, val_acc:0.10
epoch:2, loss:2.37, train_acc:0.10, val_acc:0.11
epoch:3, loss:2.30, train_acc:0.11, val_acc:0.10
epoch:4, loss:2.36, train_acc:0.11, val_acc:0.11
epoch:5, loss:2.29, train_acc:0.12, val_acc:0.12
epoch:6, loss:2.26, train_acc:0.12, val_acc:0.13
epoch:7, loss:2.25, train_acc:0.13, val_acc:0.14
epoch:8, loss:2.22, train_acc:0.13, val_acc:0.15
epoch:9, loss:2.19, train_acc:0.14, val_acc:0.15
epoch:10, loss:2.19, train_acc:0.14, val_acc:0.16


### Save and load model


In [19]:
torch.save(model.state_dict(), "model_weight.pth")

In [20]:
model = models.vit_b_16(weights="DEFAULT")
model.heads[0] = torch.nn.Linear(num_features, 10)
model.load_state_dict(torch.load("model_weight.pth"))
model.to(device)

VisionTransformer(
  (conv_proj): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
  (encoder): Encoder(
    (dropout): Dropout(p=0.0, inplace=False)
    (layers): Sequential(
      (encoder_layer_0): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_attention): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
        )
        (dropout): Dropout(p=0.0, inplace=False)
        (ln_2): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (mlp): MLPBlock(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU(approximate='none')
          (2): Dropout(p=0.0, inplace=False)
          (3): Linear(in_features=3072, out_features=768, bias=True)
          (4): Dropout(p=0.0, inplace=False)
        )
      )
      (encoder_layer_1): EncoderBlock(
        (ln_1): LayerNorm((768,), eps=1e-06, elementwise_affine=True)
        (self_a

### Predict


In [21]:
model.eval()
result = []
with torch.no_grad():
    for batch in test_loader:
        inputs = batch[0].to(device)
        labels = batch[1].to(device)
        outputs = model(inputs)
        result.append(
            {
                "pred": outputs.argmax().to("cpu").detach().numpy().flatten()[0],
                "true": labels.to("cpu").detach().numpy()[0],
            }
        )


In [22]:
result = pd.DataFrame(result)


In [23]:
result.head()


Unnamed: 0,pred,true
0,8,9
1,0,5
2,1,7
3,0,9
4,4,2


In [24]:
print(classification_report(result["true"].values, result["pred"].values))


              precision    recall  f1-score   support

           0       0.11      1.00      0.20         2
           1       0.19      0.30      0.23        10
           2       0.00      0.00      0.00         5
           3       0.00      0.00      0.00         5
           4       0.50      0.33      0.40         6
           5       0.00      0.00      0.00         5
           6       0.00      0.00      0.00         2
           7       0.00      0.00      0.00         6
           8       0.18      0.67      0.29         3
           9       0.00      0.00      0.00         6

    accuracy                           0.18        50
   macro avg       0.10      0.23      0.11        50
weighted avg       0.11      0.18      0.12        50



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [25]:
result


Unnamed: 0,pred,true
0,8,9
1,0,5
2,1,7
3,0,9
4,4,2
5,8,2
6,6,5
7,1,2
8,4,4
9,1,3
