In [None]:
#@title Setup project root

PROJECT_ROOT = "/Users/yunfanbao/Documents/work/mv-mammo-transformer"

import os, sys

os.chdir(PROJECT_ROOT)
sys.path.append(PROJECT_ROOT)

print("Current directory:", os.getcwd())

In [None]:
# @title Experiment Configuration

# @markdown # experiment
experiment_name = "local_exp_v1"  # @param {type:"string"}

# @markdown # model
model_name = "mv_transformer"  # @param ["sv_baseline", "mv_concat", "mv_transformer"]
backbone = "swin_t"  # @param ["resnet50", "effnetv2_s", "swin_t", "swin_s", "swin_b"]
head = "softmax"  # @param ["softmax", "evidential"]
in_chans = 1  # @param {type:"number"}
birads_classes = 5  # @param {type:"number"}
density_classes = 4  # @param {type:"number"}

In [None]:
# @title Backbone and Head Configuration

# model.backbone_kwargs
def config_backbone_kwargs():
    backbone_kwargs = {}

    # @markdown # resnet50
    if backbone == "resnet50":
        pretrained = True  # @param {type:"boolean"}

        backbone_kwargs = {
            "pretrained": pretrained,
        }

    # @markdown # effnetv2_s
    if backbone == "effnetv2_s":
        pretrained = True  # @param {type:"boolean"}

        backbone_kwargs = {
            "pretrained": pretrained,
        }

    # @markdown # swin_t
    if backbone == "swin_t":
        pretrained = True  # @param {type:"boolean"}

        backbone_kwargs = {
            "pretrained": pretrained,
        }

    # @markdown # swin_s
    if backbone == "swin_s":
        pretrained = True  # @param {type:"boolean"}

        backbone_kwargs = {
            "pretrained": pretrained,
        }

    # @markdown # swin_b
    if backbone == "swin_b":
        pretrained = True  # @param {type:"boolean"}

        backbone_kwargs = {
            "pretrained": pretrained,
        }

    return backbone_kwargs

In [None]:
# @title Head Configuration

# model.head_kwargs
def config_head_kwargs():
    head_kwargs = {}

    # @markdown # softmax
    if head == "softmax":
        dropout = 0.0  # @param {type:"number"}

        head_kwargs = {
            "dropout": dropout,
        }

    # @markdown # evidential
    if head == "evidential":
        hidden = 512  # @param {type:"number"}
        dropout = 0.05  # @param {type:"number"}

        head_kwargs = {
            "hidden": hidden,
            "dropout": dropout,
        }

    return head_kwargs

In [None]:
# @title Model Configuration

def config_model_kwargs():
    model_kwargs = {}

    # @markdown # sv_baseline
    if model_name == "sv_baseline":
        pass

    # @markdown # mv_concat
    if model_name == "mv_concat":
        use_proj = False  # @param {type:"boolean"}
        porj_dim = 512  # @param {type:"number"}

        model_kwargs = {
            "proj_dim": porj_dim if use_proj else None,
        }

    # @markdown # mv_transformer
    if model_name == "mv_transformer":
        token_dim = 512  # @param {type:"number"}
        nhead = 8  # @param {type:"number"}
        expansion = 4  # @param {type:"number"}
        dropout = 0.1  # @param {type:"number"}
        num_layers = 2  # @param {type:"number"}

        model_kwargs = {
            "token_dim": token_dim,
            "nhead": nhead,
            "expansion": expansion,
            "dropout": dropout,
            "num_layers": num_layers,
        }

    return model_kwargs


# model
def config_model():
    model_kwargs = {
        "name": model_name,
        "backbone": backbone,
        "head": head,
        "in_chans": in_chans,
        "birads_classes": birads_classes,
        "density_classes": density_classes,
        "backbone_kwargs": config_backbone_kwargs(),
        "head_kwargs": config_head_kwargs(),
        "model_kwargs": config_model_kwargs(),
    }

    return model_kwargs


In [None]:
# @title Transform Configuration

transform = "mv_consistent"  # @param ["mv_consistent", "mv_test", "sv_baseline", "sv_test"]


# data.transform_kwargs
def config_transform():
    transform_kwargs = {
            "image_size": [224, 224],
        }

    if backbone in ["swin_t", "swin_s", "swin_b"]:
        transform_kwargs = {
            "image_size": [224, 224],
        }

    return transform_kwargs

In [None]:
# @title Data Configuration

# @markdown # Dataset
version = "first_10"  # @param ["full", "first_1000", "first_10"]
mode = "multi"  # @param ["multi", "single"]
train_splits = ["F1", "F2", "F3", "F4"]
val_splits = ["F5"]

label_columns = ["breast_birads", "breast_density"]
label_mapping = {
    "breast_birads": {
        "BI-RADS 1": 0,
        "BI-RADS 2": 1,
        "BI-RADS 3": 2,
        "BI-RADS 4": 3,
        "BI-RADS 5": 4,
    },
    "breast_density": {
        "DENSITY A": 0,
        "DENSITY B": 1,
        "DENSITY C": 2,
        "DENSITY D": 3,
    }
}

# @markdown # Dataloader
batch_size = 8  # @param {type:"slider", min:1, max:32}
shuffle = True  # @param {type:"boolean"}
num_workers = 4  # @param {type:"number"}


