In [8]:
import os
import torch
from custom_byol_bolts import CustomBYOL  # Adjust this import if needed

def sanitize_checkpoint(old_checkpoint_path, output_dir):
    print(f"Processing: {old_checkpoint_path}")

    # Load the old (unsafe) checkpoint
    checkpoint = torch.load(old_checkpoint_path, weights_only=False)
    state_dict = checkpoint["state_dict"]
    config = checkpoint["hyper_parameters"]["config"]

    # Rebuild the model (DO NOT include transformations)
    model = CustomBYOL(
        num_classes=5,
        learning_rate=config["lr"],
        weight_decay=config["weight_decay"],
        input_height=32,
        batch_size=config["batch_size"],
        num_workers=config["dataset"]["num_workers"],
        warmup_epochs=config["warm_up"],
        max_epochs=config["epochs"],
        config=config,
        transformations=None
    )

    # Load weights into the new clean model
    model.load_state_dict(state_dict, strict=True)

    # Prepare clean save path
    filename = os.path.basename(old_checkpoint_path).replace(".ckpt", "_clean.ckpt")
    output_path = os.path.join(output_dir, filename)

    # Save only the state_dict and config (SAFE)
    torch.save({"state_dict": model.state_dict(), "config": config}, output_path)
    print(f"Saved clean checkpoint: {output_path}\n")

def find_and_sanitize_all_checkpoints(root_dir, output_dir):
    os.makedirs(output_dir, exist_ok=True)
    for subdir, _, files in os.walk(root_dir):
        if 'extras' not in subdir:
            for file in files:
                if file.endswith(".ckpt") and 'best' in file:
                    old_checkpoint_path = os.path.join(subdir, file)
                    try:
                        print(file)
                        sanitize_checkpoint(old_checkpoint_path, output_dir)
                    except Exception as e:
                        print(f"Failed to process {old_checkpoint_path}: {e}")


base_dir = "./finished_models/pretrained_models_byol/"
output_dir = "./finished_models/pretrained_models_byol_new/"

find_and_sanitize_all_checkpoints(base_dir, output_dir)


best_pretrained_byol_DTW_w=3_r=5_epoch=149-val_loss=-3.9991.ckpt
Processing: ./finished_models/pretrained_models_byol/16-04-2025-20-51_byol_DTW_w=3_r=5\checkpoints\best_pretrained_byol_DTW_w=3_r=5_epoch=149-val_loss=-3.9991.ckpt
Saved clean checkpoint: ./finished_models/pretrained_models_byol_new/best_pretrained_byol_DTW_w=3_r=5_epoch=149-val_loss=-3.9991_clean.ckpt

best_pretrained_byol_DTW_w=1_r=10_epoch=147-val_loss=-3.9979.ckpt
Processing: ./finished_models/pretrained_models_byol/16-04-2025-22-49_byol_DTW_w=1_r=10\checkpoints\best_pretrained_byol_DTW_w=1_r=10_epoch=147-val_loss=-3.9979.ckpt
Saved clean checkpoint: ./finished_models/pretrained_models_byol_new/best_pretrained_byol_DTW_w=1_r=10_epoch=147-val_loss=-3.9979_clean.ckpt

best_pretrained_byol_DTW_w=3_r=10_epoch=149-val_loss=-3.9976.ckpt
Processing: ./finished_models/pretrained_models_byol/16-04-2025-23-34_byol_DTW_w=3_r=10\checkpoints\best_pretrained_byol_DTW_w=3_r=10_epoch=149-val_loss=-3.9976.ckpt
Saved clean checkpoint: 