In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import random
import os

In [2]:
# 设置多GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.backends.cudnn.benchmark = True

In [3]:
# 定义数据增强策略，包含CutMix
def cutmix_data(x, y, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    rand_index = torch.randperm(x.size()[0]).to(device)
    target_a = y
    target_b = y[rand_index]
    bbx1, bby1, bbx2, bby2 = rand_bbox(x.size(), lam)
    x[:, :, bbx1:bbx2, bby1:bby2] = x[rand_index, :, bbx1:bbx2, bby1:bby2]
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (x.size()[-1] * x.size()[-2]))
    return x, target_a, target_b, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)
    cx = np.random.randint(W)
    cy = np.random.randint(H)
    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)
    return bbx1, bby1, bbx2, bby2

In [4]:
# 加载CIFAR-100数据集并进行预处理
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

In [5]:
# 定义CNN模型
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d(2, 2)
        )
        self.fc1 = nn.Linear(256 * 4 * 4, 1024)
        self.fc2 = nn.Linear(1024, 100)

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)  # 展平操作
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [6]:
# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        attention = self.attention(x, x, x)[0]
        x = self.dropout1(self.norm1(attention + x))
        forward = self.feed_forward(x)
        out = self.dropout2(self.norm2(forward + x))
        return out

# Transformer Classifier
class TransformerClassifier(nn.Module):
    def __init__(self, num_classes, embed_size=512, heads=8, dropout=0.1, forward_expansion=4):
        super(TransformerClassifier, self).__init__()
        self.embedding = nn.Linear(3 * 32 * 32, embed_size)
        self.transformer = TransformerBlock(embed_size, heads, dropout, forward_expansion)
        self.fc = nn.Linear(embed_size, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.embedding(x).unsqueeze(1)  # Add sequence dimension
        x = self.transformer(x)
        x = self.fc(x.squeeze(1))  # Remove sequence dimension
        return x

In [7]:
# 训练模型
def train_model(model, train_loader, test_loader, optimizer, criterion, epochs):
    model = nn.DataParallel(model)
    model.to(device)
    writer = SummaryWriter()

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, targets) in enumerate(train_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # 使用CutMix
            inputs, targets_a, targets_b, lam = cutmix_data(inputs, targets)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = lam * criterion(outputs, targets_a) + (1 - lam) * criterion(outputs, targets_b)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            if i % 100 == 99:
                print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 100}')
                running_loss = 0.0

        # 验证模型
        model.eval()
        correct = 0
        total = 0
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += targets.size(0)
                correct += (predicted == targets).sum().item()

        accuracy = 100 * correct / total
        print(f'Validation Accuracy: {accuracy:.2f}%')

        writer.add_scalar('Loss/train', running_loss, epoch)
        writer.add_scalar('Loss/val', val_loss / len(test_loader), epoch)
        writer.add_scalar('Accuracy/val', accuracy, epoch)

    writer.close()

In [8]:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [9]:
# 初始化模型和优化器
num_classes = 100

model_cnn = CNN()

criterion = nn.CrossEntropyLoss()
optimizer_cnn = optim.Adam(model_cnn.parameters(), lr=0.001)


In [10]:
train_model(model_cnn, train_loader, test_loader, optimizer_cnn, criterion, epochs=100)
torch.save(model_cnn.state_dict(), 'cnn_model.pth')


Epoch 1, Batch 100, Loss: 4.581295104026794
Epoch 1, Batch 200, Loss: 4.301905860900879
Epoch 1, Batch 300, Loss: 4.171725716590881
Validation Accuracy: 13.89%
Epoch 2, Batch 100, Loss: 4.062727067470551
Epoch 2, Batch 200, Loss: 4.010521094799042
Epoch 2, Batch 300, Loss: 3.952156367301941
Validation Accuracy: 20.73%
Epoch 3, Batch 100, Loss: 3.862959442138672
Epoch 3, Batch 200, Loss: 3.8389627861976625
Epoch 3, Batch 300, Loss: 3.8163853430747987
Validation Accuracy: 25.82%
Epoch 4, Batch 100, Loss: 3.7059045886993407
Epoch 4, Batch 200, Loss: 3.7132679438591003
Epoch 4, Batch 300, Loss: 3.690778527259827
Validation Accuracy: 30.18%
Epoch 5, Batch 100, Loss: 3.621175343990326
Epoch 5, Batch 200, Loss: 3.5585072898864745
Epoch 5, Batch 300, Loss: 3.5666079473495484
Validation Accuracy: 32.08%
Epoch 6, Batch 100, Loss: 3.5947267985343934
Epoch 6, Batch 200, Loss: 3.4393190932273865
Epoch 6, Batch 300, Loss: 3.5309542179107667
Validation Accuracy: 35.31%
Epoch 7, Batch 100, Loss: 3.480

In [11]:
# 创建模型
model_transformer = TransformerClassifier(num_classes)
# 训练参数
optimizer_transformer = optim.Adam(model_transformer.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss()
train_model(model_transformer, train_loader, test_loader, optimizer_transformer, criterion, epochs=100)


  return F.linear(input, self.weight, self.bias)


Epoch 1, Batch 100, Loss: 4.489048414230346
Epoch 1, Batch 200, Loss: 4.374918823242187
Epoch 1, Batch 300, Loss: 4.319014465808868
Validation Accuracy: 9.31%
Epoch 2, Batch 100, Loss: 4.273751482963562
Epoch 2, Batch 200, Loss: 4.260625388622284
Epoch 2, Batch 300, Loss: 4.237093567848206
Validation Accuracy: 10.16%
Epoch 3, Batch 100, Loss: 4.218091058731079
Epoch 3, Batch 200, Loss: 4.2090130305290225
Epoch 3, Batch 300, Loss: 4.203894271850586
Validation Accuracy: 10.07%
Epoch 4, Batch 100, Loss: 4.169821660518647
Epoch 4, Batch 200, Loss: 4.164753770828247
Epoch 4, Batch 300, Loss: 4.180019207000733
Validation Accuracy: 10.14%
Epoch 5, Batch 100, Loss: 4.153359167575836
Epoch 5, Batch 200, Loss: 4.139689223766327
Epoch 5, Batch 300, Loss: 4.160777134895325
Validation Accuracy: 11.03%
Epoch 6, Batch 100, Loss: 4.142562098503113
Epoch 6, Batch 200, Loss: 4.170066967010498
Epoch 6, Batch 300, Loss: 4.142720885276795
Validation Accuracy: 11.35%
Epoch 7, Batch 100, Loss: 4.096260452270

In [12]:
# 保存权重
torch.save(model_transformer.state_dict(), 'transformer_model.pth')