In [1]:
import torch
import collections
import matplotlib.pyplot as plt
import numpy as np
import random
import torch.nn.functional as F
import pandas as pd
from pathlib import Path
import torch.optim as optim
from lightning import Trainer
from omegaconf import OmegaConf

In [2]:
import sys
sys.path.insert(1, '/home/buehlern/Documents/Masterarbeit/models')
from src.data.mri_datamodule import MRIDataModule
from src.models.vit_mae_module import VisionTransformerMAE
from src.models.vit_mae_probe_module import MAEFineProbeClassifier

In [3]:
# Load model checkpoint
mae_name = "ViT-L MAE" # Old pretrained model
mae_name = "ViT-L MAE FT-1" # FT PT Model normal, 50 epochs
mae_name = "ViT-L MAE FT-2" # FT PT Model normal, 90 epochs
mae_name = "ViT-L MAE FT-3" # FT PT Model normal, 300 epochs
mae_name = "ViT-L MAE FT-4" # FT PT Model normal, 1000 epochs
mae_name = "ViT-L MAE FT-5" # FT PT Model normal, 10k samples, 3 epochs
mae_name = "ViT-L_MAE_FT-10k_1" # FT PT Model normal, 10k samples, 10 epochs
mae_checkpoint = f"/home/buehlern/Documents/Masterarbeit/models/checkpoints/{mae_name}.ckpt"

In [None]:
checkpoint = torch.load(mae_checkpoint)
state_dict = checkpoint['state_dict']

In [4]:
state_dict.keys()

