## 1. 設定

In [None]:
from google.colab import drive
import zipfile
import os

# Google Driveをマウント（既にマウント済みの場合はスキップ）

# マウント
drive.mount("/content/drive")

# zipファイルのパス（例：Google Drive上のアップロード先）
zip_path = "/content/drive/MyDrive/algonauts_2023_challenge_data.zip"  # 変更してください

# 展開先ディレクトリ
extract_path = "/content/drive/MyDrive/"  # 変更してください

# zipファイルの確認と展開
if os.path.exists(zip_path):
    print(f"Extracting {zip_path} to {extract_path}")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_path)
    print("Extraction complete!")
else:
    print(f"Zip file not found at: {zip_path}")

In [None]:
# =============================================================================
# 学習設定 - 必要に応じて変更してください
# =============================================================================

# 学習モード: "dummy" | "light" | "standard"
TRAIN_MODE = "dummy"

# 被験者ID
SUBJECT = "subj01"

# データパス（Google Drive上）
DATA_ROOT = "/content/drive/MyDrive/algonauts_2023_challenge_data"

# チェックポイント保存先
CHECKPOINT_DIR = "/content/drive/MyDrive/mindeye_checkpoints"

# 既存の学習済みモデル（転移学習用、Noneの場合はスキップ）
PRETRAINED_CKPT = None  # 例: "/content/drive/MyDrive/train_logs/multisubject_subj01_1024hid_nolow_300ep"

# =============================================================================
# モード別設定（自動設定）
# =============================================================================
if TRAIN_MODE == "dummy":
    HIDDEN_DIM = 256
    BATCH_SIZE = 2
    NUM_EPOCHS = 1
    USE_PRIOR = False
    BLURRY_RECON = False
    DUMMY_MODE = True
elif TRAIN_MODE == "light":
    HIDDEN_DIM = 512
    BATCH_SIZE = 4
    NUM_EPOCHS = 10
    USE_PRIOR = False
    BLURRY_RECON = False
    DUMMY_MODE = False
else:  # standard
    HIDDEN_DIM = 1024
    BATCH_SIZE = 8
    NUM_EPOCHS = 50
    USE_PRIOR = True
    BLURRY_RECON = False
    DUMMY_MODE = False

print(f"Mode: {TRAIN_MODE}")
print(f"Hidden Dim: {HIDDEN_DIM}")
print(f"Batch Size: {BATCH_SIZE}")
print(f"Epochs: {NUM_EPOCHS}")
print(f"Dummy Mode: {DUMMY_MODE}")

## 2. 環境構築

In [None]:
# GPUの確認
!nvidia-smi

In [None]:
# 依存パッケージのインストール
!pip install -q torch torchvision einops omegaconf h5py tqdm accelerate

# ダミーモードでない場合は追加パッケージをインストール
if not DUMMY_MODE:
    !pip install -q open_clip_torch diffusers transformers kornia webdataset dalle2_pytorch

In [None]:
# リポジトリのクローン（既存の場合はスキップ）
import os
if not os.path.exists("/content/MindEyeV2"):
    # TODO: 自分のフォークしたリポジトリURLに変更してください
    !git clone https://github.com/YOUR_USERNAME/MindEyeV2.git /content/MindEyeV2
else:
    print("Repository already exists")

# パスを追加
import sys
sys.path.insert(0, "/content/MindEyeV2/mysrc")
sys.path.insert(0, "/content/MindEyeV2/src")

os.chdir("/content/MindEyeV2")
print(f"Working directory: {os.getcwd()}")

In [None]:
# Google Driveをマウント
from google.colab import drive
drive.mount("/content/drive")

# データの存在確認
import os
subj_dir = os.path.join(DATA_ROOT, SUBJECT)
if os.path.exists(subj_dir):
    print(f"✓ Data found: {subj_dir}")
    !ls -la {subj_dir}
else:
    print(f"✗ Data NOT found: {subj_dir}")
    print("Please upload Algonauts2023 data to Google Drive")

## 3. データ読み込み

In [None]:
import torch
import numpy as np
from tqdm.auto import tqdm

# 環境変数でダミーモードを設定
os.environ["MINDEYE_DUMMY"] = "1" if DUMMY_MODE else "0"

# mysrcモジュールをインポート
from algonauts_dataset import AlgonautsDataset, get_dataloader, get_total_vertices
from config import print_config, DEVICE

print_config()
print(f"\nDevice: {DEVICE}")
print(f"Total vertices for {SUBJECT}: {get_total_vertices(SUBJECT)}")

In [None]:
# データローダー作成
train_loader = get_dataloader(
    data_root=DATA_ROOT,
    subject=SUBJECT,
    split="train",
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=2,
)

