## 1 导入相关包

In [1]:
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchvision

## 2 获取数据集

In [2]:
def get_dataset(path, batch_size=32, transform=None):
    """
    加载MNIST数据集并将其转换为DataLoader对象。
    :param path: 数据集路径
    :param batch_size: 批处理大小
    :param transform: 数据预处理
    :return: 训练集与测试集的DataLoader对象
    """
    if transform is None:
        transform = torchvision.transforms.Compose([  # 对图像进行预处理
            torchvision.transforms.ToTensor(),  # 将图片转换成张量
            torchvision.transforms.Normalize((0.5,), (0.5,))  # 对图像进行归一化处理
        ])

    train = CIFAR_Dataset(path, train=True, transform=transform)
    tset = CIFAR_Dataset(path, train=False, transform=transform)

    # 创建dataloader对象
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(tset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader

## 3 数据集类

In [3]:
class CIFAR_Dataset(Dataset):
    def __init__(self, data_dir, train, transform):  # 数据集的位置，训练集还是测试集，以及数据预处理的变换
        super(CIFAR_Dataset, self).__init__()
        self.data_dir = data_dir
        self.train = train
        self.transform = transform
        self.data = []
        self.targets = []

        # 判断是否为训练集
        if self.train:
            for i in range(5):  # CIFAR-10训练数据集有5个文件，所以要循环5次读取
                with open(data_dir + '/cifar-10-batches-py/data_batch_' + str(i + 1), 'rb') as f:  # 二进制格式读取文件
                    entry = pickle.load(f, encoding='latin1')  # 对文件进行反序列化成python对象
                    self.data.append(entry['data'])  # 读取文件中data部分的数据并将其添加到self.data中
                    self.targets.extend(entry['labels'])  # 读取文件中labels部分的数据并将其添加到self.targets中
        else:  # 操作与上述相同，只是读取的是测试集
            with open(data_dir + '/cifar-10-batches-py/test_batch', 'rb') as f:
                entry = pickle.load(f, encoding='latin1')
                self.data.append(entry['data'])
                self.targets.extend(entry['labels'])
        # 上面的操作是将数据添加到列表中，就会对整体数据添加一个纬度，
        # 比如训练集本身是n*3*32*32,现在变成了 5*(n/5)*3*32*32,所以需要reshape一下,
        # -1将5与n/5这两个纬度合并，变成n
        self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
        # 对纬度进行转置，这个跟图片数组相关
        self.data = self.data.transpose((0, 2, 3, 1))

    # 获取数据集长度
    def __len__(self):
        return len(self.data)

    # 让对象能像数组一样根据下标访问
    def __getitem__(self, idx):
        # 这里是自己构建one-hot数组，可以利用torch.nn.functional 中的ont_hot函数进行变换
        label = torch.zeros(10)
        label[self.targets[idx]] = 1.

        # 判断是否有预处理函数，如果有则对数据进行预处理
        if self.transform:
            image = self.transform(self.data[idx])
        if self.train and idx > 0 and idx % 5 == 0:
            # 获取一个数据集长度的随机数
            mixup_idx = random.randint(0, len(self.data) - 1)
            # 设置one_hot数组
            mixup_label = torch.zeros(10)
            label[self.targets[mixup_idx]] = 1.

            # 如果存在预处理函数，则对数据集进行预处理
            if self.transform:
                mixup_image = self.transform(self.data[mixup_idx])

            # 根据beta分布的随机数，对数据进行cutmix操作
            mask = np.ones_like(image)  # 生成mask矩阵，用于对图像进行cut操作
            la = float(np.random.beta(0.5, 0.5, 1))  # 生成一个符合beta分布的随机数
            # 随机获取切割的部分
            rx = np.int8(np.random.uniform(0, 32, 1))[0]
            ry = np.int8(np.random.uniform(0, 32, 1))[0]
            rw = np.int8(np.power(1 - la, 0.5) * 32)
            rh = np.int8(np.power(1 - la, 0.5) * 32)
            if rx > rw:
                rx, rw = rw, rx
            if ry > rh:
                ry, rh = rh, ry
            #  对图像进行cut操作
            mask[rx:rw, ry:rh, :] = 0
            #   对图像进行mix操作
            image = image * mask + mixup_image * (1 - mask)
            label = la * label + (1 - la) * mixup_label
        return image, label

## 4 取patch

In [4]:
def image2embed(image, patch_size):
    """
    将图像转换为嵌入向量
    :param image: 图片  batch_size * channel * h * w
    :param patch_size: 块大小
    :return:
    """
    patch = F.unfold(image, kernel_size=patch_size, stride=patch_size).transpose(-1, -2)  # 将图片分成块，它实质是将卷积的部分直接取出来
    return patch 

## 5 Embedding层

In [5]:
class Embedding(nn.Module):
    def __init__(self, channel, batchsize, psize, patchsize, emb_dim, device):
        """
        词嵌入层
        :param batchsize: 批量大小
        :param psize: 用于位置编码的一个参数，它的大小等于  图片通道数 * (一张图片一行数据的大小//patchsize)²
        :param patchsize: 提取图块的边长
        :param emb_dim: 嵌入维度
        :param device: 运算设备
        """
        super(Embedding, self).__init__()
        self.pathF = image2embed  # 导入提取图片块的函数
        self.patchszie = patchsize  # 边长
        self.emb_dim = emb_dim  # 嵌入纬度
        self.l1 = nn.Linear(patchsize * patchsize * channel, emb_dim)  # 用于将图片块映射为为嵌入纬度大小
        # 定义一个矩阵嵌入到输入数据开头，表示数据的开始
        self.cls_token_emb = torch.randn(batchsize, 1, self.emb_dim, requires_grad=True, device=device)
        # 位置编码
        self.position_emb = torch.randn(1, psize, self.emb_dim, requires_grad=True, device=device)

    def forward(self, x):  # 前向传播
        """
        这里将图片块转换为嵌入纬度，加入了开头与位置编码
        :param x:
        :return:
        """

        x = self.pathF(x, self.patchszie)
        x = self.l1(x)
        x = torch.cat((self.cls_token_emb[:x.shape[0]], x), dim=1)
        x += self.position_emb
        return x

## 6 注意力

In [6]:
class Attention(nn.Module):
    def __init__(self, emb_dim=128, head=8):
        """
        注意力机制
        :param emb_dim: 词嵌入纬度
        :param head: 多头头数
        """
        super(Attention, self).__init__()
        assert emb_dim % head == 0  # 保证emb_dim可以整除head，注意力机制的词嵌入维度需要是多头的n倍
        self.emb_dim = emb_dim  # 词嵌入纬度
        self.head = head  # 多头
        self.head_dim = emb_dim // head

        # q k v 三个输入的线性层  维度变换 emb_dim → emb_dim
        self.query_L = nn.Linear(emb_dim, emb_dim)
        self.key_L = nn.Linear(emb_dim, emb_dim)
        self.value_L = nn.Linear(emb_dim, emb_dim)

    def forward(self, q, k, v):
        """
        前向传播 q,k,v为transformer的三个输入，这里做了注意力机制的运算
        :return:
        """
        # q,k,v的形状为 batchsize 长度 词嵌入纬度 ，下面batchsize，长度，词嵌入纬度，头数，分别用 B L D H 代替
        # 这里进行多头注意力机制进行计算，因此需要进行纬度变换
        x_q = self.query_L(q)  # q 进行线性层变换 B,L,D → B,L,D
        x_q = x_q.reshape(q.shape[0], q.shape[1], self.head, self.head_dim)  # B,L,D → B,L,H,D/H
        x_q = x_q.transpose(1, 2)  # B,L,H,D/H → B,H,L,D/H
        x_q = x_q.reshape(-1, q.shape[1], self.head_dim)  # B,H,L,D/H  → BH,L,D/H

        # k,v操作与q相同
        x_k = self.key_L(k).reshape(k.shape[0], k.shape[1], self.head, self.head_dim)
        x_k = x_k.transpose(1, 2)
        x_k = x_k.reshape(-1, k.shape[1], self.head_dim)

        x_v = self.value_L(v).reshape(v.shape[0], v.shape[1], self.head, self.head_dim)
        x_v = x_v.transpose(1, 2)
        x_v = x_v.reshape(-1, v.shape[1], self.head_dim)

        # 注意力机制计算，这里需要对x_K进行转置才符合运算规则
        x_k = x_k.transpose(1, 2)  # BH,L,BH  →  BH,D/H,L
        x_atten = torch.matmul(x_q, x_k) / (self.head_dim ** 0.5)  # q,k相乘并除以根号D → BH,L,L
        x_atten = F.softmax(x_atten, dim=-1)

        x_out = torch.matmul(x_atten, x_v)  # → BH,L,D/H
        x_out = x_out.reshape(-1, self.head, x_out.shape[1], x_out.shape[2])  # BH,L,D/H → B,H,L,D/H
        x_out = x_out.transpose(1, 2)  # B,H,L,D/H → B,L,H,D/H
        x = x_out.reshape(-1, x_out.shape[1], self.head * self.head_dim)  # B,L,H,D/H->B,L,D
        return x

## 7 Encoder

In [7]:
class Encoder(nn.Module):
    def __init__(self, emb_dim=128, head=8):
        """
        编码器
        :param emb_dim: 嵌入维度
        :param head: 多头头数
        """
        super(Encoder, self).__init__()
        self.Attention = Attention(emb_dim, head)  # 注意力机制
        # 前馈全连接子层
        self.l1 = nn.Linear(emb_dim, 1024)
        self.l2 = nn.Linear(1024, emb_dim)
        # 规范化层
        self.norm1 = nn.LayerNorm(emb_dim)
        self.norm2 = nn.LayerNorm(emb_dim)

    def forward(self, q, k, v):  # 前向传播计算
        # 注意力机制
        x = self.Attention(q, k, v)
        # 规范化层
        x = self.norm1(x + q)
        # 全连接层
        x_ = self.l1(x)
        x_ = F.gelu(x_)
        x_ = self.l2(x_)
        # 规范化层
        x = self.norm2(x + x_)
        return x

## 8 VIT

In [8]:
class VIT(nn.Module):
    def __init__(self, channel, batchsize, psize, patchsize, emb_dim, head, device, N=3):
        """
        VIT模型
        :param batchsize: 批量
        :param psize: 用于位置编码的一个参数，它的大小等于  图片通道数 * (一张图片一行数据的大小//patchsize)²
        :param patchsize: 图片块边长
        :param emb_dim: 嵌入维度
        :param head: 多头
        :param device: 运算设备
        """
        super(VIT, self).__init__()
        self.Embed = Embedding(channel, batchsize, psize, patchsize, emb_dim, device)  # 词嵌入层
        self.Encoder = torch.nn.ModuleList([Encoder(emb_dim, head) for _ in range(N)])
        # 用于分类的全连接层
        self.l1 = nn.Linear(emb_dim, 1024)
        self.l2 = nn.Linear(1024, 10)  # CIFAR10 10分类

    def forward(self, x):
        #  词嵌入层
        x = self.Embed(x)
        #  编码器层
        for i in self.Encoder:
            x = i(x, x, x)
        #  分类层
        x = self.l1(x)
        x = F.relu(x)
        x = self.l2(x)
        return x

## 9 准确率函数

In [9]:
def testacc(model, test, epoch, device):
    """
    测试准确率
    :param model: 模型
    :param test: 测试集
    :param epoch: 第epoch轮
    :param device: 设备
    :return:
    """
    all = 0  # 样本总数
    right = 0  # 正确个数
    model.eval()
    for i, (data, label) in enumerate(test):
        all += 128
        data = data.to(device)
        label = label.to(device)
        pre = model(data)[:, 0, :]
        pre = torch.argmax(pre, dim=-1)  # 获取最大值标签
        label=torch.argmax(label, dim=-1)
        right += (pre == label).sum()  # 统计每轮正确的数量
    print(epoch, right / all)
    return right / all


## 10 训练函数

In [10]:
def train(path, batchsize, patchsize, emb_dim=256, head=8, device='cpu', lr=3e-4, N=6):
    """
    训练模型
    :param path: 数据集路径
    :param batchsize: 批量大小
    :param patchsize: 块大小
    :param emb_dim: 嵌入纬度
    :param head: 多头
    :param device: 设备
    :param lr: 学习率
    :param N: Encoder层数
    :return: 模型
    """
    train, test = get_dataset(path, batchsize)
    # 损失函数
    lossf = nn.CrossEntropyLoss()

    # 用于位置编码的一个参数，它的大小等于  图片通道数 * (一张图片一行数据的大小//patchsize)²
    psize = (32 // patchsize) * (32 // patchsize) + 1
    channel = 3  # 图片通道数

    # 创建VIT模型
    model = VIT(channel, batchsize, psize, patchsize, emb_dim, head, device, N=N)
    # 设置优化器
    optm = torch.optim.Adam(model.parameters(), lr=lr)
    model = model.to(device)
    loss_all=[]
    acc_=[]
    for epo in range(30):
        model.train()
        for i, (data, label) in enumerate(train):
            data = data.to(device)
            label = label.to(device)
            optm.zero_grad()
            pre = model(data)[:, 0, :]
            loss = lossf(pre, label.float())
            loss.backward()
            optm.step()
            if i%60==0:
                loss_all.append(float(loss))
                print(epo,i)
                acc_.append(testacc(model, test, epo, device)) 
    with open('loss.txt','w',encoding="utf-8") as f:
        f.write(str(loss_all))
    with open('acc.txt','w',encoding="utf-8") as f:
        f.write(str(acc_))
    return model

In [11]:
batchsize = 128
patchsize = 4
path = r'C:\Users\30535\Desktop\CodeProgram\Python\deepstudy\data'

model = train(path, batchsize, patchsize, device='cuda')

0 0
0 tensor(0.0992, device='cuda:0')
0 60
0 tensor(0.2447, device='cuda:0')
0 120
0 tensor(0.3160, device='cuda:0')
0 180
0 tensor(0.3461, device='cuda:0')
0 240
0 tensor(0.3861, device='cuda:0')
0 300
0 tensor(0.4229, device='cuda:0')
0 360
0 tensor(0.4019, device='cuda:0')
1 0
1 tensor(0.4522, device='cuda:0')
1 60
1 tensor(0.4715, device='cuda:0')
1 120
1 tensor(0.4833, device='cuda:0')
1 180
1 tensor(0.4949, device='cuda:0')
1 240
1 tensor(0.4910, device='cuda:0')
1 300
1 tensor(0.4928, device='cuda:0')
1 360
1 tensor(0.5100, device='cuda:0')
2 0
2 tensor(0.5118, device='cuda:0')
2 60
2 tensor(0.5145, device='cuda:0')
2 120
2 tensor(0.5345, device='cuda:0')
2 180
2 tensor(0.5357, device='cuda:0')
2 240
2 tensor(0.5302, device='cuda:0')
2 300
2 tensor(0.5432, device='cuda:0')
2 360
2 tensor(0.5402, device='cuda:0')
3 0
3 tensor(0.5415, device='cuda:0')
3 60
3 tensor(0.5611, device='cuda:0')
3 120
3 tensor(0.5490, device='cuda:0')
3 180
3 tensor(0.5628, device='cuda:0')
3 240
3 tens

28 300
28 tensor(0.6014, device='cuda:0')
28 360
28 tensor(0.6080, device='cuda:0')
29 0
29 tensor(0.6072, device='cuda:0')
29 60
29 tensor(0.6119, device='cuda:0')
29 120
29 tensor(0.6162, device='cuda:0')
29 180
29 tensor(0.6046, device='cuda:0')
29 240
29 tensor(0.6069, device='cuda:0')
29 300
29 tensor(0.6171, device='cuda:0')
29 360
29 tensor(0.6123, device='cuda:0')


In [15]:
%matplotlib qt
import matplotlib.pyplot as plt
import re 
with open('loss.txt','r',encoding="utf-8") as f:
    data=eval(f.read())
# d=[]
# s='\d.\d+'
# for i in data:
#     aa=re.findall(s,i)[0]
#     d.append(float(aa))

plt.figure()
plt.plot([i*60 for i in range(len(data))],data)
plt.xlabel('batch_num')
plt.ylabel('loss')
plt.show()