## 1. 导入相关包

In [1]:
# 加载和预处理数据集
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader

In [2]:
trans_train = transforms.Compose(
    [transforms.RandomCrop(32,padding=4),  # 将给定图像随机裁剪为不同的大小和宽高比，#然后缩放所裁剪得到的图像为制定的大小;
     # （即先随机采集,然后对裁剪得到的图像缩放为同一大小）默认scale=(0.08,1.0)
     transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像，默认为0.5;
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])])
trans_valid = transforms.Compose(
    [   #transforms.Resize(64),  # 是按照比例把图像最小的一个边长放缩到256，另一边按照相同比例放缩。
        #transforms.CenterCrop(28),#依据给定的size从中心裁剪
        transforms.ToTensor(),
        # 将PIL Image或者ndarray 转换为tensor，并且归一化至[0-1]#归一化至[0-1]是直接除以255，若自己的ndarray数据尺度有变化，则需要自行修改。
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# 对数据按通道进行标准化，即先减均值，再除以标准差，注意是 hwc
# trainset = torchvision.datasets.CIFAR10(root=r"H:\datasets\data", train=True, download=True, transform=trans_train)
# trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
testset = torchvision.datasets.CIFAR10(root=r'H:\datasets\data', train=False,
                                       download=False, transform=trans_valid)
testloader = DataLoader(testset, batch_size=256, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse ', 'ship', 'truck ')

In [3]:
class Attention(nn.Module):
    def __init__(self, dim=128, heads=8, dim_head=64, dropout=0.):
        super(Attention, self).__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim), nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d)->b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class ViT(nn.Module):
    def __init__(self, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=512, pool='cls', channels=3, dim_head=64,
                 dropout=0.1, emb_dropout=0.1):
        super().__init__()
        image_height = 32
        patch_height = 4
        image_width = 32
        patch_width = 4
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2)->b (h w) (p1 p2 c) ', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim), nn.LayerNorm(dim), )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Encoder(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)


class Encoder(nn.Module):
    def __init__(self, dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim), nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


In [4]:
import time


def train(epoch):
    print(' \nEpoch: %d' % epoch)
    model = ViT()
    device = 'cuda'
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    net = model.to(device)
    net.to(device)
    train_loss = 0
    correct = 0
    total = 0
    criterion = torch.nn.CrossEntropyLoss()
    t = time.time()
    loss_all=[]
    acc_=[]
    for e in range(epoch):
        net.train()
        trainset = torchvision.datasets.CIFAR10(root=r"H:\datasets\data", train=True, download=True, transform=trans_train)
        trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            # sparse_selection()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        loss_all.append(float(loss))
        acc1=te(net, device, criterion)
        acc_.append(float(acc1))
        print(e, acc1)
        print(time.time() - t)
    with open('loss.txt','w',encoding="utf-8") as f:
        f.write(str(loss_all))
    with open('acc.txt','w',encoding="utf-8") as f:
        f.write(str(acc_))
    torch.save(model.state_dict(),'./model1.pt')


def te(net, device, criterion):
    test_loss = 0
    correct = 0
    total = 0
    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _,predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

In [None]:
train(400)

 
Epoch: 400
Files already downloaded and verified
0 0.4761
29.307168006896973
Files already downloaded and verified
1 0.5389
58.45121216773987
Files already downloaded and verified
2 0.5878
87.4048445224762
Files already downloaded and verified
3 0.6133
116.34339594841003
Files already downloaded and verified
4 0.6341
145.18039512634277
Files already downloaded and verified
5 0.6398
173.89497756958008
Files already downloaded and verified
6 0.6512
202.7138066291809
Files already downloaded and verified
7 0.6406
231.46543073654175
Files already downloaded and verified
8 0.6779
260.24078011512756
Files already downloaded and verified
9 0.692
289.24566769599915
Files already downloaded and verified
10 0.6712
318.88187623023987
Files already downloaded and verified
11 0.7031
348.45001435279846
Files already downloaded and verified
12 0.7064
377.93659353256226
Files already downloaded and verified
13 0.7188
407.6235933303833
Files already downloaded and verified
14 0.731
437.13007640838623

Files already downloaded and verified
123 0.8601
3603.6176307201385
Files already downloaded and verified
124 0.8613
3632.452630996704
Files already downloaded and verified
125 0.8611
3661.4248089790344
Files already downloaded and verified
126 0.86
3690.3648071289062
Files already downloaded and verified
127 0.8583
3719.347805738449
Files already downloaded and verified
128 0.8644
3748.252805709839
Files already downloaded and verified
129 0.8624
3777.2108058929443
Files already downloaded and verified
130 0.8616
3806.202753305435
Files already downloaded and verified
131 0.8597
3835.084849834442
Files already downloaded and verified
132 0.8607
3864.048240184784
Files already downloaded and verified
133 0.8588
3892.964912414551
Files already downloaded and verified
134 0.856
3921.906912088394
Files already downloaded and verified
135 0.8525
3950.734563589096
Files already downloaded and verified
136 0.861
3979.632709503174
Files already downloaded and verified
137 0.8571
4008.50316047

245 0.8703
7135.537796735764
Files already downloaded and verified
246 0.8665
7164.451796770096
Files already downloaded and verified
247 0.8716
7193.568591594696
Files already downloaded and verified
248 0.8713
7222.418637275696
Files already downloaded and verified
249 0.8704
7251.385093688965
Files already downloaded and verified
250 0.8704
7280.374188661575
Files already downloaded and verified
251 0.8657
7309.295281887054
Files already downloaded and verified
252 0.8719
7338.296778917313
Files already downloaded and verified
253 0.8694
7367.257781505585
Files already downloaded and verified
254 0.8697
7396.2707777023315
Files already downloaded and verified
255 0.8695
7425.350219249725
Files already downloaded and verified
256 0.8714
7454.29833316803
Files already downloaded and verified
257 0.8667
7483.201336860657
Files already downloaded and verified
258 0.8703
7512.001064538956
Files already downloaded and verified
259 0.872
7541.01106262207
Files already downloaded and verifi