<a href="https://colab.research.google.com/github/kaminglui/Domain-Adaptation-with-ME-IIS/blob/main/ME_IIS_Colab.ipynb" target="_blank">Open in Colab</a>

# ME-IIS Domain Adaptation (Colab)
Single-pass pipeline to train a source ResNet-50 and adapt with ME-IIS. Edit the Section 3 config cell to switch datasets, domains, and options.

## 0. Setup & GPU check
- Mount Google Drive and choose a working directory.
- Clone/pull the repo into that folder.
- Confirm CUDA availability and show the GPU name.

In [None]:
# Section 0 - Setup & GPU check
import os
from pathlib import Path

USE_DRIVE = True  # Set False to keep everything in /content
DRIVE_ROOT = "/content/drive/MyDrive" if USE_DRIVE else "/content"
WORK_DIR = os.path.join(DRIVE_ROOT, "MEIIS-Colab")
REPO_URL = "https://github.com/kaminglui/Domain-Adaptation-with-ME-IIS.git"
REPO_NAME = "Domain-Adaptation-with-ME-IIS"
REPO_DIR = os.path.join(WORK_DIR, REPO_NAME)

if USE_DRIVE:
    from google.colab import drive  # type: ignore
    drive.mount("/content/drive")

os.makedirs(WORK_DIR, exist_ok=True)
os.chdir(WORK_DIR)
print("WORK_DIR:", WORK_DIR)

if not os.path.isdir(REPO_DIR):
    print("[Repo] Cloning repository...")
    !git clone {REPO_URL}
else:
    print("[Repo] Repository exists; pulling latest changes...")
    os.chdir(REPO_DIR)
    !git pull
    os.chdir(WORK_DIR)

os.chdir(REPO_DIR)
print("[Repo] Current repo dir:", os.getcwd())

try:
    import torch
    print("torch.cuda.is_available():", torch.cuda.is_available())
    if torch.cuda.is_available():
        print("GPU:", torch.cuda.get_device_name(0))
    else:
        print("GPU not detected - switch runtime to GPU for full runs.")
except ImportError as exc:
    print("PyTorch not installed yet; run Section 1 to install dependencies. GPU check skipped:", exc)


## 1. Install dependencies
Re-run this cell if the Colab runtime restarts.

In [None]:
# Section 1 - Install dependencies (re-run if the runtime restarts)
import os

os.chdir(REPO_DIR)
req_file = "env/requirements_colab.txt" if os.path.exists("env/requirements_colab.txt") else "env/requirements.txt"
print("Installing dependencies from:", req_file)
!pip install -r {req_file}


## 2. Download datasets via KaggleHub
Downloads Office-Home and Office-31 with KaggleHub, locates the canonical roots, and links them under `datasets/` for the scripts.

In [None]:
# Section 2 - Download datasets via KaggleHub
import os
import pathlib

os.chdir(REPO_DIR)
print("Repo dir:", os.getcwd())

try:
    import kagglehub  # type: ignore
except ImportError:
    print("[Data] Installing kagglehub...")
    !pip install kagglehub
    import kagglehub  # type: ignore

def _find_office_home_root(base_dir: pathlib.Path) -> pathlib.Path:
    candidates = [base_dir] + [p for p in base_dir.rglob("*") if p.is_dir()]
    for cand in candidates:
        names = {p.name for p in cand.iterdir() if p.is_dir()}
        if {"Art", "Clipart", "Product"} <= names and any(n.lower().startswith("real") for n in names):
            return cand
    return base_dir

def _find_office31_root(base_dir: pathlib.Path) -> pathlib.Path:
    candidates = [base_dir] + [p for p in base_dir.rglob("*") if p.is_dir()]
    for cand in candidates:
        names = {p.name.lower() for p in cand.iterdir() if p.is_dir()}
        if {"amazon", "dslr", "webcam"} <= names:
            return cand
    return base_dir

print("[Data] Downloading Office-Home (lhrrraname/officehome)...")
office_home_raw = pathlib.Path(kagglehub.dataset_download("lhrrraname/officehome"))
office_home_root = _find_office_home_root(office_home_raw)
print("  Office-Home root:", office_home_root)

