<a href="https://colab.research.google.com/github/eR3R3/visionTransformerClassification/blob/main/VisionTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Config and embedding

In [None]:
import torch
from torch import nn
import numpy as np
import torch.optim as optim
import os
import yaml
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from typing import Dict, List, Optional, Union, Tuple, Iterable
from google.colab import files

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class VisionConfig:

    def __init__(
        self,
        emb_dim=224,
        num_layer=12,
        batch_size=32,
        num_attention_head=8,
        in_channel=3,
        image_size=32,
        kernel_size=8,
        layer_norm_eps=1e-6,
        attention_dropout=0.0,
        num_feature=16,
        num_image_tokens: int = None):

        self.emb_dim = emb_dim
        self.batch_size = batch_size
        self.kernel_size = kernel_size
        self.num_layer = num_layer
        self.num_feature = num_feature
        self.num_attention_head = num_attention_head
        self.image_size = image_size
        self.in_channel = in_channel
        self.attention_dropout = attention_dropout
        self.layer_norm_eps = layer_norm_eps
        self.num_image_tokens = num_image_tokens


class VisionEmbedding(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.emb_dim = config.emb_dim
        self.in_channel = config.in_channel
        self.out_channel = self.emb_dim
        self.kernel_size = config.kernel_size
        self.image_size = config.image_size
        self.conv_1 = nn.Conv2d(kernel_size=self.kernel_size,
                                in_channels=self.in_channel,
                                out_channels=self.out_channel,
                                stride=self.kernel_size)
        self.num_feature = (self.image_size // self.kernel_size) ** 2
        self.pos_embedding = nn.Embedding(self.num_feature, self.emb_dim)
        self.register_buffer("pos_emb_index",
                             torch.arange(self.num_feature),
                             persistent=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, in_channel, height, width = x.shape
        # x_shape = [batch_size, in_channel, height, width]
        x = self.conv_1(x)
        # x_shape = [batch_size, out_channel, height, width]
        x = x.flatten(2)
        # x_shape = [batch_size, emb_dim, num_feature]
        x = x.transpose(1, 2)
        # x_shape = [batch_size, num_feature, emb_dim]
        x = x + self.pos_embedding(self.pos_emb_index)
        return x


LayerNorm and MLP


In [None]:
class LayerNorm(nn.LayerNorm):
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        orig_dtype = input.dtype
        ans = super().forward(input.type(torch.float32))
        ans = ans.type(orig_dtype)
        return ans


class Mlp(nn.Module):
    def __init__(self, emb_dim: int):
        super().__init__()
        self.linear_1 = nn.Linear(emb_dim, emb_dim*4)
        self.gelu = nn.GELU()
        self.linear_2 = nn.Linear(emb_dim*4, emb_dim)
    def forward(self, x:torch.Tensor) -> torch.Tensor:
        x = self.linear_1(x)
        x = self.gelu(x)
        x = self.linear_2(x)
        return x


Attention Mechanism and Residual Transformer

In [None]:
class VisionAttention(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.attention_dropout = config.attention_dropout
        self.emb_dim = config.emb_dim
        self.mlp = Mlp(emb_dim=config.emb_dim)
        self.linear_qkv = nn.Linear(self.emb_dim, self.emb_dim * 3)
        self.dropout = nn.Dropout(config.attention_dropout)
        self.layer_norm = LayerNorm(self.emb_dim)
        self.num_head = config.num_attention_head
        self.num_feature =config.num_feature
        self.head_dim = config.emb_dim // config.num_attention_head
        self.scale = self.head_dim ** -0.5
        self.batch_size = config.batch_size
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear_qkv(x)
        # [batch_size, num_feature, emb_dim*3]
        q, k, v = x.chunk(3, dim=-1)
        # 3 * [batch_size, num_feature, emb_dim]
        q = q.reshape(self.batch_size, self.num_feature, self.num_head, self.head_dim).permute(0, 2, 1, 3)
        k = k.reshape(self.batch_size, self.num_feature, self.num_head, self.head_dim).permute(0, 2, 1, 3)
        v = v.reshape(self.batch_size, self.num_feature, self.num_head, self.head_dim).permute(0, 2, 1, 3)
        # 3 * [batch_size, num_head, num_feature, head_dim]
        attn_weights = (torch.matmul(q, k.transpose(2, 3)) * self.scale)
        attn_weights = attn_weights.softmax(dim=-1)
        attn_weights = self.dropout(attn_weights)
        x = torch.matmul(attn_weights, v)
        x = x.transpose(1, 2)
        x = x.reshape(self.batch_size, self.num_feature, self.emb_dim)
        return x

class ResidualVisionTransformer(nn.Module):
    def __init__(self,config: VisionConfig, attention=VisionAttention, Mlp=Mlp):
        super().__init__()
        self.attention = attention(config)
        self.num_layer = config.num_layer
        self.mlp = Mlp(emb_dim=config.emb_dim)
        self.layer_norm = LayerNorm(config.emb_dim)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        for i in range(self.num_layer):
            residual = x
            x = self.layer_norm(x)
            x = self.attention(x)
            x = self.mlp(x)
            x = x + residual
        return x

Final Transformer Model

In [None]:
class VisionModel(nn.Module):
    def __init__(self, config: VisionConfig):
        super().__init__()
        self.embedding = VisionEmbedding(config)
        self.transformer = ResidualVisionTransformer(config)
        self.layer_norm = LayerNorm(config.emb_dim)
        self.condensation = nn.Linear(224, 1)
        self.projection = nn.Linear(16, 10)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        x = self.transformer(x)
        x = self.layer_norm(x)
        x = self.condensation(x)
        x = x.flatten(1)
        x = self.projection(x)
        return x

In [None]:
vision_config = VisionConfig()
vision_model = VisionModel(vision_config)
vision_model.to(device)

VisionModel(
  (embedding): VisionEmbedding(
    (conv_1): Conv2d(3, 224, kernel_size=(8, 8), stride=(8, 8))
    (pos_embedding): Embedding(16, 224)
  )
  (transformer): ResidualVisionTransformer(
    (attention): VisionAttention(
      (mlp): Mlp(
        (linear_1): Linear(in_features=224, out_features=896, bias=True)
        (gelu): GELU(approximate='none')
        (linear_2): Linear(in_features=896, out_features=224, bias=True)
      )
      (linear_qkv): Linear(in_features=224, out_features=672, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (layer_norm): LayerNorm((224,), eps=1e-05, elementwise_affine=True)
    )
    (mlp): Mlp(
      (linear_1): Linear(in_features=224, out_features=896, bias=True)
      (gelu): GELU(approximate='none')
      (linear_2): Linear(in_features=896, out_features=224, bias=True)
    )
    (layer_norm): LayerNorm((224,), eps=1e-05, elementwise_affine=True)
  )
  (layer_norm): LayerNorm((224,), eps=1e-05, elementwise_affine=True)
  (cond

In [None]:
#uploaded = files.upload()


In [None]:
from google.colab import drive
drive.mount('/content/drive')

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

meta_file_path = "/content/drive/MyDrive/dataset/cifer_10/batches.meta"
meta_file = unpickle(meta_file_path)
label_name = meta_file[b'label_names']
label_name = [byte.decode() for byte in label_name]
print(label_name[0])

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
airplane


In [None]:
def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

file = "./data_batch_1"
img_data = unpickle(file)
img_data_keys = list(img_data.keys())[:4]
img_filename = img_data[b'filenames']
img_label = img_data[b'labels']
img_data = img_data[b'data']
img_data = torch.from_numpy(img_data)
img_label = torch.tensor(img_label)
img_data = img_data.to(device)
img_label = img_label.to(device)
print(img_data_keys)
print(img_data.shape)
print(img_label.shape)

[b'batch_label', b'labels', b'data', b'filenames']
torch.Size([10000, 3072])
torch.Size([10000])


In [1]:
def preprocess(data):
  img_data = data.reshape(-1, 3, 32, 32)
  img_data = img_data.float()/255.0
  mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
  std = torch.tensor([0.229, 0.224, 0.225]).to(device)
  mean = mean[None, :, None, None]
  std = std[None, :, None, None]
  normalized_images = (img_data - mean) / std
  return normalized_images
img_data = preprocess(img_data)
print("preprocessed data shape:",img_data.shape)

class VisionDataset(Dataset):
  def __init__(self, data, label, preprocess):
    super().__init__()
    # data: [num_batches, num_channel, height, width]
    self.label = label
    self.preprocess = preprocess
    self.data = preprocess(data)
  def __len__(self):
    return len(self.data)

  def __getitem__(self, idx):
    img = self.data[idx]
    label = self.label[idx]
    return img, label


train_set = VisionDataset(img_data[0:8000], img_label[0:8000], preprocess)
val_set = VisionDataset(img_data[8000:9984], img_label[8000:9984], preprocess)

batch_size = 32
train_loader = DataLoader(train_set, batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size, shuffle=False)

right_num_total = 0
for num, (img, label) in enumerate(train_loader):
  logits = vision_model(img)
  pred = nn.functional.softmax(logits, dim=-1)
  _, max_index = pred.max(dim=-1)
  right_num_batch = (max_index == label).sum().item()
  right_num_total = right_num_batch + right_num_total
right_rate = right_num_total/8000
print("original accuracy:", f"{right_rate*100}%")

NameError: name 'img_data' is not defined

In [None]:
from torch.optim import Adam
from torch.nn import CrossEntropyLoss

optimizer = Adam(vision_model.parameters(), lr=1e-4)
criterion = CrossEntropyLoss()

best_val_loss = float('inf')
patience = 10
counter = 0

for epoch in range(100):  # 最大训练 100 个 epoch
    # 训练阶段
    vision_model.train()
    for img, label in train_loader:
        optimizer.zero_grad()
        output = vision_model(img)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()

    # 验证阶段
    vision_model.eval()
    val_loss = 0
    with torch.no_grad():
        right_num_total = 0
        label_num_total = 0
        for img, label in val_loader:
            output = vision_model(img)
            val_loss += criterion(output, label).item()
            pred = vision_model(img)
            _, max_index = pred.max(dim=-1, keepdim=False)
            right_num_batch = (max_index==label).sum().item()
            right_num_total = right_num_total + right_num_batch
            label_num_total = label_num_total + 32
    val_loss /= len(val_loader)
    accuracy = right_num_total / label_num_total

    print(f"Epoch {epoch}, Validation Loss: {val_loss:.3f}, Accuracy: {accuracy*100:.3f}%")

    # Early stopping
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping triggered.")
            break

Epoch 0, Validation Loss: 2.3058
Epoch 1, Validation Loss: 2.3030
Epoch 2, Validation Loss: 2.3030
Epoch 3, Validation Loss: 2.3021
Epoch 4, Validation Loss: 2.3054
Epoch 5, Validation Loss: 2.3025
Epoch 6, Validation Loss: 2.3038
Epoch 7, Validation Loss: 2.3019
Epoch 8, Validation Loss: 2.3027
Epoch 9, Validation Loss: 2.3023
Epoch 10, Validation Loss: 2.3018
Epoch 11, Validation Loss: 2.3037
Epoch 12, Validation Loss: 2.3020
Epoch 13, Validation Loss: 2.3016
Epoch 14, Validation Loss: 2.3016
Epoch 15, Validation Loss: 2.3015
Epoch 16, Validation Loss: 2.3013
Epoch 17, Validation Loss: 2.3014
Epoch 18, Validation Loss: 2.3029
Epoch 19, Validation Loss: 2.3007
Epoch 20, Validation Loss: 2.3011
Epoch 21, Validation Loss: 2.3002
Epoch 22, Validation Loss: 2.3010
Epoch 23, Validation Loss: 2.2958
Epoch 24, Validation Loss: 2.2871
Epoch 25, Validation Loss: 2.2932
Epoch 26, Validation Loss: 2.1632
Epoch 27, Validation Loss: 2.1308
Epoch 28, Validation Loss: 2.1378
Epoch 29, Validation Los