In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import collections
import matplotlib.pyplot as plt
import numpy as np
import random
import torch.nn.functional as F
import pandas as pd
from pathlib import Path
import torch.optim as optim
from lightning import Trainer
from omegaconf import OmegaConf

In [None]:
import sys
sys.path.insert(1, '/home/buehlern/Documents/Masterarbeit/models')
from src.data.mri_datamodule import MRIDataModule
from src.models.vit_mae_module import VisionTransformerMAE
from src.models.vit_mae_probe_module import ViTMAELinearProbingClassifier

# Reconstruction Visualization

In [None]:
# Load model checkpoint
mae_name = "ViT-L MAE" # Old pretrained model
mae_name = "ViT-L MAE FT-1" # FT PT Model normal, 50 epochs
mae_name = "ViT-L MAE FT-2" # FT PT Model normal, 90 epochs
mae_name = "ViT-L MAE FT-3" # FT PT Model normal, 300 epochs
mae_name = "ViT-L MAE FT-4" # FT PT Model normal, 1000 epochs
mae_name = "ViT-L MAE FT-5" # FT PT Model normal, 10k samples, 3 epochs
mae_name = "ViT-L_MAE_FT-10k_1" # FT PT Model normal, 10k samples, 10 epochs
mae_name = "ViT-L_MAE_FT-10k_2" # FT PT Model normal, 10k samples, 10 epochs, 50% mask ratio
mae_name = "ViT-L_MAE_FT-10k_3" # FT PT Model normal, 10k samples, 10 epochs, patch_size 32
# 10k samples, 30 epochs
mae_name = "ViT-L_MAE_FT-10k/default" # Default FT PT Model, 10k samples, 30 epochs
mae_name = "ViT-L_MAE_FT-10k/downsampling"
mae_name = "ViT-L_MAE_FT-10k/maskratio"
mae_name = "ViT-L_MAE_FT-10k/patchsize"
# Full pretraining
mae_name = "ViT-L-MAE/Default/epoch_000"
mae_name = "ViT-L-MAE/Default/epoch_001"
mae_name = "ViT-L-MAE/Default/epoch_002"
mae_name = "ViT-L-MAE/Overfit/epoch_000"
# Final pretraining
mae_name = "ViT-B-MAE/Default/epoch_009"
mae_checkpoint = f"/home/buehlern/Documents/Masterarbeit/models/checkpoints/{mae_name}.ckpt"

In [None]:
torch.cuda.device_count()

In [None]:
device = 'cuda:0' # 'cpu'
checkpoint = torch.load(mae_checkpoint, map_location=device)
state_dict = checkpoint['state_dict']

In [None]:
state_dict.keys()

In [None]:
def remap_keys(state_dict, unwanted_prefix, new_prefix):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith(unwanted_prefix):
            new_key = key.replace(unwanted_prefix, new_prefix)
        else:
            new_key = key
        new_state_dict[new_key] = value
    return collections.OrderedDict(new_state_dict)

In [None]:
new_state_dict = remap_keys(state_dict, 'net._orig_mod.', 'net.')

In [None]:
new_state_dict.keys()

In [None]:
mae = VisionTransformerMAE(image_size = 3072, patch_size = 48, image_channels=1, output_attentions=True)
mae.load_state_dict(new_state_dict)

In [None]:
mae.eval()

In [None]:
# Disable masking
#mae.net.config.mask_ratio = 0

In [None]:
# For outputting attentions
print('output_attentions:', mae.net.config.output_attentions)
print('_attn_implementation:', mae.net.config._attn_implementation)

In [None]:
# Load the DataModule
mri_datamodule = MRIDataModule(
    df_name="df_min_ft_test_114",
    label="bodypart",
    pad_to_multiple_of=48,
    #batch_binning="smart",
    #batch_bins=[1152, 1536, 1920, 2304, 2688, 3072],
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.05,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)