print("[Data] Downloading Office-31 (xixuhu/office31)...")
office31_raw = pathlib.Path(kagglehub.dataset_download("xixuhu/office31"))
office31_root = _find_office31_root(office31_raw)
print("  Office-31 root:", office31_root)

datasets_dir = pathlib.Path("datasets")
datasets_dir.mkdir(exist_ok=True)

def _ensure_link(link_path: pathlib.Path, target: pathlib.Path) -> None:
    target = target.resolve()
    if link_path.exists() and not link_path.is_symlink():
        print(f"[Data] {link_path} exists and is not a symlink; leaving as-is.")
        return
    if link_path.is_symlink():
        current = link_path.resolve()
        if current == target:
            print(f"[Data] {link_path} already points to {target}")
            return
        link_path.unlink()
    try:
        link_path.symlink_to(target, target_is_directory=True)
        print(f"[Data] Linked {link_path} -> {target}")
    except OSError as exc:
        print(f"[Data] Could not create symlink {link_path} -> {target}: {exc}")

_ensure_link(datasets_dir / "Office-Home", office_home_root)
_ensure_link(datasets_dir / "Office-31", office31_root)

OFFICE_HOME_ROOT = datasets_dir / "Office-Home"
OFFICE31_ROOT = datasets_dir / "Office-31"
print("[Data] Office-Home DATA_ROOT:", OFFICE_HOME_ROOT)
print("[Data] Office-31 DATA_ROOT:", OFFICE31_ROOT)


## 3. Configuration (single cell with all knobs)
Edit only in this cell to change experiment settings.

In [None]:
# Edit only in this cell to change experiment settings.

# Dataset & domains
DATASET_NAME = "office_home"  # or "office31"
SOURCE_DOMAIN = "Ar"          # e.g. Ar/Cl/Pr/Rw for Office-Home, A/D/W for Office-31
TARGET_DOMAIN = "Cl"
SEED = 0

# Paths (assume we are in the repo root)
DATA_ROOT = "datasets/Office-Home"  # or "datasets/Office-31"

# Source training hyperparameters
NUM_EPOCHS_SRC = 50
BATCH_SIZE_SRC = 32
LR_BACKBONE = 1e-3
LR_CLASSIFIER = 1e-2
WEIGHT_DECAY = 1e-3
NUM_WORKERS = 4
DETERMINISTIC = True  # set True to minimize randomness

# ME-IIS / adaptation hyperparameters
FEATURE_LAYERS = "layer3,layer4,avgpool"
GMM_SELECTION_MODE = "bic"        # "fixed" or "bic"
NUM_LATENT_STYLES = 5             # default components per layer (fixed mode or BIC init)
GMM_BIC_MIN_COMPONENTS = 2        # BIC lower bound
GMM_BIC_MAX_COMPONENTS = 8        # BIC upper bound

# NEW: source_prob_mode
SOURCE_PROB_MODE = "softmax"      # or "onehot" (use GT labels for source constraints)

# NEW: optional per-layer GMM override
# e.g. "layer3:10,layer4:5" or "7,8,9" for 3 layers; leave "" to disable
COMPONENTS_OVERRIDE = ""

# Pseudo-label adaptation (ME-IIS+PL)
USE_PSEUDO_LABELS = False         # default False; set True to use pseudo-labels
PSEUDO_CONF_THRESH = 0.9
PSEUDO_MAX_RATIO = 1.0            # 1.0 = no cap; or e.g. 0.3 to limit pseudo-target count
PSEUDO_LOSS_WEIGHT = 1.0

# IIS configuration
IIS_ITERS = 15
IIS_TOL = 1e-3
ADAPT_EPOCHS = 10
FINETUNE_BACKBONE = False
BACKBONE_LR_SCALE = 0.1
CLASSIFIER_LR = 1e-2
ADAPT_WEIGHT_DECAY = 1e-3

# Optional: force fresh runs instead of auto-resume
FORCE_FRESH_SOURCE_TRAIN = False
FORCE_FRESH_ADAPT = False


## 4. Helper: construct checkpoint paths and cleanup
Lets you intentionally bypass auto-resume by deleting existing checkpoints.