print(f"\nDataset size: {len(train_loader.dataset)}")
print(f"Number of batches: {len(train_loader)}")

# サンプルバッチを取得
sample_batch = next(iter(train_loader))
print(f"\nSample batch:")
print(f"  fMRI shape: {sample_batch['fmri'].shape}")
print(f"  Image shape: {sample_batch['image'].shape}")

## 4. モデル作成

In [None]:
from models_algonauts import AlgonautsMindEye, create_algonauts_model
from transfer_utils import (
    load_pretrained_without_ridge,
    freeze_layers,
    get_trainable_params,
    print_parameter_summary,
)

# モデル作成
model = create_algonauts_model(
    subjects=[SUBJECT],
    hidden_dim=HIDDEN_DIM,
    seq_len=1,
    n_blocks=4,
    use_prior=USE_PRIOR,
    blurry_recon=BLURRY_RECON,
    device=DEVICE,
)

print("Model created!")
print_parameter_summary(model)

In [None]:
# 転移学習（既存ckptがある場合）
if PRETRAINED_CKPT and os.path.exists(PRETRAINED_CKPT):
    print(f"Loading pretrained weights from: {PRETRAINED_CKPT}")
    loaded, missing = load_pretrained_without_ridge(
        model,
        PRETRAINED_CKPT,
        freeze_backbone=True,
        freeze_prior=True,
    )
    print(f"\nAfter transfer learning:")
    print_parameter_summary(model)
else:
    print("No pretrained checkpoint specified. Training from scratch.")
    # スクラッチ学習の場合はbackboneもfreezeしない
    pass

## 5. CLIP特徴の準備

In [None]:
# ダミーモードの場合はダミーCLIPを使用
if DUMMY_MODE:
    from dummy_models import DummyCLIPImageEmbedder, get_dummy_clip_features
    
    clip_embedder = DummyCLIPImageEmbedder().to(DEVICE)
    print("Using DummyCLIPImageEmbedder")
else:
    # 本物のCLIPを使用
    try:
        import open_clip
        
        # ViT-bigG-14 のロード（重い）
        print("Loading OpenCLIP ViT-bigG-14... (this may take a while)")
        clip_model, _, preprocess = open_clip.create_model_and_transforms(
            "ViT-bigG-14",
            pretrained="laion2b_s39b_b160k",
        )
        clip_model = clip_model.to(DEVICE).eval()
        
        for param in clip_model.parameters():
            param.requires_grad = False
        
        print("OpenCLIP loaded!")
    except Exception as e:
        print(f"Failed to load OpenCLIP: {e}")
        print("Falling back to dummy mode")
        DUMMY_MODE = True
        from dummy_models import DummyCLIPImageEmbedder
        clip_embedder = DummyCLIPImageEmbedder().to(DEVICE)

In [None]:
def get_clip_features(images):
    """画像からCLIP特徴を抽出"""
    with torch.no_grad():
        if DUMMY_MODE:
            return clip_embedder(images)
        else:
            # OpenCLIPを使用
            features = clip_model.encode_image(images)
            return features

# テスト
test_images = sample_batch['image'].to(DEVICE)
test_features = get_clip_features(test_images)
print(f"CLIP features shape: {test_features.shape}")

## 6. 学習ループ

In [None]:
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR

# オプティマイザ（学習可能なパラメータのみ）
trainable_params = get_trainable_params(model, mode="all_unfrozen")
print(f"Trainable parameters: {sum(p.numel() for p in trainable_params):,}")

optimizer = AdamW(trainable_params, lr=3e-4, weight_decay=1e-2)

# スケジューラ
total_steps = len(train_loader) * NUM_EPOCHS
scheduler = OneCycleLR(
    optimizer,
    max_lr=3e-4,
    total_steps=total_steps,
    pct_start=0.1,
)

In [None]:
def soft_clip_loss(preds, targets, temp=0.006):
    """Soft CLIP contrastive loss"""
    # Flatten to (batch, dim)
    preds = preds.view(preds.shape[0], -1)
    targets = targets.view(targets.shape[0], -1)
    
    # Normalize
    preds = F.normalize(preds, dim=-1)
    targets = F.normalize(targets, dim=-1)
    
    # Cosine similarity
    logits = (preds @ targets.T) / temp
    labels = torch.arange(len(logits), device=logits.device)
    
    loss_i = F.cross_entropy(logits, labels)
    loss_t = F.cross_entropy(logits.T, labels)
    
    return (loss_i + loss_t) / 2