In [None]:
def show_image(image, title=''):
    # image is [H, W, 1]
    assert image.shape[2] == 1
    plt.imshow(image, cmap=plt.cm.bone, interpolation='none')
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

In [None]:
def visualize(pixel_values, model, imgname=None):
    patch_size = model.config.patch_size
    print(pixel_values.size())
    image_width, image_height = pixel_values.size()[-2:]
    num_patches_x = image_width // patch_size
    num_patches_y = image_height // patch_size
    print(f"Size: {image_width}x{image_height}")
    print(f"Patches: {num_patches_x}x{num_patches_y}")
    
    # forward pass
    outputs = model(pixel_values, interpolate_pos_encoding=True, output_attentions=True)
    y = model.unpatchify(outputs.logits, original_image_size=pixel_values.size()[-2:])
    y = torch.einsum('nchw->nhwc', y).detach().cpu()
    
    # visualize the mask
    mask = outputs.mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *1)  # (N, H*W, p*p*1)
    mask = model.unpatchify(mask, original_image_size=pixel_values.size()[-2:])
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', pixel_values)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # Attention map calculations
    # seq_len = 1025
    # Without [CLS] token: seq_len = 1024
    # num_patches_x_out = num_patches_y = sqrt(seq_len) = sqrt(1024) = 32
    # num_patches_x_in = image_size // patch_size = 3072 // 48 = 64
    # Reason for difference is masking! 4096 * (1-0.75) = 1024

    #image_size = model.config.image_size
    #patch_size = model.config.patch_size
    #num_patches = image_size // patch_size
    #attn_map = attn.squeeze().numpy()
    ## Scale attention map to original image size
    #attn_map_scaled = np.zeros((image_size, image_size))
    #for i in range(num_patches):
    #    for j in range(num_patches):
    #        x_start = i * patch_size
    #        x_end = x_start + patch_size
    #        y_start = j * patch_size
    #        y_end = y_start + patch_size
    #        
    #        attn_map_scaled[x_start:x_end, y_start:y_end] = attn_map[i * num_patches + j]
    
    #num_patches = attn.size(-1)
    #print(f"num_patches (after masking {model.config.mask_ratio}):", num_patches) # 1024
    #attn = attn[0].reshape(model.config.patch_size, model.config.patch_size, -1)
    # Rescale attention to image size
    #attn = F.interpolate(attn.unsqueeze(0).unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0][0]
    
    attentions = outputs.attentions
    print("len(attentions):", len(attentions)) # 24 layers
    print("attentions[0].shape:", attentions[0].shape) # [1, 16, 1025, 1025]
    # Initialize full attention map
    full_attn_map = torch.zeros((1, image_width, image_height)) # (1, 3072, 3072)
    masklist = outputs.mask.detach().type(torch.int64) > 0
    print("masklist.shape:", masklist.shape) # (1, 4096)
    # Determine patch contributions
    #patch_contrib = attentions[-1][:, :, 1:, 1:].mean(dim=1) # Take mean attention of last layer (without [CLS] token)
    patch_contrib = torch.zeros(attentions[0].shape[-1]-1) # -1 for [CLS] token
    check_cls_token = False
    if check_cls_token:
        for layer_attn in attentions[:-1]:
            # Only check [CLS] token
            attn = layer_attn[:, :, 1, 1:].detach().cpu() # [1, 16, 1024]
            # Average over heads
            attn = attn.mean(dim=1) # [1, 1024]
            # Average over batch (if batch_size > 1)
            attn = attn.mean(dim=0) # [1, 1024]
            patch_contrib += attn
    else:
        for layer_attn in attentions[:-1]:
            # Remove [CLS] token
            attn = layer_attn[:, :, 1:, 1:].detach().cpu() # [1, 16, 1024, 1024]
            # Average over heads
            attn = attn.mean(dim=1) # [1, 1024, 1024]
            # Average over batch (if batch_size > 1)
            attn = attn.mean(dim=0) # [1024, 1024]
            # Average contribution across all other tokens
            attn = attn.mean(dim=1) # [1024,]
            patch_contrib += attn
    # Normalize
    patch_contrib -= patch_contrib.min()
    patch_contrib /= patch_contrib.max()
    print("patch_contrib.shape:", patch_contrib.shape)
    print(pd.DataFrame(patch_contrib).describe())
    # Map attention scores onto full attention map
    attn_iter = iter(patch_contrib)
    print("len(masklist[0]):", len(masklist[0]))
    #print("patches:", (len(masklist[0])-1) % num_patches_x + 1, len(masklist[0]) // num_patches_x)
    for i, masked in enumerate(masklist[0]):
        row = i % num_patches_y
        col = i // num_patches_y
        attn_val = next(attn_iter) if ~masked else torch.tensor([0])
        #attn_val = torch.tensor([col / (image_height // num_patches_y)])
        #attn_val = torch.tensor([1])
        #print("i, row, col, x, y, masked, attn_val:", i, row, col, col*patch_size, row*patch_size, masked.item(), attn_val)
        full_attn_map[0, col*patch_size:(col+1)*patch_size, row*patch_size:(row+1)*patch_size] = attn_val.expand((patch_size, patch_size))
    
    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 10]

    plt.subplot(1, 5, 1)
    show_image(x[0], "original")

    plt.subplot(1, 5, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 5, 3)
    show_image(y[0], f"reconstruction (loss: {outputs.loss.item():.4f})")

    plt.subplot(1, 5, 4)
    show_image(im_paste[0], "reconstruction + visible")
    
    plt.subplot(1, 5, 5)
    plt.imshow(full_attn_map[0], cmap='grey', interpolation='nearest')
    plt.title("Attention Map")
    plt.axis('off')

    if imgname is not None:
        base_path = f"/home/buehlern/Documents/Masterarbeit/notebooks/Data Exploration Graphics/Model Eval/ViT MAE FT/{mae_name}/"
        Path(base_path).mkdir(exist_ok=True)
        plt.savefig(base_path + '/' + str(imgname) + '.png')
    plt.show()

