In [2]:
import os
import pandas as pd
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from timm import create_model
from sklearn.model_selection import KFold
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
class Config():
    DATA_ROOT = r'E:\ISICDM2025\competition_image'
    CSV_PATH = os.path.join(DATA_ROOT, 'isicdm2025.csv')
    WEIGHTS_DIR = r'E:\ISICDM2025'
    IMG_SIZES = [384, 512, 768]  # 可选: 384, 512, 768
    BATCH_SIZE = 16
    NUM_EPOCHS = 200
    LEARNING_RATE = 5e-5
    PATIENCE = 7
    RANDOM_SEED = 42
    
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class ISICDMDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        row = self.dataframe.iloc[idx]
        # 构造完整路径: E:\ISICDM2025\competition_image\train\ISICDM2025_000399.png
        img_path = os.path.join(Config.DATA_ROOT, row['split'], row['crop_filename'])
        image = Image.open(img_path).convert("L")  # 单通道灰度图
        label = row['category_id']

        if self.transform:
            image = self.transform(image)

        return image, label

class EarlyStopping:
    def __init__(self, patience=Config.PATIENCE, verbose=False, delta=0, path='best_model.pth'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = float('inf')
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model, optimizer, epoch):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer, epoch)
        elif score < self.best_score + self.delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, optimizer, epoch)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, optimizer, epoch):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
        }, self.path)
        self.val_loss_min = val_loss

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    correct = 0
    total = 0

    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    acc = 100. * correct / total
    return total_loss / len(loader), acc

def validate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    acc = 100. * correct / total
    return total_loss / len(loader), acc

        
for IMG_SIZE in Config.IMG_SIZES:
    WEIGHT_PATH = os.path.join(Config.WEIGHTS_DIR, f'{IMG_SIZE}_efficientnet_b0_expand.pth')
    SAVE_PATH = f'finetuned_all_{IMG_SIZE}_efficientnet_b0.pth'
    temp_df = pd.read_csv(Config.CSV_PATH)
    train_temp_df = temp_df[temp_df['split'] == 'train']  # 仅用原始train计算统计量，避免数据泄露
    
    temp_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
    ])
    
    temp_dataset = ISICDMDataset(train_temp_df, temp_transform)
    temp_loader = DataLoader(temp_dataset, batch_size=64, pin_memory=True)
    
    mean = 0.
    std = 0.
    nb_samples = 0.
    
    for data, _ in temp_loader:
        batch_samples = data.size(0)
        data = data.view(batch_samples, data.size(1), -1)  # [B, 1, H*W]
        mean += data.mean(2).sum(0)
        std += data.std(2).sum(0)
        nb_samples += batch_samples
    
    MEAN = [mean.item() / nb_samples]
    STD = [std.item() / nb_samples]
    
    train_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.3, contrast=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])
    
    val_transform = transforms.Compose([
        transforms.Resize((IMG_SIZE, IMG_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(mean=MEAN, std=STD)
    ])

    df = pd.read_csv(Config.CSV_PATH)
    all_data_df = pd.concat([
        df[df['split'] == 'train'],
        df[df['split'] == 'val'],
        df[df['split'] == 'test']
    ], ignore_index=True)
    
    kf = KFold(n_splits=5, shuffle=True, random_state=Config.RANDOM_SEED)
    all_metrics = []
    for fold, (train_idx, val_idx) in enumerate(kf.split(all_data_df), 1):
        print(f'====== Starting Fold {fold} ======')
        
        # Split data for current fold
        train_df = all_data_df.iloc[train_idx].reset_index(drop=True)
        val_df = all_data_df.iloc[val_idx].reset_index(drop=True)
    
        # Create datasets and dataloaders
        train_dataset = ISICDMDataset(train_df, train_transform)
        val_dataset = ISICDMDataset(val_df, val_transform)
    
        train_loader = DataLoader(train_dataset, batch_size=Config.BATCH_SIZE, shuffle=True, pin_memory=True)
        val_loader = DataLoader(val_dataset, batch_size=Config.BATCH_SIZE, shuffle=False, pin_memory=True)
    
        # Initialize model
        model = create_model('tf_efficientnet_b0.ns_jft_in1k', pretrained=False, in_chans=1)
        checkpoint = torch.load(WEIGHT_PATH, map_location=DEVICE, weights_only=False)
    
        num_classes = 7
        model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    
        # Load pretrained weights (excluding classifier)
        pretrained_dict = checkpoint['model_state_dict']
        model_dict = model.state_dict()
        filtered_dict = {
            k: v for k, v in pretrained_dict.items()
            if k in model_dict and 'classifier' not in k  # 跳过分类头
        }
        model_dict.update(filtered_dict)
        model.load_state_dict(model_dict)
        model = model.to(DEVICE)
    
        # Loss, optimizer, scheduler
        criterion = nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(model.parameters(), lr=Config.LEARNING_RATE)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max = Config.NUM_EPOCHS * len(train_loader),
            eta_min=1e-6
        )
    
        # Early stopping for this fold
        early_stopping = EarlyStopping(patience=PATIENCE, verbose=True, path=f'{SAVE_PATH}_fold_{fold}.pth')
    
        # Training loop
        for epoch in range(Config.NUM_EPOCHS):
            print(f"\n---------- Epoch {epoch+1}/{Config.NUM_EPOCHS} ----------")
    
            # Training step
            train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, DEVICE)
            
            # Validation step
            val_loss, val_acc = validate(model, val_loader, criterion, DEVICE)
    
            # Print metrics
            print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
            print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
    
            # Early stopping check
            early_stopping(val_loss, model, optimizer, epoch)
            if early_stopping.early_stop:
                print(f"Early stopping triggered for Fold {fold} at epoch {epoch+1}")
                break
    
            # Scheduler step
            scheduler.step()
            # current_lr = scheduler.get_last_lr()[0] # Not used, removed
    
        # Record best metric for this fold
        all_metrics.append({'fold': fold, 'best_val_loss': early_stopping.val_loss_min})
        print(f'====== Completed Fold {fold} ======\n')
    print(f"\n{IMG_SIZE} - 微调训练完成！各折最佳模型已保存至:")
    for i in range(5):
        print(f"  - {SAVE_PATH}_fold_{i+1}.pth")
    print("\n各折最佳验证损失:")
    for metric in all_metrics:
        print(f"  - Fold {metric['fold']}: {metric['best_val_loss']:.6f}")