def train_step(batch):
    """1バッチの学習ステップ"""
    model.train()
    
    # データ取得
    fmri = batch['fmri'].to(DEVICE)
    images = batch['image'].to(DEVICE)
    
    # CLIP特徴を取得（ターゲット）
    with torch.no_grad():
        clip_target = get_clip_features(images)
    
    # Forward
    backbone, clip_voxels, blurry = model(fmri)
    
    # Loss計算
    # clip_voxels: (batch, seq, emb_dim)
    # clip_target: (batch, seq, emb_dim) or (batch, emb_dim)
    if clip_target.dim() == 2:
        clip_target = clip_target.unsqueeze(1)
    
    loss = soft_clip_loss(clip_voxels, clip_target)
    
    # Backward
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(trainable_params, 1.0)
    optimizer.step()
    scheduler.step()
    
    return loss.item()

In [None]:
# メモリ確認
if torch.cuda.is_available():
    print(f"GPU Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB allocated")
    print(f"GPU Memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB reserved")

In [None]:
# 学習ループ
print(f"\n{'='*60}")
print(f"Starting training: {NUM_EPOCHS} epochs, {len(train_loader)} batches/epoch")
print(f"{'='*60}\n")

losses = []

for epoch in range(NUM_EPOCHS):
    epoch_losses = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}")
    for batch in pbar:
        loss = train_step(batch)
        epoch_losses.append(loss)
        pbar.set_postfix({"loss": f"{loss:.4f}"})
    
    avg_loss = np.mean(epoch_losses)
    losses.extend(epoch_losses)
    
    print(f"Epoch {epoch+1}/{NUM_EPOCHS} - Average Loss: {avg_loss:.4f}")
    
    # メモリ確認
    if torch.cuda.is_available():
        print(f"  GPU Memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")

print(f"\n{'='*60}")
print("Training complete!")
print(f"{'='*60}")

In [None]:
# 損失の可視化
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(losses)
plt.xlabel("Step")
plt.ylabel("Loss")
plt.title("Training Loss")
plt.grid(True)
plt.show()

## 7. チェックポイント保存

In [None]:
from transfer_utils import save_checkpoint

# 保存先ディレクトリ作成
save_dir = os.path.join(CHECKPOINT_DIR, f"algonauts_{SUBJECT}_{TRAIN_MODE}")
os.makedirs(save_dir, exist_ok=True)

# チェックポイント保存
save_checkpoint(
    model=model,
    optimizer=optimizer,
    epoch=NUM_EPOCHS,
    save_path=os.path.join(save_dir, "last.pth"),
    extra_info={
        "train_mode": TRAIN_MODE,
        "subject": SUBJECT,
        "hidden_dim": HIDDEN_DIM,
        "final_loss": losses[-1] if losses else None,
    },
)

print(f"\nCheckpoint saved to: {save_dir}")

## 8. 簡易検証

In [None]:
# 推論テスト
model.eval()

with torch.no_grad():
    test_batch = next(iter(train_loader))
    test_fmri = test_batch['fmri'].to(DEVICE)
    test_images = test_batch['image'].to(DEVICE)
    
    # fMRI → CLIP tokens
    backbone, clip_voxels, blurry = model(test_fmri)
    
    # 実際のCLIP特徴
    clip_target = get_clip_features(test_images)
    if clip_target.dim() == 2:
        clip_target = clip_target.unsqueeze(1)
    
    # コサイン類似度
    pred_flat = F.normalize(clip_voxels.view(clip_voxels.shape[0], -1), dim=-1)
    target_flat = F.normalize(clip_target.view(clip_target.shape[0], -1), dim=-1)
    
    similarity = (pred_flat * target_flat).sum(dim=-1).mean()
    
    print(f"\nInference test:")
    print(f"  Input fMRI shape: {test_fmri.shape}")
    print(f"  Output CLIP shape: {clip_voxels.shape}")
    print(f"  Average cosine similarity: {similarity.item():.4f}")

In [None]:
# 入力画像の可視化
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

def denormalize(tensor):
    """ImageNet正規化を元に戻す"""
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to(tensor.device)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to(tensor.device)
    return tensor * std + mean

# サンプル画像を表示
sample_images = denormalize(test_images[:4])
grid = make_grid(sample_images, nrow=4).cpu().permute(1, 2, 0).numpy()
grid = np.clip(grid, 0, 1)

plt.figure(figsize=(12, 3))
plt.imshow(grid)
plt.title("Sample Training Images")
plt.axis("off")
plt.show()

## 次のステップ

1. **ダミーモードで動作確認** → エラーなく完了すればOK
2. **軽量モード（light）で実学習** → T4で数時間
3. **標準モード（standard）で本格学習** → Pro or 研究室PC
4. **推論ノートブック** → `mindeye_inference_colab.ipynb` で画像再構成