In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torchvision.models import resnet50  
from torch.utils.data import DataLoader, SubsetRandomSampler 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir='runs/result')

In [3]:
# 1. 设置设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 2. 加载并预处理CIFAR-100数据集
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # ViT期望的输入尺寸
    transforms.ToTensor(),
    transforms.Normalize(0.5, 0.5)
])

In [11]:
type(trainset)

torchvision.datasets.cifar.CIFAR100

In [22]:
trainset = torchvision.datasets.CIFAR100(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True)

testset = torchvision.datasets.CIFAR100(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [23]:
# 假设我们只想要前50%的样本  
num_samples = len(trainset)  
indices = list(range(num_samples))  
split = int(0.05 * num_samples)  
subset_indices = indices[:split]  
  
# 使用SubsetRandomSampler来随机选择样本（但在这里我们只是按顺序选择）  
subset_sampler = SubsetRandomSampler(subset_indices)  
  
num_test_samples = len(testset)
indices = list(range(num_test_samples))  
split = int(0.02 * num_test_samples)  
test_subset_indices = indices[:split] 

test_subset_sampler = SubsetRandomSampler(test_subset_indices)  

# 创建DataLoader  
trainloader = DataLoader(trainset, batch_size=64, sampler=subset_sampler) 
testloader = DataLoader(testset, batch_size=64, sampler=test_subset_sampler) 

In [24]:
# 3.1.1 定义ViT模型
weights = ViT_B_16_Weights.DEFAULT
model = vit_b_16(weights=weights)
model.heads[0] = nn.Linear(model.heads[0].in_features, 100)  # 修改分类头为100类

# 如果有可用的GPU，则将模型转到GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

# 3.1.2. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [4]:
# 4.1.1 定义CNN模型  
class SimpleCNN(nn.Module):  
    def __init__(self, num_classes=100):  
        super(SimpleCNN, self).__init__()  
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)  
        self.relu = nn.ReLU(inplace=True)  
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)  
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)  
        self.fc = nn.Linear(64 * 7 * 7, num_classes)  # 假设输入图像是224x224，经过两次卷积和池化后，特征图大小为7x7  
  
    def forward(self, x):  
        x = self.conv1(x)  
        x = self.relu(x)  
        x = self.maxpool(x)  
        x = self.conv2(x)  
        x = self.relu(x)  
        x = self.maxpool(x)  
        x = x.view(x.size(0), -1)  # 展平特征图  
        x = self.fc(x)  
        return x  
# 实例化模型  
cnn_model = SimpleCNN(num_classes=100)  
  
# 如果有可用的GPU，则将模型转到GPU  
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")  
cnn_model.to(device)  
  
# 4.1.2 定义损失函数和优化器  
cnn_criterion = nn.CrossEntropyLoss()  
cnn_optimizer = optim.Adam(cnn_model.parameters(), lr=0.001)

In [26]:
# 5. 训练模型
val_acc_list = []
train_loss_list = []
val_loss_list = []
for epoch in range(10):  # 遍历数据集多次
    train_loss = 0 
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        cnn_optimizer.zero_grad()
        
        outputs = model(inputs)
        cnn_outputs = cnn_model(inputs)
        
        loss = criterion(outputs, labels)
        cnn_loss = cnn_criterion(cnn_outputs, labels)
        
        loss.backward()
        cnn_loss.backward()
        
        optimizer.step()
        cnn_optimizer.step()

        running_loss += loss.item()
        train_loss   += loss.item()
        if i % 3 == 2:  # 每20个批次打印一次
            print("epoch:"+f'{epoch + 1}, batch:{i + 1}, loss: {running_loss / 10:.3f}')
            running_loss = 0.0
        # print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss}')

    train_loss_list.append(train_loss / len(trainloader))
    val_loss = 0
    # 6. 评估模型
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            val_loss += criterion(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc_list.append(correct / total)
    
    writer.add_scalar(tag="train_accuracy", 
                      scalar_value=correct / total, 
                      global_step=epoch  
                      )
    
    writer.add_scalar(tag="train_loss", 
                      scalar_value=train_loss / len(trainloader),  
                      global_step=epoch  
                      )
    writer.add_scalar(tag="loss",
                      scalar_value=val_loss,  
                      global_step=epoch  
                      )
    train_loss = 0
print('Finished Training')

epoch:1, batch:3, loss: 1.454
epoch:1, batch:6, loss: 1.469
epoch:1, batch:9, loss: 1.413
epoch:1, batch:12, loss: 1.432
epoch:1, batch:15, loss: 1.415
epoch:1, batch:18, loss: 1.391
epoch:1, batch:21, loss: 1.389
epoch:1, batch:24, loss: 1.401
epoch:1, batch:27, loss: 1.388
epoch:1, batch:30, loss: 1.387
epoch:1, batch:33, loss: 1.392
epoch:1, batch:36, loss: 1.392
epoch:1, batch:39, loss: 1.404
epoch:2, batch:3, loss: 1.373
epoch:2, batch:6, loss: 1.368
epoch:2, batch:9, loss: 1.379
epoch:2, batch:12, loss: 1.373
epoch:2, batch:15, loss: 1.367
epoch:2, batch:18, loss: 1.369
epoch:2, batch:21, loss: 1.373
epoch:2, batch:24, loss: 1.378
epoch:2, batch:27, loss: 1.364
epoch:2, batch:30, loss: 1.370
epoch:2, batch:33, loss: 1.393
epoch:2, batch:36, loss: 1.361
epoch:2, batch:39, loss: 1.380
epoch:3, batch:3, loss: 1.340
epoch:3, batch:6, loss: 1.369
epoch:3, batch:9, loss: 1.357
epoch:3, batch:12, loss: 1.354
epoch:3, batch:15, loss: 1.369
epoch:3, batch:18, loss: 1.363
epoch:3, batch:21