# data
def config_data():
    data_kwargs = {
        "version": version,
        "mode": mode,
        "train_splits": train_splits,
        "val_splits": val_splits,
        "label_columns": label_columns,
        "label_mapping": label_mapping,
        "transform": transform,
        "transform_kwargs": config_transform(),
        "batch_size": batch_size,
        "shuffle": shuffle,
        "num_workers": num_workers,
    }

    return data_kwargs

In [None]:
# @title Loss Configuration

loss_name = "softmax_lsce"  # @param ["softmax_ce", "softmax_lsce", "evidential_ce", "evidential_klce", "softmax+gradcam", "evidential+gradcam"]
lambda_density = 0.1  # @param {type:"number"}


# train.loss_kwargs
def config_loss_kwargs():
    loss_kwargs = {
        "name": loss_name,
        "lambda_density": lambda_density,
    }

    if "gradcam" in loss_name:
        start_epoch = 3  # @param {type:"number"}
        weight = 0.05  # @param {type:"number"}

        loss_kwargs |= {
            "gradcam_kwargs": {
                "start_epoch": start_epoch,
                "weight": weight,
            },
        }

    return loss_kwargs

In [None]:
# @title Backbone Freeze Schedule Configuration

# training.freeze_schedule
def config_freeze_schedule():
    freeze_schedule = []

    # @markdown # resnet50
    if backbone == "resnet50":
        freeze_schedule = [
            {"epoch": 1, "action": "freeze_all"},  # warmup
            {"epoch": 3, "action": "unfreeze_from", "n": 3},  # train layer3 & layer4
            {"epoch": 5, "action": "unfreeze_from", "n": 2},  # train layer2~4
            {"epoch": 8, "action": "unfreeze_all"},  # full fine-tune
        ]

    # @markdown # effnetv2_s
    if backbone == "effnetv2_s":
        freeze_schedule = [
            {"epoch": 1, "action": "freeze_all"},  # warmup
            {"epoch": 3, "action": "unfreeze_from", "n": 5},  # train last 2 blocks
            {"epoch": 5, "action": "unfreeze_from", "n": 3},  # train last 4 blocks
            {"epoch": 8, "action": "unfreeze_all"},  # full fine-tune
        ]

    # @markdown # swin_t
    if backbone == "swin_t":
        freeze_schedule = [
            {"epoch": 1, "action": "freeze_all"},  # warmup
            {"epoch": 2, "action": "unfreeze_from", "n": 3},  # train stage3 + stage4
            {"epoch": 4, "action": "unfreeze_from", "n": 2},  # train stage2~4
            {"epoch": 7, "action": "unfreeze_all"},  # full fine-tune
        ]

    # @markdown # swin_s
    if backbone == "swin_s":
        freeze_schedule = [
            {"epoch": 1, "action": "freeze_all"},  # warmup
            {"epoch": 3, "action": "unfreeze_from", "n": 3},  # train stage3 + stage4
            {"epoch": 5, "action": "unfreeze_from", "n": 2},  # train stage2~4
            {"epoch": 9, "action": "unfreeze_all"},  # full fine-tune
        ]

    # @markdown # swin_b
    if backbone == "swin_b":
        freeze_schedule = [
            {"epoch": 1, "action": "freeze_all"},  # warmup
            {"epoch": 4, "action": "unfreeze_from", "n": 3},  # train stage3 + stage4
            {"epoch": 7, "action": "unfreeze_from", "n": 2},  # train stage2~4
            {"epoch": 12, "action": "unfreeze_all"},  # full fine-tune
        ]

    return freeze_schedule

In [None]:
# @title Learning Rate Schedule Configuration

# training.lr_schedule
def config_lr_schedule():
    lr_schedule = {
        "name": "cosine",
        "warmup_epochs": 3,
        "eta_min": 1e-6,
    }

    lr_schedule_alternative_1 = {
        "name": "step",
        "step_size": 5,
        "gamma": 0.1,
    }

    lr_schedule_alternative_2 = {
        "name": "multistep",
        "milestones": [10, 20, 30],
        "gamma": 0.1,
    }

    lr_schedule_alternative_3 = {
        "name": "plateau",
        "factor": 0.1,
        "patience": 5,
    }

    return lr_schedule

In [None]:
# @title Training Configuration

from config import CHECKPOINT_DIR

lr = 3e-4  # @param {type:"number"}
weight_decay = 1e-5  # @param {type:"number"}
max_epoch = 10  # @param {type:"slider", min:1, max:100}

use_amp = True  # @param {type:"boolean"}
eval_every = 1  # @param {type:"number"}
save_dir = os.path.join(CHECKPOINT_DIR, experiment_name)


# training
def config_training():
    training_kwargs = {
        "loss_kwargs": config_loss_kwargs(),
        "lr": lr,
        "weight_decay": weight_decay,
        "max_epoch": max_epoch,
        "use_amp": use_amp,
        "eval_every": eval_every,
        "save_dir": save_dir,
        "lr_schedule": config_lr_schedule(),
        "freeze_schedule": config_freeze_schedule(),
    }

    return training_kwargs

In [None]:
# @title Save Config to YAML

import yaml
from config import CONFIG_DIR

cfg = {
    "model": config_model(),
    "data": config_data(),
    "training": config_training(),
}

yaml_path = os.path.join(CONFIG_DIR, f"{experiment_name}.yaml")

with open(yaml_path, "w") as f:
    yaml.dump(cfg, f, sort_keys=False)

print(f"Saved {yaml_path}")

In [None]:
#@title Train model

cmd = (
    f"python -m scripts.train "
    f"--config {experiment_name}.yaml"
)

print("Running command:")
print(cmd)

!{cmd}