Model creation and loading

In [1]:
import torch
from pathlib import Path
from utils.models.complete_model import create_complete_model, save_complete_model, load_complete_model, save_checkpoint, load_checkpoint


# ALL PATHS
MODELS_DIR = "models/"
SEGMENTER_MODEL_PATH = f"{MODELS_DIR}dino_unet_decoder_finetuned.pth"
save_path = f"{MODELS_DIR}complete_model.pth"
checkpoint_path = f"{MODELS_DIR}model_checkpoint.pth"

# Example usage
device = "cuda" if torch.cuda.is_available() else "cpu"
model = create_complete_model(device=device, SEGMENTER_MODEL_PATH=SEGMENTER_MODEL_PATH)

# Load the model
if Path(save_path).exists():
    model = load_complete_model(model, save_path, device=device, strict=True)

Loaded segmenter weights from models/dino_unet_decoder_finetuned.pth
Loaded complete model weights from models/complete_model.pth


Data loader creation

In [2]:
from utils.data.dataloaders import create_dataloaders

# CheXpert
CHEXPERT_DIR = "Datasets/CheXpertPlus"
chexpert_paths = {
    "chexpert_data_path": f"{CHEXPERT_DIR}/PNG",  # base PNG folder
    "chexpert_data_csv": f"{CHEXPERT_DIR}/df_chexpert_plus_240401_findings.csv",
}

# MIMIC
MIMIC_DIR = "Datasets/MIMIC"
mimic_paths = {
    "mimic_data_path": MIMIC_DIR,
    "mimic_splits_csv": f"{MIMIC_DIR}/mimic-cxr-2.0.0-split.csv.gz",
    "mimic_metadata_csv": f"{MIMIC_DIR}/mimic-cxr-2.0.0-metadata-findings-only.csv",
    "mimic_reports_path": f"{MIMIC_DIR}/cxr-record-list.csv.gz",  # must contain 'path'
    "mimic_images_dir": f"{MIMIC_DIR}/matched_images_and_masks_mimic_224/images",
}

import os
kwargs = {
    # "num_workers": os.cpu_count() // 2 if os.cpu_count() else 4,  # adjust on your VM
    # "persistent_workers": True,           # reuses workers between iterations
    # "prefetch_factor": 4,                 # each worker prefetches batches
    # "pin_memory": True,                   # if using CUDA
    # "drop_last": False
}

train_loader = create_dataloaders(
    chexpert_paths, 
    mimic_paths, 
    batch_size=4,
    split="train", 
    sampling_ratio=0.7,
    **kwargs
)

valid_loader = create_dataloaders(
    chexpert_paths,
    mimic_paths,
    batch_size=4,
    split="valid",
    sampling_ratio=0.7,
    **kwargs
)

images, findings, image_paths, _ = next(iter(train_loader))
print("Batch image tensor shape:", getattr(images, "shape", "N/A"))
print("Batch findings shape:", getattr(findings, "shape", len(findings)))
print("Batch image paths shape:", getattr(image_paths, "shape", len(image_paths)))

Batch image tensor shape: torch.Size([4, 3, 512, 512])
Batch findings shape: 4
Batch image paths shape: 4


Training

In [None]:
from utils.training import train, EarlyStopping, EarlyStoppingConfig
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
optimizer = Adam(model.parameters(), lr=3e-4, weight_decay=0.1)
# scheduler = ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=5, min_lr=1e-6)

epochs = (5 * len(train_loader)) // 100
from transformers import get_cosine_schedule_with_warmup
total_steps = 5 * len(train_loader)
warmup_steps = max(1, int(0.05 * total_steps))
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