odict_keys(['net.vit.embeddings.cls_token', 'net.vit.embeddings.position_embeddings', 'net.vit.embeddings.patch_embeddings.projection.weight', 'net.vit.embeddings.patch_embeddings.projection.bias', 'net.vit.encoder.layer.0.attention.attention.query.weight', 'net.vit.encoder.layer.0.attention.attention.query.bias', 'net.vit.encoder.layer.0.attention.attention.key.weight', 'net.vit.encoder.layer.0.attention.attention.key.bias', 'net.vit.encoder.layer.0.attention.attention.value.weight', 'net.vit.encoder.layer.0.attention.attention.value.bias', 'net.vit.encoder.layer.0.attention.output.dense.weight', 'net.vit.encoder.layer.0.attention.output.dense.bias', 'net.vit.encoder.layer.0.intermediate.dense.weight', 'net.vit.encoder.layer.0.intermediate.dense.bias', 'net.vit.encoder.layer.0.output.dense.weight', 'net.vit.encoder.layer.0.output.dense.bias', 'net.vit.encoder.layer.0.layernorm_before.weight', 'net.vit.encoder.layer.0.layernorm_before.bias', 'net.vit.encoder.layer.0.layernorm_after.wei

In [5]:
def remap_keys(state_dict, unwanted_prefix, new_prefix):
    new_state_dict = {}
    for key, value in state_dict.items():
        if key.startswith(unwanted_prefix):
            new_key = key.replace(unwanted_prefix, new_prefix)
        else:
            new_key = key
        new_state_dict[new_key] = value
    return collections.OrderedDict(new_state_dict)

In [6]:
new_state_dict = remap_keys(state_dict, 'net._orig_mod.', 'net.')

In [7]:
new_state_dict.keys()

odict_keys(['net.vit.embeddings.cls_token', 'net.vit.embeddings.position_embeddings', 'net.vit.embeddings.patch_embeddings.projection.weight', 'net.vit.embeddings.patch_embeddings.projection.bias', 'net.vit.encoder.layer.0.attention.attention.query.weight', 'net.vit.encoder.layer.0.attention.attention.query.bias', 'net.vit.encoder.layer.0.attention.attention.key.weight', 'net.vit.encoder.layer.0.attention.attention.key.bias', 'net.vit.encoder.layer.0.attention.attention.value.weight', 'net.vit.encoder.layer.0.attention.attention.value.bias', 'net.vit.encoder.layer.0.attention.output.dense.weight', 'net.vit.encoder.layer.0.attention.output.dense.bias', 'net.vit.encoder.layer.0.intermediate.dense.weight', 'net.vit.encoder.layer.0.intermediate.dense.bias', 'net.vit.encoder.layer.0.output.dense.weight', 'net.vit.encoder.layer.0.output.dense.bias', 'net.vit.encoder.layer.0.layernorm_before.weight', 'net.vit.encoder.layer.0.layernorm_before.bias', 'net.vit.encoder.layer.0.layernorm_after.wei

In [8]:
mae = VisionTransformerMAE(image_size = 3072, patch_size = 48, image_channels=1, output_attentions=True)
mae.load_state_dict(new_state_dict)

<All keys matched successfully>

In [9]:
mae.eval()

VisionTransformerMAE(
  (net): ViTMAEForPreTraining(
    (vit): ViTMAEModel(
      (embeddings): ViTMAEEmbeddings(
        (patch_embeddings): ViTMAEPatchEmbeddings(
          (projection): Conv2d(1, 1024, kernel_size=(48, 48), stride=(48, 48))
        )
      )
      (encoder): ViTMAEEncoder(
        (layer): ModuleList(
          (0-23): 24 x ViTMAELayer(
            (attention): ViTMAEAttention(
              (attention): ViTMAESelfAttention(
                (query): Linear(in_features=1024, out_features=1024, bias=True)
                (key): Linear(in_features=1024, out_features=1024, bias=True)
                (value): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
              (output): ViTMAESelfOutput(
                (dense): Linear(in_features=1024, out_features=1024, bias=True)
                (dropout): Dropout(p=0.0, inplace=False)
              )
            )
            (intermediate): Vi

In [10]:
# Disable masking
#mae.net.config.mask_ratio = 0

In [11]:
# For outputting attentions
print('output_attentions:', mae.net.config.output_attentions)
print('_attn_implementation:', mae.net.config._attn_implementation)

output_attentions: True
_attn_implementation: eager


In [12]:
# Load the DataModule
mri_datamodule = MRIDataModule(
    df_name="df_min_ft_test_114",
    label="bodypart",
    batch_binning="smart",
    batch_bins=[1152, 1536, 1920, 2304, 2688, 3072],
    batch_size=1,
    num_workers=4,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.05,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True)

Using label bodypart as stratification_target
Initializing MRIDatasetBase...
Loading dataframe from /home/buehlern/Documents/Masterarbeit/data/df_min_ft_test_114.pkl...
MRIDatasetBase(len=114) initialized
Getting train indices...
Done. Train len: 92
Getting val indices...
Done. Val len: 6
Getting test indices...
WARN: Including test data
Done. Test len: 16
Initializing train dataset...
Done.
Initializing val dataset...
Done.
Initializing test dataset...
Done.


In [13]:
def show_image(image, title=''):
    # image is [H, W, 1]
    assert image.shape[2] == 1
    plt.imshow(image, cmap=plt.cm.bone)
    plt.title(title, fontsize=16)
    plt.axis('off')
    return

In [14]:
def visualize(pixel_values, model, imgname=None):
    patch_size = model.config.patch_size
    print(pixel_values.size())
    image_width, image_height = pixel_values.size()[-2:]
    num_patches_x = image_width // patch_size
    num_patches_y = image_height // patch_size
    print(f"Size: {image_width}x{image_height}")
    print(f"Patches: {num_patches_x}x{num_patches_y}")
    
    # forward pass
    outputs = model(pixel_values, interpolate_pos_encoding=True, output_attentions=True)
    y = model.unpatchify(outputs.logits, original_image_size=pixel_values.size()[-2:])
    y = torch.einsum('nchw->nhwc', y).detach().cpu()
    
    # visualize the mask
    mask = outputs.mask.detach()
    mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size**2 *1)  # (N, H*W, p*p*1)
    mask = model.unpatchify(mask, original_image_size=pixel_values.size()[-2:])
    mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
    
    x = torch.einsum('nchw->nhwc', pixel_values)

    # masked image
    im_masked = x * (1 - mask)

    # MAE reconstruction pasted with visible patches
    im_paste = x * (1 - mask) + y * mask

    # Attention map calculations
    # seq_len = 1025
    # Without [CLS] token: seq_len = 1024
    # num_patches_x_out = num_patches_y = sqrt(seq_len) = sqrt(1024) = 32
    # num_patches_x_in = image_size // patch_size = 3072 // 48 = 64
    # Reason for difference is masking! 4096 * (1-0.75) = 1024

    #image_size = model.config.image_size
    #patch_size = model.config.patch_size
    #num_patches = image_size // patch_size
    #attn_map = attn.squeeze().numpy()
    ## Scale attention map to original image size
    #attn_map_scaled = np.zeros((image_size, image_size))
    #for i in range(num_patches):
    #    for j in range(num_patches):
    #        x_start = i * patch_size
    #        x_end = x_start + patch_size
    #        y_start = j * patch_size
    #        y_end = y_start + patch_size
    #        
    #        attn_map_scaled[x_start:x_end, y_start:y_end] = attn_map[i * num_patches + j]
    
    #num_patches = attn.size(-1)
    #print(f"num_patches (after masking {model.config.mask_ratio}):", num_patches) # 1024
    #attn = attn[0].reshape(model.config.patch_size, model.config.patch_size, -1)
    # Rescale attention to image size
    #attn = F.interpolate(attn.unsqueeze(0).unsqueeze(0), scale_factor=model.config.patch_size, mode="nearest")[0][0]

    attentions = outputs.attentions
    print("len(attentions):", len(attentions)) # 24 layers
    print("attentions[0].shape:", attentions[0].shape) # [1, 16, 1025, 1025]
    # Initialize full attention map
    full_attn_map = torch.zeros((1, image_width, image_height)) # (1, 3072, 3072)
    masklist = outputs.mask.detach().type(torch.int64) > 0
    print("masklist.shape:", masklist.shape) # (1, 4096)
    # Determine patch contributions
    #patch_contrib = attentions[-1][:, :, 1:, 1:].mean(dim=1) # Take mean attention of last layer (without [CLS] token)
    patch_contrib = torch.zeros(attentions[0].shape[-1]-1) # -1 for [CLS] token
    check_cls_token = False
    if check_cls_token:
        for layer_attn in attentions[:-1]:
            # Only check [CLS] token
            attn = layer_attn[:, :, 1, 1:].detach().cpu() # [1, 16, 1024]
            # Average over heads
            attn = attn.mean(dim=1) # [1, 1024]
            # Average over batch (if batch_size > 1)
            attn = attn.mean(dim=0) # [1, 1024]
            patch_contrib += attn
    else:
        for layer_attn in attentions[:-1]:
            # Remove [CLS] token
            attn = layer_attn[:, :, 1:, 1:].detach().cpu() # [1, 16, 1024, 1024]
            # Average over heads
            attn = attn.mean(dim=1) # [1, 1024, 1024]
            # Average over batch (if batch_size > 1)
            attn = attn.mean(dim=0) # [1024, 1024]
            # Average contribution across all other tokens
            attn = attn.mean(dim=1) # [1024,]
            patch_contrib += attn
    # Normalize
    patch_contrib -= patch_contrib.min()
    patch_contrib /= patch_contrib.max()
    print("patch_contrib.shape:", patch_contrib.shape)
    print(pd.DataFrame(patch_contrib).describe())
    # Map attention scores onto full attention map
    attn_iter = iter(patch_contrib)
    print("len(masklist[0]):", len(masklist[0]))
    for i, masked in enumerate(masklist[0]):
        row = i // num_patches_x
        col = i % num_patches_y
        #print("i, row, col, masked:", i, row, col, masked.item())
        attn_val = next(attn_iter) if ~masked else torch.tensor([0])
        full_attn_map[0, row*patch_size:(row+1)*patch_size, col*patch_size:(col+1)*patch_size] = attn_val.expand((patch_size, patch_size))
    
    # make the plt figure larger
    plt.rcParams['figure.figsize'] = [24, 10]

    plt.subplot(1, 5, 1)
    show_image(x[0], "original")

    plt.subplot(1, 5, 2)
    show_image(im_masked[0], "masked")

    plt.subplot(1, 5, 3)
    show_image(y[0], f"reconstruction (loss: {outputs.loss.item():.4f})")

    plt.subplot(1, 5, 4)
    show_image(im_paste[0], "reconstruction + visible")
    
    plt.subplot(1, 5, 5)
    plt.imshow(full_attn_map[0], cmap='gray', interpolation='nearest')
    plt.title("Attention Map")
    plt.axis('off')

    if imgname is not None:
        base_path = f"/home/buehlern/Documents/Masterarbeit/notebooks/Data Exploration Graphics/Model Eval/ViT MAE FT/{mae_name}/"
        Path(base_path).mkdir(exist_ok=True)
        plt.savefig(base_path + '/' + str(imgname) + '.png')
    plt.show()

In [15]:
train_iter = iter(mri_datamodule.train_dataloader())
val_iter = iter(mri_datamodule.val_dataloader())

Generating shape_to_indices dict in CustomBatchSampler...
Done.
Maximum bin shape: (2688, 1920)
DataLoader length 92
Maximum bin shape: (2688, 1152)
Generating shape_to_indices dict in CustomBatchSampler...
Done.
Maximum bin shape: (1920, 1536)
Maximum bin shape: (1536, 1152)


In [None]:
item = next(val_iter)
print(item[0].shape)
visualize(item[0], mae.net, imgname="val_example")

# Finetuning

## Bodypart classification

In [4]:
# Finetuning DataModule
ft_bp_datamodule = MRIDataModule(
    df_name="df_min_ft_pt_1k",
    label="bodypart",
    pad_to_multiple_of=48,
    batch_size=1,
    num_workers=1,
    persistent_workers=True,
    pin_memory=True,
    val_size=0.05,
    test_size=0.15,
    output_channels=1,
    fix_inverted=True)

Using label bodypart as stratification_target
Initializing MRIDatasetBase...
Loading dataframe from /home/buehlern/Documents/Masterarbeit/data/df_min_ft_pt_1k.pkl...
MRIDatasetBase(len=1000) initialized
Getting train indices...
Done. Train len: 822
Getting val indices...
Done. Val len: 38
Getting test indices...
WARN: Including test data
Done. Test len: 140
Initializing train dataset...
Done.
Initializing val dataset...
Done.
Initializing test dataset...
Done.


In [5]:
ft_bp_datamodule.dsbase.df["bodypart"].describe()

count     1000
unique      14
top       knee
freq       170
Name: bodypart, dtype: object

In [10]:
cfg = OmegaConf.load('../models/configs/experiment/vit_mae_probe_bodypart.yaml')
model = MAEFineProbeClassifier(num_classes=cfg.model.num_classes,
                                mae_checkpoint=mae_checkpoint,
                                optimizer=cfg.model.optimizer,
                                scheduler=cfg.model.scheduler,
                                seq_mean=cfg.model.seq_mean,
                                compile=cfg.model.compile)

checkpoint path /home/buehlern/Documents/Masterarbeit/models/checkpoints/ViT-L_MAE_FT-10k_1.ckpt


/home/buehlern/.local/lib/python3.10/site-packages/lightning/fabric/utilities/cloud_io.py:57: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.


In [11]:
trainer = Trainer(max_epochs=10)
trainer.fit(model, train_dataloaders=ft_bp_datamodule.train_dataloader())

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


DataLoader length 822


TypeError: 'DictConfig' object is not callable