In [None]:
import os

os.chdir(REPO_DIR)

def get_source_ckpt_path():
    return f"checkpoints/source_only_{SOURCE_DOMAIN}_to_{TARGET_DOMAIN}_seed{SEED}.pth"

def get_me_iis_ckpt_path():
    layers_str = FEATURE_LAYERS.replace(",", "-").replace(" ", "")
    return f"checkpoints/me_iis_{SOURCE_DOMAIN}_to_{TARGET_DOMAIN}_{layers_str}_seed{SEED}.pth"

if FORCE_FRESH_SOURCE_TRAIN:
    src_ckpt = get_source_ckpt_path()
    if os.path.exists(src_ckpt):
        os.remove(src_ckpt)
        print("Deleted existing source checkpoint:", src_ckpt)

if FORCE_FRESH_ADAPT:
    adapt_ckpt = get_me_iis_ckpt_path()
    if os.path.exists(adapt_ckpt):
        os.remove(adapt_ckpt)
        print("Deleted existing ME-IIS checkpoint:", adapt_ckpt)


## 5. Train source-only model (full training)
Build the command from the config above and stream output directly in the cell.

In [None]:
import os

det_flag = "--deterministic" if DETERMINISTIC else ""
os.chdir(REPO_DIR)

SRC_CMD = (
    'python scripts/train_source.py \
'
    f'  --dataset_name {DATASET_NAME} \
'
    f'  --data_root "{DATA_ROOT}" \
'
    f'  --source_domain {SOURCE_DOMAIN} \
'
    f'  --target_domain {TARGET_DOMAIN} \
'
    f'  --num_epochs {NUM_EPOCHS_SRC} \
'
    f'  --batch_size {BATCH_SIZE_SRC} \
'
    f'  --lr_backbone {LR_BACKBONE} \
'
    f'  --lr_classifier {LR_CLASSIFIER} \
'
    f'  --weight_decay {WEIGHT_DECAY} \
'
    f'  --num_workers {NUM_WORKERS} \
'
    f'  {det_flag} \
'
    f'  --seed {SEED}'
).strip()

print("Training command:\n", SRC_CMD)
!{SRC_CMD}


## 6. Run ME-IIS adaptation
Uses the source checkpoint above plus the ME-IIS settings in the config cell.

In [None]:
import os

det_flag = "--deterministic" if DETERMINISTIC else ""
os.chdir(REPO_DIR)

ADAPT_CMD = (
    'python scripts/adapt_me_iis.py \
'
    f'  --dataset_name {DATASET_NAME} \
'
    f'  --data_root "{DATA_ROOT}" \
'
    f'  --source_domain {SOURCE_DOMAIN} \
'
    f'  --target_domain {TARGET_DOMAIN} \
'
    f'  --checkpoint {get_source_ckpt_path()} \
'
    f'  --batch_size {BATCH_SIZE_SRC} \
'
    f'  --num_workers {NUM_WORKERS} \
'
    f'  --feature_layers "{FEATURE_LAYERS}" \
'
    f'  --num_latent_styles {NUM_LATENT_STYLES} \
'
    f'  --gmm_selection_mode {GMM_SELECTION_MODE} \
'
    f'  --gmm_bic_min_components {GMM_BIC_MIN_COMPONENTS} \
'
    f'  --gmm_bic_max_components {GMM_BIC_MAX_COMPONENTS} \
'
    f'  --iis_iters {IIS_ITERS} \
'
    f'  --iis_tol {IIS_TOL} \
'
    f'  --adapt_epochs {ADAPT_EPOCHS} \
'
    f'  {'--finetune_backbone' if FINETUNE_BACKBONE else ''} \
'
    f'  --backbone_lr_scale {BACKBONE_LR_SCALE} \
'
    f'  --classifier_lr {CLASSIFIER_LR} \
'
    f'  --weight_decay {ADAPT_WEIGHT_DECAY} \
'
    f'  --source_prob_mode {SOURCE_PROB_MODE} \
'
    f'  {det_flag} \
'
    f'  --seed {SEED}'
).strip()