In [None]:
train_iter = iter(mri_datamodule.train_dataloader())
val_iter = iter(mri_datamodule.val_dataloader())

In [None]:
item = next(val_iter)
print(item[0].shape)
visualize(item[0], mae.net, imgname="val_example")

# Finetuning

In [None]:
def scheduler_fn(optimizer):
    return torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1)
def optim_fn(params):
    global lr
    return torch.optim.AdamW(params, lr=lr)
def run_finetuning(mae_checkpoint, datamodule, num_classes, set_lr, eff_batch_size, num_epochs=1, device=0):
    global lr
    lr = set_lr
    grad_acc = eff_batch_size // datamodule.batch_size
    print(f"Testing {mae_checkpoint} on task {datamodule.label} with lr={lr} and effective batch_size={datamodule.batch_size}*{grad_acc}={datamodule.batch_size*grad_acc}")
    ft_probe_model = ViTMAELinearProbingClassifier(
        optimizer=optim_fn,
        scheduler=scheduler_fn,
        mae_checkpoint=mae_checkpoint,
        num_classes=num_classes,
        freeze_encoder=False,
        mean_pooling=True
    )
    trainer = Trainer(max_epochs=num_epochs, accumulate_grad_batches=grad_acc, accelerator='gpu', devices=[device])
    trainer.fit(ft_probe_model, datamodule.train_dataloader(), datamodule.val_dataloader())
    return ft_probe_model

In [None]:
effective_batch_size = 64

## Default Model

In [None]:
mae_name = "ViT-B-MAE/Default/epoch_009"
mae_checkpoint = f"/home/buehlern/Documents/Masterarbeit/models/checkpoints/{mae_name}.ckpt"

### Bodypart classification

