In [None]:
# =============================================================================
#           03_dinov2_GIANT_train.ipynb: DINOv2 Giant 版训练脚本
# =============================================================================

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error
from tqdm.notebook import tqdm
from PIL import Image
import os
import gc #用于清理内存

# --- 释放显存 (防止之前的模型占用) ---
torch.cuda.empty_cache()
gc.collect()

# -----------------------------------------------------------------------------
# 1. 配置参数 (关键修改)
# -----------------------------------------------------------------------------
IMAGE_DIR = './'
TRAIN_CSV = 'train.csv'
# !! 警告：Giant 模型非常大，Batch Size 必须调得非常小 !!
BATCH_SIZE = 4             # 建议从 4 开始尝试，如果爆显存就改成 2 或 1
NUM_EPOCHS = 15            # Giant 收敛快，15轮通常够了
LR = 1e-3                  # 因为冻结了 backbone，学习率可以稍大
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

print(f"正在使用设备: {device} | Batch Size: {BATCH_SIZE}")

# -----------------------------------------------------------------------------
# 2. 数据准备 (保持不变)
# -----------------------------------------------------------------------------
df = pd.read_csv(TRAIN_CSV)
df_wide = pd.pivot_table(df, index=['image_path'], columns='target_name', values='target', aggfunc='mean').reset_index()
target_cols = ['Dry_Clover_g', 'Dry_Dead_g', 'Dry_Green_g', 'GDM_g', 'Dry_Total_g']
train_df, val_df = train_test_split(df_wide, test_size=0.2, random_state=42)

# -----------------------------------------------------------------------------
# 3. 定义类：VisualModel (Giant 版)
# -----------------------------------------------------------------------------
class BiomassDataset(Dataset):
    def __init__(self, dataframe, image_dir, target_cols, transform=None):
        self.df = dataframe
        self.image_dir = image_dir
        self.transform = transform
        self.targets = np.log1p(self.df[target_cols].values.astype(np.float32))

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img_path_rel = self.df.iloc[idx]['image_path']
        img_path = os.path.join(self.image_dir, img_path_rel)
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        target = torch.tensor(self.targets[idx], dtype=torch.float)
        return image, target

class DINOv2VisualModel(nn.Module):
    def __init__(self, num_targets=5):
        super(DINOv2VisualModel, self).__init__()
        
        print("正在加载 DINOv2 (GIANT) 模型... 模型文件很大(>4GB)，请耐心等待...")
        # !! 修改 1: 加载 Giant 版本 (dinov2_vitg14) !!
        self.backbone = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14')
        
        # !! 修改 2: 必须强制冻结 !!
        # 在 300 张图上训练 11亿参数的模型，如果不冻结绝对会过拟合
        for param in self.backbone.parameters():
            param.requires_grad = False
            
        # !! 修改 3: Giant 的输出维度是 1536 !!
        self.embed_dim = 1536
        
        # 定义回归头
        self.head = nn.Sequential(
            nn.Linear(self.embed_dim, 512), # 中间层也相应变宽
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.4), # 增加 Dropout 防止过拟合
            nn.Linear(512, num_targets)
        )

    def forward(self, image):
        features = self.backbone.forward_features(image)['x_norm_clstoken']
        output = self.head(features)
        return output

# -----------------------------------------------------------------------------
# 4. 初始化
# -----------------------------------------------------------------------------
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)), # Giant 也可以处理更大分辨率，但为了显存考虑保持 224
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

train_dataset = BiomassDataset(train_df.reset_index(drop=True), IMAGE_DIR, target_cols, transform=data_transforms['train'])
val_dataset = BiomassDataset(val_df.reset_index(drop=True), IMAGE_DIR, target_cols, transform=data_transforms['val'])

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

model = DINOv2VisualModel(num_targets=len(target_cols)).to(device)

criterion = nn.MSELoss()
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=1e-4)

# -----------------------------------------------------------------------------
# 5. 训练循环
# -----------------------------------------------------------------------------
best_rmse = float('inf')
save_path = 'best_dinov2_giant_model.pth'

print(f"\n开始训练 DINOv2 GIANT 模型 (Batch Size: {BATCH_SIZE})...")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    # Giant 模型推理很慢，进度条可能会走得慢一些
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [训练]")
    
    for images, targets in pbar:
        images, targets = images.to(device), targets.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
    
    epoch_train_loss = running_loss / len(train_dataset)

    model.eval()
    all_preds = []
    all_targets = []
    
    with torch.no_grad():
        pbar_val = tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} [验证]")
        for images, targets in pbar_val:
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            preds_orig = np.expm1(outputs.cpu().numpy())
            targets_orig = np.expm1(targets.cpu().numpy())
            all_preds.append(preds_orig)
            all_targets.append(targets_orig)
            
    val_rmse = np.sqrt(mean_squared_error(np.concatenate(all_targets), np.concatenate(all_preds)))
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} -> 训练损失: {epoch_train_loss:.4f} | 验证 RMSE: {val_rmse:.4f}")
    
    if val_rmse < best_rmse:
        best_rmse = val_rmse
        torch.save(model.state_dict(), save_path)
        print(f"  -> 新的最佳模型(Giant)已保存: {save_path} (RMSE: {best_rmse:.4f})")

print(f"\n--- 训练完成 ---\n最好的 DINOv2 Giant 模型 RMSE 是: {best_rmse:.4f}")
