In [None]:
import torch
import torchvision.transforms as transforms
from PIL import Image
import os
import csv
from timm import create_model

# 定义测试集的数据转换
test_transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# 测试集路径
test_data_path = '/kaggle/input/threetest/test'  # 替换为您的测试集路径

# 加载 Swin Transformer 模型
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = create_model('swin_base_patch4_window7_224', pretrained=False, num_classes=3)  # 创建 Swin 模型
model.load_state_dict(torch.load('/kaggle/working/checkpoints/swin_best.pth'))  # 加载训练好的权重
model.to(device)
model.eval()

# 保存预测结果
output_csv_path = '/kaggle/working/predictions_swin.csv'
class_names = ['Class_0', 'Class_1', 'Class_2']  # 替换为实际类别名称

with open(output_csv_path, 'w', newline='') as csvfile:
    fieldnames = ['ImageName', 'Prediction']
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()

    for image_name in os.listdir(test_data_path):
        image_path = os.path.join(test_data_path, image_name)
        try:
            image = Image.open(image_path).convert('RGB')
            image_tensor = test_transform(image).unsqueeze(0).to(device)  # 转换并添加批次维度

            with torch.no_grad():
                outputs = model(image_tensor)
                _, prediction = torch.max(outputs, 1)

            writer.writerow({'ImageName': image_name, 'Prediction': class_names[prediction.item()]})
        except Exception as e:
            print(f"Error processing {image_name}: {e}")

print(f'Predictions saved to {output_csv_path}')