if USE_PSEUDO_LABELS:
    ADAPT_CMD += (
        f" --use_pseudo_labels --pseudo_conf_thresh {PSEUDO_CONF_THRESH} "
        f"--pseudo_max_ratio {PSEUDO_MAX_RATIO} --pseudo_loss_weight {PSEUDO_LOSS_WEIGHT}"
    )

if COMPONENTS_OVERRIDE.strip():
    ADAPT_CMD += f' --components_per_layer "{COMPONENTS_OVERRIDE}"'

print("Adaptation command:\n", ADAPT_CMD)
!{ADAPT_CMD}


## 7. Optional: experiment driver examples
Leave `RUN_EXPERIMENT_DRIVER = False` for the standard train + adapt path.

In [None]:
RUN_EXPERIMENT_DRIVER = False

if RUN_EXPERIMENT_DRIVER:
    import os
    os.chdir(REPO_DIR)
    det_flag = "--deterministic" if DETERMINISTIC else ""
    EXP_CMD = (
        'python scripts/run_me_iis_experiments.py \
'
        f'  --dataset_name {DATASET_NAME} \
'
        f'  --source_domain {SOURCE_DOMAIN} \
'
        f'  --target_domain {TARGET_DOMAIN} \
'
        '  --experiment_family layers \
'
        f'  --seeds {SEED} \
'
        f'  --base_data_root "{DATA_ROOT}" \
'
        f'  --feature_layers "{FEATURE_LAYERS}" \
'
        f'  --num_latent_styles {NUM_LATENT_STYLES} \
'
        f'  --gmm_selection_mode {GMM_SELECTION_MODE} \
'
        f'  --gmm_bic_min_components {GMM_BIC_MIN_COMPONENTS} \
'
        f'  --gmm_bic_max_components {GMM_BIC_MAX_COMPONENTS} \
'
        f'  --source_prob_mode {SOURCE_PROB_MODE} \
'
        f'  --num_epochs {NUM_EPOCHS_SRC} \
'
        f'  --batch_size {BATCH_SIZE_SRC} \
'
        f'  --num_workers {NUM_WORKERS} \
'
        f'  --iis_iters {IIS_ITERS} \
'
        f'  --iis_tol {IIS_TOL} \
'
        f'  --adapt_epochs {ADAPT_EPOCHS} \
'
        f'  --backbone_lr_scale {BACKBONE_LR_SCALE} \
'
        f'  --classifier_lr {CLASSIFIER_LR} \
'
        f'  --weight_decay {ADAPT_WEIGHT_DECAY} \
'
        f'  {'--finetune_backbone' if FINETUNE_BACKBONE else ''} \
'
        f'  {det_flag} \
'
        f'  --output_csv results/me_iis_layers_{SOURCE_DOMAIN}_to_{TARGET_DOMAIN}_seed{SEED}.csv'
    ).strip()

    if COMPONENTS_OVERRIDE.strip():
        EXP_CMD += f' --components_per_layer "{COMPONENTS_OVERRIDE}"'

    if USE_PSEUDO_LABELS:
        EXP_CMD += (
            f" --pseudo_conf_thresh {PSEUDO_CONF_THRESH} "
            f"--pseudo_max_ratio {PSEUDO_MAX_RATIO} --pseudo_loss_weight {PSEUDO_LOSS_WEIGHT}"
        )

    if det_flag:
        EXP_CMD += f" {det_flag}"

    print("Experiment driver command:\n", EXP_CMD)
    !{EXP_CMD}
else:
    print("Experiment driver is disabled (set RUN_EXPERIMENT_DRIVER = True to run).")


## 8. Notes on outputs
- Source checkpoints: `checkpoints/source_only_*` (auto-resume uses the matching file name).
- ME-IIS checkpoints and IIS weights: `checkpoints/me_iis_*` and `results/me_iis_weights_*.npz`.
- CSV log of source/adapt runs: `results/office_home_me_iis.csv` (dataset column distinguishes Office-Home vs Office-31).
- Experiment driver summaries (optional): `results/me_iis_experiments_summary.csv` or the path you pass via `--output_csv`.
- TensorBoard logs live under `runs/` (`runs/source_only` and `runs/adapt_me_iis`).