In [None]:
label = "bodypart"
# Bodypart Finetuning DataModule
ft_bp_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    #pad_to_multiple_of=48,
    batch_binning="smart",
    batch_bins=[1152, 1536, 1920, 2304, 2688, 3072],
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_bp_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
train_lbl = ft_bp_datamodule.dsbase.df.iloc[ft_bp_datamodule.data_train.indices]["fracture_bool"]
val_lbl = ft_bp_datamodule.dsbase.df.iloc[ft_bp_datamodule.data_val.indices]["fracture_bool"]
test_lbl = ft_bp_datamodule.dsbase.df.iloc[ft_bp_datamodule.data_test.indices]["fracture_bool"]
print(f"Ratio of fracture in train data:", train_lbl.sum() / len(train_lbl))
print(f"Ratio of fracture in val data:", val_lbl.sum() / len(val_lbl))
print(f"Ratio of fracture in test data:", test_lbl.sum() / len(test_lbl))

#### HParam Search: Learning rate

Test learning rates 0.0001, 0.0003, 0.001, 0.003:

In [None]:
_ = ft_probe_model = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, 0.00001, effective_batch_size, num_epochs=10)

In [None]:
_ = ft_probe_model = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, 0.00003, effective_batch_size)

In [None]:
_ = ft_probe_model = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, 0.0001, effective_batch_size)

In [None]:
_ = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, 0.0003, effective_batch_size)

In [None]:
_ = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, 0.001, effective_batch_size)

In [None]:
_ = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, 0.003, effective_batch_size)

#### Finetuning

In [None]:
best_lr = 0.00003

In [None]:
ft_probe_model = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=30, device=0)

In [None]:
#val_iter = iter(ft_bp_datamodule.val_dataloader())
#item = next(val_iter)
#pred = ft_probe_model(item[0])
#torch.nn.Softmax(pred)
#torch.argmax(pred)

### Fracture Detection

In [None]:
label = "fracture"
# Fracture Finetuning DataModule
ft_frac_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    pad_to_multiple_of=48,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_frac_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
train_lbl = ft_frac_datamodule.dsbase.df.iloc[ft_frac_datamodule.data_train.indices]["fracture_bool"]
val_lbl = ft_frac_datamodule.dsbase.df.iloc[ft_frac_datamodule.data_val.indices]["fracture_bool"]
test_lbl = ft_frac_datamodule.dsbase.df.iloc[ft_frac_datamodule.data_test.indices]["fracture_bool"]
print(f"Ratio of fracture in train data:", train_lbl.sum() / len(train_lbl))
print(f"Ratio of fracture in val data:", val_lbl.sum() / len(val_lbl))
print(f"Ratio of fracture in test data:", test_lbl.sum() / len(test_lbl))

In [None]:
ft_frac_model = run_finetuning(mae_checkpoint, ft_frac_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=30, device=0)

### Foreign Material Detection

In [None]:
label = "foreignmaterial"
# Fracture Finetuning DataModule
ft_fm_datamodule = MRIDataModule(
    df_name="df_min",
    label=label,
    pad_to_multiple_of=48,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_fm_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
train_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_train.indices]["fracture_bool"]
val_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_val.indices]["fracture_bool"]
test_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_test.indices]["fracture_bool"]
print(f"Ratio of fracture in train data:", train_lbl.sum() / len(train_lbl))
print(f"Ratio of fracture in val data:", val_lbl.sum() / len(val_lbl))
print(f"Ratio of fracture in test data:", test_lbl.sum() / len(test_lbl))

