In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange   # einops模块 提供 rearrange数据维度重排， reduce 数据统计， repeat 复制方式扩展维度

# from einops.layers.torch import Rearrange

In [2]:
class Embedding(nn.Module):  # Patch Embedding + Position Embedding + Class Embedding
    def __init__(self, image_channels=3, image_size=224, patch_size=16, dim=768, drop_ratio=0.):
        super(Embedding, self).__init__()
        self.num_patches = (image_size // patch_size) ** 2  # Patch数量

        self.patch_conv = nn.Conv2d(image_channels, dim, patch_size, patch_size)  # 使用卷积将图像划分成Patches
        self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))  # class embedding
        self.pos_emb = nn.Parameter(torch.zeros(1, self.num_patches + 1, dim))  # position embedding
        self.dropout = nn.Dropout(drop_ratio)

    def forward(self, x):
        x = self.patch_conv(x)
        x = rearrange(x, "B C H W -> B (H W) C")
        cls_token = torch.repeat_interleave(self.cls_token, x.shape[0], dim=0)  # (1,1,dim) -> (B,1,dim)
        x = torch.cat([cls_token, x], dim=1)  # (B,1,dim) cat (B,num_patches,dim) --> (B,num_patches+1,dim)
        x = x + self.pos_emb
        return self.dropout(x)  # token

In [3]:
class MultiHeadAttention(nn.Module):  # Multi-Head Attention
    def __init__(self, dim, num_heads=8, drop_ratio=0.):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.head_dim = dim // num_heads

        self.qkv = nn.Linear(dim, dim * 3, bias=False)  # 使用一个Linear，计算得到qkv
        self.dropout = nn.Dropout(drop_ratio)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        # B: Batch Size / P: Num of Patches / D: Dim of Patch / H: Num of Heads / d: Dim of Head
        qkv = self.qkv(x)
        qkv = rearrange(qkv, "B P (C H d) -> C B H P d", C=3, H=self.num_heads, d=self.head_dim)
        q, k, v = qkv[0], qkv[1], qkv[2]  # 分离qkv
        k = rearrange(k, "B H P d -> B H d P")
        # Attention(Q, K, V ) = softmax(QKT/dk)V （T表示转置)
        attn = torch.matmul(q, k) * self.head_dim ** -0.5  # QKT/dk
        attn = F.softmax(attn, dim=-1)  # softmax(QKT/dk)
        attn = self.dropout(attn)
        x = torch.matmul(attn, v)  # softmax(QKT/dk)V
        x = rearrange(x, "B H P d -> B P (H d)")
        x = self.proj(x)
        x = self.dropout(x)
        return x

In [4]:

class MLP(nn.Module):  # MLP
    def __init__(self, in_dims, hidden_dims=None, drop_ratio=0.):
        super(MLP, self).__init__()
        if hidden_dims is None:
            hidden_dims = in_dims * 4  # linear的hidden_dims默认为in_dims的4倍

        self.fc1 = nn.Linear(in_dims, hidden_dims)
        self.fc2 = nn.Linear(hidden_dims, in_dims)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(drop_ratio)

    def forward(self, x):
        # Linear + GELU + Dropout + Linear + Dropout
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

In [5]:

class EncoderBlock(nn.Module):  # Transformer Encoder Block
    def __init__(self, dim, num_heads=8, drop_ratio=0.):
        super(EncoderBlock, self).__init__()

        self.layernorm1 = nn.LayerNorm(dim)
        self.multiheadattn = MultiHeadAttention(dim, num_heads)
        self.dropout = nn.Dropout(drop_ratio)
        self.layernorm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim)

    def forward(self, x):
        # 两次残差连接，分别在Multi-Head Attention和MLP之后
        x0 = x
        x = self.layernorm1(x)
        x = self.multiheadattn(x)
        x = self.dropout(x)
        x1 = x + x0  # 第一次残差连接
        x = self.layernorm2(x1)
        x = self.mlp(x)
        x = self.dropout(x)
        return x + x1  # 第二次残差连接

In [6]:
class MLPHead(nn.Module):  # MLP Head
    def __init__(self, dim, num_classes=1000):
        super(MLPHead, self).__init__()
        self.layernorm = nn.LayerNorm(dim)
        # 对于一般数据集，此处为1层Linear; 对于ImageNet-21k数据集，此处为Linear+Tanh+Linear
        self.mlphead = nn.Linear(dim, num_classes)

    def forward(self, x):
        x = self.layernorm(x)
        cls = x[:, 0, :]  # 去除class token
        return self.mlphead(cls)

In [7]:

class ViT(nn.Module):  # Vision Transformer
    def __init__(self, image_channels=3, image_size=224, num_classes=1000, patch_size=16, dim=768, num_heads=12,
                 layers=12):
        super(ViT, self).__init__()
        self.embedding = Embedding(image_channels, image_size, patch_size, dim)
        self.encoder = nn.Sequential(
            *[EncoderBlock(dim, num_heads) for i in range(layers)]  # encoder结构为layers(L)个Transformer Encoder Block
        )
        self.head = MLPHead(dim, num_classes)

    def forward(self, x):
        x_emb = self.embedding(x)
        feature = self.encoder(x_emb)
        return self.head(feature)

In [8]:
def vit_base(num_classes=1000):  # ViT-Base
    return ViT(image_channels=3, image_size=224, num_classes=num_classes, patch_size=16, dim=768, num_heads=12,
               layers=12)


def vit_large(num_classes=1000):  # ViT-Large
    return ViT(image_channels=3, image_size=224, num_classes=num_classes, patch_size=16, dim=1024, num_heads=16,
               layers=24)


def vit_huge(num_classes=1000):  # ViT-Huge
    return ViT(image_channels=3, image_size=224, num_classes=num_classes, patch_size=16, dim=1280, num_heads=16,
               layers=32)

In [10]:
images = torch.randn(8, 3, 224, 224)
vb = vit_base()
#vl = vit_large()
#vh = vit_huge()
print(vb(images).shape)
#print(vl(images).shape)
#print(vh(images).shape)

torch.Size([8, 1000])


In [17]:
from torchvision.datasets import mnist
import torchvision
import torchvision.transforms as transforms
import torch.utils as utils

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
training_data = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testing_data = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [18]:
train_db, val_db = utils.data.random_split(training_data, [50000,10000])

train_loader = DataLoader(train_db, batch_size=32, shuffle=True)
val_loader = DataLoader(val_db, batch_size=32, shuffle=True)
test_loader = DataLoader(testing_data, batch_size=32, shuffle=True)

In [26]:
def vit_base(num_classes=1000):  # ViT-Base
    return ViT(image_channels=1, image_size=28, num_classes=10, patch_size=9, dim=768, num_heads=12,
               layers=12)

In [27]:
device = torch.device("cuda:0")

net = vit_base().to(device=device, dtype=torch.float32)
optimizer = torch.optim.Adam(net.parameters(), lr=0.0001, weight_decay=1e-8)
criterion = nn.CrossEntropyLoss()

In [30]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
MinTrainLoss=999
epochs=10

train_loss = []
val_loss = []

train_acc=[]
val_acc=[]

for epoch in range(1, epochs+1):
    total_train_loss = []
    total_val_loss = []
    
    
    # Train 
    net.train()
    for input_img, label in train_loader:
        input_img = input_img.cuda()
        label = label.cuda()
        optimizer.zero_grad()
        
        pred_img = net(input_img)
        loss = criterion(pred_img, label)
        loss.backward()
        optimizer.step()

        
    
    # Verify
    net.eval()
    
    current = 0
    with torch.no_grad():
        for train_img, train_label in train_loader:
            train_img = train_img.cuda()
            train_label = train_label.cuda()
            
            pred = net(train_img)
            total_train_loss.append(criterion(pred, train_label).item())
            current += (pred.argmax(1)==train_label).type(torch.float).sum().item()
            
    train_loss.append(np.mean(total_train_loss))
    train_acc.append(current/500)
    
    
    current = 0
    with torch.no_grad():
        for val_img, val_label in val_loader:
            val_img = val_img.cuda()
            val_label = val_label.cuda()
            
            pred = net(val_img)
            total_val_loss.append(criterion(pred, val_label).item())
            current += (pred.argmax(1)==val_label).type(torch.float).sum().item()
            
    val_loss.append(np.mean(total_val_loss))
    val_acc.append(current/100)
    
    print("epochs[%3d/%3d] val_loss: %.5f, val_acc: %.3f"%(epoch, epochs, val_loss[-1],val_acc[-1]))

epochs[  1/ 10] val_loss: 0.17300, val_acc: 94.550
epochs[  2/ 10] val_loss: 0.14623, val_acc: 95.340
epochs[  3/ 10] val_loss: 0.14138, val_acc: 95.630


In [None]:
#  Plot loss curve
plt.plot(range(50), train_loss)
plt.plot(range(50), val_loss)
plt.legend(["train","val"])
plt.xticks(np.arange(0, 50, 5))  # 横坐标的值和步长
plt.yticks(np.arange(0, 1, 0.1))  # 横坐标的值和步长
plt.xlabel("Epochs")
plt.ylabel("loss(100%)")
plt.title("Loss curve")
plt.show()

In [None]:
#  Plot accuracy curve
plt.plot(range(50), train_acc)
plt.plot(range(50), val_acc)
plt.legend(["train","val"])
plt.xticks(np.arange(0, 50, 5))  # 横坐标的值和步长
plt.yticks(np.arange(90, 100, 1))  # 横坐标的值和步长
plt.xlabel("Epochs")
plt.ylabel("accuracy(100%)")
plt.title("Accuracy curve")
plt.show()