In [None]:
import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from timm import create_model
from tensorboardX import SummaryWriter

# 切换目录到数据集
os.chdir('/kaggle/input/threeclass/Three')

# 数据增强和预处理
train_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomAffine(degrees=15, scale=(0.8, 1.5)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 加载数据集
trainset = datasets.ImageFolder(root='/kaggle/input/threeclass/Three/train', transform=train_transform)
train_loader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=4)

valset = datasets.ImageFolder(root='/kaggle/input/threeclass/Three/val', transform=val_transform)
val_loader = DataLoader(valset, batch_size=64, shuffle=False, num_workers=4)

# 使用 Swin Transformer 模型
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = create_model('swin_base_patch4_window7_224', pretrained=True, num_classes=3)  # 创建 Swin 模型
model.to(device)

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.05)

# TensorBoard日志
writer = SummaryWriter(log_dir='/kaggle/working/runs')
BEST_VAL_ACC = 0.0

# 开始训练
for epoch in range(100):
    model.train()
    epoch_loss = 0.0
    total = 0
    correct = 0
    run_accuracy = 0.0
    run_loss = 0.0
    start_time = time.time()

    for i, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        
        # 前向传播
        outputs = model(images)
        
        # 计算损失
        loss = criterion(outputs, labels)
        loss.backward()

        optimizer.step()
        
        run_loss += loss.item()

        _, preds = torch.max(outputs, 1)
        run_accuracy += (preds == labels).sum().item()

        if i % 20 == 19:
            print(f'Epoch {epoch}, Iter {i}, Loss: {run_loss / 20:.4f}, Accuracy: {100 * run_accuracy / (20 * labels.size(0)):.2f}%')
            run_loss = 0.0
            run_accuracy = 0.0

    # 记录训练损失
    writer.add_scalar('Loss/train', run_loss, epoch)

    # 验证集评估
    model.eval()
    accuracy = 0.0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)

            # 计算准确率
            _, preds = torch.max(outputs, 1)
            accuracy += (preds == labels).sum().item()
            total += labels.size(0)

    val_accuracy = accuracy / total * 100
    print(f'Epoch {epoch} Validation Accuracy: {val_accuracy:.4f}%')

    writer.add_scalar('Accuracy/val', val_accuracy, epoch)

    # 保存最佳模型
    if val_accuracy > BEST_VAL_ACC:
        BEST_VAL_ACC = val_accuracy
        if not os.path.isdir('/kaggle/working/checkpoints'):
            os.mkdir('/kaggle/working/checkpoints')
        torch.save(model.state_dict(), '/kaggle/working/checkpoints/swin_best.pth')

    # 打印时间
    time_elapsed = time.time() - start_time
    print(f'Epoch finished in: {time_elapsed // 60}m {time_elapsed % 60}s')

print(f'Training complete. Best validation accuracy: {BEST_VAL_ACC:.4f}%')