In [None]:
ft_fm_model = run_finetuning(mae_checkpoint, ft_fm_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

## Downsampling Model

In [None]:
mae_name = "ViT-L_MAE_FT-10k/downsampling"
mae_checkpoint = f"/home/buehlern/Documents/Masterarbeit/models/checkpoints/{mae_name}.ckpt"

### Bodypart classification

In [None]:
ft_bp_model_downsampling = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

### Fracture Detection

In [None]:
ft_frac_model_downsampling = run_finetuning(mae_checkpoint, ft_frac_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

### Foreign Material Detection

In [None]:
label = "foreignmaterial"
# Fracture Finetuning DataModule
ft_fm_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    pad_to_multiple_of=48,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_fm_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
train_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_train.indices]["fracture_bool"]
val_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_val.indices]["fracture_bool"]
test_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_test.indices]["fracture_bool"]
print(f"Ratio of fracture in train data:", train_lbl.sum() / len(train_lbl))
print(f"Ratio of fracture in val data:", val_lbl.sum() / len(val_lbl))
print(f"Ratio of fracture in test data:", test_lbl.sum() / len(test_lbl))

In [None]:
ft_frac_model = run_finetuning(mae_checkpoint, ft_fm_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

## Maskratio Model

In [None]:
mae_name = "ViT-L_MAE_FT-10k/maskratio"
mae_checkpoint = f"/home/buehlern/Documents/Masterarbeit/models/checkpoints/{mae_name}.ckpt"

### Bodypart classification

In [None]:
ft_bp_model_maskratio = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

### Fracture Detection

In [None]:
ft_frac_model_maskratio = run_finetuning(mae_checkpoint, ft_frac_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

### Foreign Material Detection

In [None]:
label = "foreignmaterial"
# Fracture Finetuning DataModule
ft_fm_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    pad_to_multiple_of=48,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_fm_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
train_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_train.indices]["fracture_bool"]
val_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_val.indices]["fracture_bool"]
test_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_test.indices]["fracture_bool"]
print(f"Ratio of fracture in train data:", train_lbl.sum() / len(train_lbl))
print(f"Ratio of fracture in val data:", val_lbl.sum() / len(val_lbl))
print(f"Ratio of fracture in test data:", test_lbl.sum() / len(test_lbl))

In [None]:
ft_frac_model = run_finetuning(mae_checkpoint, ft_fm_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

## Patch Size Model

In [None]:
mae_name = "ViT-L_MAE_FT-10k/patchsize"
mae_checkpoint = f"/home/buehlern/Documents/Masterarbeit/models/checkpoints/{mae_name}.ckpt"

### Bodypart classification

In [None]:
label = "bodypart"
# Bodypart Finetuning DataModule, pad to multiple of new patch_size (32)
ft_bp_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    pad_to_multiple_of=32,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_bp_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
ft_bp_model_patchsize = run_finetuning(mae_checkpoint, ft_bp_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

### Fracture Detection

In [None]:
label = "fracture"
# Fracture Finetuning DataModule, pad to multiple of new patch_size (32)
ft_frac_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    pad_to_multiple_of=32,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_frac_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
ft_frac_model_patchsize = run_finetuning(mae_checkpoint, ft_frac_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)

### Foreign Material Detection

In [None]:
label = "foreignmaterial"
# Fracture Finetuning DataModule, pad to multiple of new patch_size (32)
ft_fm_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label=label,
    pad_to_multiple_of=32,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.10,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True
)
num_classes = ft_fm_datamodule.dsbase.df[label].describe()["unique"]
print(f"Label {label} has {num_classes} classes")

In [None]:
train_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_train.indices]["fracture_bool"]
val_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_val.indices]["fracture_bool"]
test_lbl = ft_fm_datamodule.dsbase.df.iloc[ft_fm_datamodule.data_test.indices]["fracture_bool"]
print(f"Ratio of fracture in train data:", train_lbl.sum() / len(train_lbl))
print(f"Ratio of fracture in val data:", val_lbl.sum() / len(val_lbl))
print(f"Ratio of fracture in test data:", test_lbl.sum() / len(test_lbl))

In [None]:
ft_frac_model = run_finetuning(mae_checkpoint, ft_fm_datamodule, num_classes, best_lr, effective_batch_size, num_epochs=10)