<a href="https://colab.research.google.com/github/ocean8800v/manuscript-pipeline/blob/main/2_finetune_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==============================
# üîß Setup Python 3.11 + TabPFN env in Colab
# ==============================

# 1. Install Conda (via condacolab)
!pip install -q condacolab
import condacolab
condacolab.install()

# 2. Create a new Conda environment with Python 3.11
!conda create -n py311 python=3.11.13 -y -c conda-forge

# 3. Install PyTorch (2.1.0 + cu121) and all required dependencies inside py311
!conda run -n py311 pip install \
    pandas==2.2.2 \
    scikit-learn==1.5.2 \
    scipy==1.15.3 \
    tabpfn==2.0.9 \
    ruff==0.12.6 \
    pynvml==12.0.0 \
    wandb==0.21.0 \
    schedulefree==1.4.1 \
    seaborn==0.12.2 \
    tqdm==4.67.1 \
    pyyaml==6.0.2\
    numpy==1.24.1
!conda run -n py311 pip install torch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121

!conda run -n py311 python -c "import sys, torch, sklearn, numpy, scipy, tabpfn, pandas; \
print(f'Python: {sys.version}'); \
print(f'Torch: {torch.__version__}'); \
print(f'scikit-learn: {sklearn.__version__}'); \
print(f'Numpy: {numpy.__version__}'); \
print(f'Pandas: {pandas.__version__}'); \
print(f'SciPy: {scipy.__version__}'); \
print(f'TabPFN: {tabpfn.__version__}')"

In [None]:
from google.colab import drive
drive.mount('/content/drive')
!git clone https://github.com/ocean8800v/cross-national-cognitive-ai.git

In [None]:
%%writefile /content/finetune_experiment.py
import os, subprocess, sys
import pandas as pd
from sklearn.model_selection import train_test_split

sys.path.append('/content/manuscript-pipeline/finetune_experiment')
from finetuning_scripts.finetune_tabpfn_main import fine_tune_tabpfn

import torch
sys.stdout.reconfigure(line_buffering=True)  # ensure real-time log flushing

subprocess.run(["nvidia-smi"], check=False)
if not torch.cuda.is_available():
    raise SystemError('GPU device not found. For fast training, please enable GPU.')

torch.cuda.empty_cache()
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(False)
torch.backends.cuda.enable_math_sdp(False)

# ========== Fine-tuning Experiment ==========
predictors = [
    "asinmiss", "logresvar", "mdistance", "gnorm", "lz", "u3", "extreme",
    "language", "country", "scale", "id"
]

target_col = 'health_vs_mci_vs_dementia'

country_codes = {'HRS': 1, 'ELSA': 2, 'LASI': 3, 'MHAS': 4, 'CHARLS': 5}
scale_col = 'scale'
country_col = 'country'

ckpt_path = "/content/drive/MyDrive/manuscript/5_weights/default_tabpfn-v2-classifier_weights.ckpt"
assert os.path.exists(ckpt_path), "Base model checkpoint not found. Please verify the path."

data_path = "/content/drive/MyDrive/manuscript/3_split_data/stacked_finetune.csv"
weights_output_path = "/content/drive/MyDrive/manuscript/5_weights/Finetuned_Weights"
os.makedirs(weights_output_path, exist_ok=True)
set
stacked_finetune = pd.read_csv(data_path)
print(f"Loaded fine-tuning data", flush=True)

# ========== Fine-tuning Parameters ==========
finetune_configurations = [
    {"n": 25, "batch_size": 10, "max_time": 3000, "learning_rate": 0.0001},
    {"n": 50, "batch_size": 10, "max_time": 5000, "learning_rate": 0.0001},
    {"n": 75, "batch_size": 10, "max_time": 7000, "learning_rate": 0.0001},
    {"n": "all", "batch_size": 10, "max_time": 9000, "learning_rate": 0.0001}
]

# ========== Fine-tuning Loop ==========
for config in finetune_configurations:

    target_percentage = config["n"]
    batch_size = config["batch_size"]
    max_time = config["max_time"]
    learning_rate = config["learning_rate"]

    print(f"\n{'='*80}", flush=True)
    print(f"Fine-tuning Configuration:", flush=True)
    if target_percentage == "all":
        print(f"Data usage: 100% (all data)", flush=True)
    else:
        print(f"Data usage: {target_percentage}%", flush=True)
    print(f"Batch size: {batch_size}", flush=True)
    print(f"Max training time: {max_time}s", flush=True)
    print(f"Learning rate: {learning_rate}", flush=True)
    print(f"{'='*80}", flush=True)

    if target_percentage == "all":
        multiclass_finetune_df = stacked_finetune.copy()
        percentage_label = "100pct"
    else:
        multiclass_finetune_list = []
        for country_code in country_codes.values():
            country_data = stacked_finetune[stacked_finetune[country_col] == country_code]
            for s in sorted(country_data[scale_col].unique()):
                scale_df = country_data[country_data[scale_col] == s]
                selected_part, _ = train_test_split(
                    scale_df,
                    test_size=(1 - (target_percentage / 100.0)),
                    stratify=scale_df[target_col],
                    random_state=42
                )
                multiclass_finetune_list.append(selected_part)
        multiclass_finetune_df = pd.concat(multiclass_finetune_list, ignore_index=True)
        percentage_label = f"{target_percentage}pct"

    # Stage 1: Binary classification fine-tuning (cognitive normal vs impaired)
    temp_binary_model_path = f"/tmp/binary_temp_{percentage_label}.ckpt"

    print("Stage 1: Binary classification fine-tuning...", flush=True)
    fine_tune_tabpfn(
        path_to_base_model=ckpt_path,
        save_path_to_fine_tuned_model=temp_binary_model_path,
        time_limit=max_time,
        finetuning_config={
            "learning_rate": learning_rate,
            "batch_size": batch_size
        },
        small_improvement_threshold=0.003,
        small_improvement_patience=25,
        validation_metric="log_loss",
        categorical_features_index=None,
        X_train=multiclass_finetune_df[predictors],
        y_train=multiclass_finetune_df[target_col],
        device="cuda",
        task_type="binary",
        show_training_curve=False,
        logger_level=1,
        random_seed=101
    )
    print(">>> Binary fine-tuning done", flush=True)

    # Stage 2: Three-class classification fine-tuning (initialised from binary weights)
    final_model_path = os.path.join(
        weights_output_path,
        f"tabpfn_finetuned_{percentage_label}_threeclass.ckpt"
    )

    print("Stage 2: Three-class classification fine-tuning...", flush=True)
    fine_tune_tabpfn(
        path_to_base_model=temp_binary_model_path,
        save_path_to_fine_tuned_model=final_model_path,
        time_limit=max_time,
        finetuning_config={
            "learning_rate": learning_rate,
            "batch_size": batch_size
        },
        small_improvement_threshold=0.003,
        small_improvement_patience=25,
        validation_metric="log_loss",
        categorical_features_index=None,
        X_train=multiclass_finetune_df[predictors],
        y_train=multiclass_finetune_df[target_col],
        device="cuda",
        task_type="multiclass",
        show_training_curve=False,
        logger_level=1,
        random_seed=101
    )
    print(">>> Three-class fine-tuning done", flush=True)

Writing /content/finetune_experiment.py


In [None]:
!conda run --no-capture-output -n py311 python -u /content/finetune_experiment.py

Tue Aug 26 01:18:18 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off |   00000000:00:04.0 Off |                    0 |
| N/A   30C    P0             44W /  400W |       0MiB /  40960MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+
                                                