## 1. 导入相关包

In [1]:
# 加载和预处理数据集
import torch
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from torch import nn
from torchvision import transforms
import torchvision
from torch.utils.data import DataLoader

In [2]:
trans_train = transforms.Compose(
    [transforms.RandomCrop(32,padding=6),  # 将给定图像随机裁剪为不同的大小和宽高比，#然后缩放所裁剪得到的图像为制定的大小;
     # （即先随机采集,然后对裁剪得到的图像缩放为同一大小）默认scale=(0.08,1.0)
     transforms.RandomHorizontalFlip(),  # 以给定的概率随机水平旋转给定的PIL的图像，默认为0.5;
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])])
trans_valid = transforms.Compose(
    [   #transforms.Resize(64),  # 是按照比例把图像最小的一个边长放缩到256，另一边按照相同比例放缩。
        #transforms.CenterCrop(28),#依据给定的size从中心裁剪
        transforms.ToTensor(),
        # 将PIL Image或者ndarray 转换为tensor，并且归一化至[0-1]#归一化至[0-1]是直接除以255，若自己的ndarray数据尺度有变化，则需要自行修改。
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
# 对数据按通道进行标准化，即先减均值，再除以标准差，注意是 hwc
trainset = torchvision.datasets.CIFAR10(root=r"H:\datasets\data", train=True, download=True, transform=trans_train)
trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
testset = torchvision.datasets.CIFAR10(root=r'H:\datasets\data', train=False,
                                       download=False, transform=trans_valid)
testloader = DataLoader(testset, batch_size=256, shuffle=False)
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse ', 'ship', 'truck ')

Files already downloaded and verified


In [3]:
class Attention(nn.Module):
    def __init__(self, dim=128, heads=8, dim_head=64, dropout=0.):
        super(Attention, self).__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)
        self.attend = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(dropout)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim), nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        x = self.norm(x)
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d)->b h n d', h=self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        attn = self.dropout(attn)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class ViT(nn.Module):
    def __init__(self, num_classes=10, dim=512, depth=6, heads=8, mlp_dim=512, pool='cls', channels=3, dim_head=64,
                 dropout=0.1, emb_dropout=0.1):
        super().__init__()
        image_height = 32
        patch_height = 4
        image_width = 32
        patch_width = 4
        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2)->b (h w) (p1 p2 c) ', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim), nn.LayerNorm(dim), )
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(
            torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)
        self.transformer = Encoder(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool = pool
        self.to_latent = nn.Identity()
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.to_patch_embedding(img)
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)
        x = self.transformer(x)
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x)


class Encoder(nn.Module):
    def __init__(self, dim=512, depth=6, heads=8, dim_head=64, mlp_dim=512, dropout=0.):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim=dim, heads=heads, dim_head=dim_head, dropout=dropout),
                FeedForward(dim, mlp_dim, dropout=dropout)
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return self.norm(x)


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim), nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim), nn.Dropout(dropout)
        )

    def forward(self, x):
        return self.net(x)


In [4]:
import time


def train(epoch):
    print(' \nEpoch: %d' % epoch)
    model = ViT()
    device = 'cuda'
    optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
    net = model.to(device)
    net.to(device)
    train_loss = 0
    correct = 0
    total = 0
    criterion = torch.nn.CrossEntropyLoss()
    t = time.time()
    loss_all=[]
    acc_=[]
    for e in range(epoch):
        net.train()
        
        for batch_idx, (inputs, targets) in enumerate(trainloader):
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            # sparse_selection()
            optimizer.step()
            train_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        loss_all.append(float(loss))
        acc1=te(net, device, criterion)
        acc_.append(float(acc1))
        print(e, acc1)
        print(time.time() - t)
    with open('loss.txt','w',encoding="utf-8") as f:
        f.write(str(loss_all))
    with open('acc.txt','w',encoding="utf-8") as f:
        f.write(str(acc_))
    torch.save(model.state_dict(),'./model1.pt')


def te(net, device, criterion):
    test_loss = 0
    correct = 0
    total = 0
    net.eval()
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)
            test_loss += loss.item()
            _,predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

In [5]:
train(350)

 
Epoch: 350
0 0.4804
28.50688886642456
1 0.5269
56.97822833061218
2 0.57
85.51609563827515
3 0.5782
114.19909739494324
4 0.6055
142.78727746009827
5 0.6081
171.57024955749512
6 0.6385
200.20650720596313
7 0.6483
228.89482879638672
8 0.6462
257.56299448013306
9 0.6743
286.2913861274719
10 0.6811
315.4403626918793
11 0.6927
344.35603857040405
12 0.6761
373.19810366630554
13 0.6996
401.95400047302246
14 0.7202
430.8852138519287
15 0.7189
459.7382140159607
16 0.7213
488.5429470539093
17 0.7251
517.4018700122833
18 0.7371
546.2356767654419
19 0.7413
575.1401631832123
20 0.7443
603.9621634483337
21 0.7579
632.8306040763855
22 0.7506
661.5786106586456
23 0.7589
690.3726100921631
24 0.7659
719.287611246109
25 0.7619
748.0530767440796
26 0.7697
776.8490719795227
27 0.767
805.6145973205566
28 0.7795
834.4780602455139
29 0.7853
863.4260601997375
30 0.7853
892.457897901535
31 0.781
921.4418935775757
32 0.7843
950.5259001255035
33 0.7963
979.3449013233185
34 0.7954
1008.3544545173645
35 0.8087
103

KeyboardInterrupt: 