early = EarlyStopping(EarlyStoppingConfig(
    patience=15, min_delta=1e-4, mode="min", restore_best=True,
    best_ckpt_path="checkpoints/model_best.pth"
))
print(f"Current memory before training: {torch.cuda.memory_allocated(device) / 1e9:.2f} GB") if device == "cuda" else None
torch.cuda.empty_cache() if device == "cuda" else None
train(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    optimizer=optimizer,
    epochs=epochs,                       # total target; not "remaining"
    device=device,
    log_dir="runs/chestx_exp1_fixed",       # SAME dir to keep appending
    checkpoint_path="checkpoints/model_epoch.pth",
    validate_every=1,
    ckpt_every=2,
    scheduler=scheduler,
    scheduler_step_on="step",
    early_stopping=early,
    resume_from=None,  # or model_best.pth if you prefer to start from best weights
    # start_epoch=...,                 # optional override
    # start_global_step=...,           # optional override
)


Current memory before training: 0.85 GB
ðŸš€ Starting training from scratch.


Epoch 1 Training:   0%|          | 0/39209 [00:00<?, ?batch/s]`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
Epoch 1 Training:   0%|          | 101/39209 [00:36<3:52:54,  2.80batch/s]


Epoch 1 | Loss: 2.3688 | LR: 0.0000 | Tokens: 30285


Epoch 1 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:12<00:24,  8.30batch/s]


Epoch 1 | Validation Loss: 2.1451 | Validation Tokens: 30252
New best model found during early stopping.
Saved checkpoint to checkpoints/model_best.pth


Epoch 2 Training:   0%|          | 101/39209 [00:32<3:27:35,  3.14batch/s]


Epoch 2 | Loss: 2.2618 | LR: 0.0000 | Tokens: 59651


Epoch 2 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:12<00:24,  8.26batch/s]


Epoch 2 | Validation Loss: 2.1483 | Validation Tokens: 30057
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 3 Training:   0%|          | 101/39209 [00:32<3:30:40,  3.09batch/s]


Epoch 3 | Loss: 2.2612 | LR: 0.0000 | Tokens: 89613


Epoch 3 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:12<00:25,  8.17batch/s]


Epoch 3 | Validation Loss: 2.1267 | Validation Tokens: 29055
New best model found during early stopping.
Saved checkpoint to checkpoints/model_best.pth


Epoch 4 Training:   0%|          | 101/39209 [00:35<3:47:10,  2.87batch/s]


Epoch 4 | Loss: 2.2398 | LR: 0.0000 | Tokens: 119112


Epoch 4 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.70batch/s]


Epoch 4 | Validation Loss: 2.1034 | Validation Tokens: 29468
New best model found during early stopping.
Saved checkpoint to checkpoints/model_best.pth
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 5 Training:   0%|          | 101/39209 [00:32<3:30:50,  3.09batch/s]


Epoch 5 | Loss: 2.2998 | LR: 0.0000 | Tokens: 149166


Epoch 5 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.74batch/s]


Epoch 5 | Validation Loss: 2.1439 | Validation Tokens: 29866


Epoch 6 Training:   0%|          | 101/39209 [00:32<3:28:16,  3.13batch/s]


Epoch 6 | Loss: 2.3368 | LR: 0.0000 | Tokens: 179170


Epoch 6 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.73batch/s]


Epoch 6 | Validation Loss: 2.1940 | Validation Tokens: 30156
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 7 Training:   0%|          | 101/39209 [00:34<3:44:07,  2.91batch/s]


Epoch 7 | Loss: 2.3742 | LR: 0.0000 | Tokens: 209644


Epoch 7 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.72batch/s]


Epoch 7 | Validation Loss: 2.2693 | Validation Tokens: 29262


Epoch 8 Training:   0%|          | 101/39209 [00:31<3:24:35,  3.19batch/s]


Epoch 8 | Loss: 2.3991 | LR: 0.0000 | Tokens: 239777


Epoch 8 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.61batch/s]


Epoch 8 | Validation Loss: 2.2749 | Validation Tokens: 29148
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 9 Training:   0%|          | 101/39209 [00:32<3:28:00,  3.13batch/s]


Epoch 9 | Loss: 2.4870 | LR: 0.0000 | Tokens: 269531


Epoch 9 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.68batch/s]


Epoch 9 | Validation Loss: 2.2134 | Validation Tokens: 29810


Epoch 10 Training:   0%|          | 101/39209 [00:31<3:24:53,  3.18batch/s]


Epoch 10 | Loss: 2.4981 | LR: 0.0000 | Tokens: 298985


Epoch 10 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.73batch/s]


Epoch 10 | Validation Loss: 2.3680 | Validation Tokens: 29402
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 11 Training:   0%|          | 101/39209 [00:32<3:29:54,  3.11batch/s]


Epoch 11 | Loss: 2.6267 | LR: 0.0000 | Tokens: 329022


Epoch 11 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:24,  8.53batch/s]


Epoch 11 | Validation Loss: 2.3494 | Validation Tokens: 29127


Epoch 12 Training:   0%|          | 101/39209 [00:31<3:24:25,  3.19batch/s]


Epoch 12 | Loss: 2.6531 | LR: 0.0000 | Tokens: 357944


Epoch 12 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.69batch/s]


Epoch 12 | Validation Loss: 2.5602 | Validation Tokens: 29545
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 13 Training:   0%|          | 101/39209 [00:32<3:30:12,  3.10batch/s]


Epoch 13 | Loss: 2.7064 | LR: 0.0000 | Tokens: 386780


Epoch 13 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.75batch/s]


Epoch 13 | Validation Loss: 2.4568 | Validation Tokens: 29990


Epoch 14 Training:   0%|          | 101/39209 [00:32<3:30:23,  3.10batch/s]


Epoch 14 | Loss: 3.0792 | LR: 0.0000 | Tokens: 416935


Epoch 14 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.81batch/s]


Epoch 14 | Validation Loss: 2.7211 | Validation Tokens: 29493
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 15 Training:   0%|          | 101/39209 [00:33<3:34:25,  3.04batch/s]


Epoch 15 | Loss: 2.8218 | LR: 0.0000 | Tokens: 445722


Epoch 15 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.75batch/s]


Epoch 15 | Validation Loss: 2.5994 | Validation Tokens: 29696


Epoch 16 Training:   0%|          | 101/39209 [00:31<3:26:03,  3.16batch/s]


Epoch 16 | Loss: 2.8876 | LR: 0.0000 | Tokens: 475678


Epoch 16 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.65batch/s]


Epoch 16 | Validation Loss: 2.6751 | Validation Tokens: 29330
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 17 Training:   0%|          | 101/39209 [00:31<3:26:24,  3.16batch/s]


Epoch 17 | Loss: 2.9560 | LR: 0.0001 | Tokens: 504598


Epoch 17 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:24,  8.51batch/s]


Epoch 17 | Validation Loss: 2.6047 | Validation Tokens: 29106


Epoch 18 Training:   0%|          | 101/39209 [00:31<3:26:29,  3.16batch/s]


Epoch 18 | Loss: 3.0558 | LR: 0.0001 | Tokens: 534145


Epoch 18 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.78batch/s]


Epoch 18 | Validation Loss: 4.8317 | Validation Tokens: 28727
Saving periodic checkpoint.
Saved checkpoint to checkpoints/model_epoch.pth


Epoch 19 Training:   0%|          | 101/39209 [00:37<4:00:14,  2.71batch/s]


Epoch 19 | Loss: 3.5807 | LR: 0.0001 | Tokens: 563995


Epoch 19 Validation:  33%|â–ˆâ–ˆâ–ˆâ–Ž      | 101/307 [00:11<00:23,  8.84batch/s]

Epoch 19 | Validation Loss: 2.7993 | Validation Tokens: 28653
Early stopping triggered at epoch 19. Best epoch: 4 with val loss: 2.103352892398834.
ðŸŽ‰ Training complete.



