In [4]:
import os
import glob
import torch
import random
import numpy as np
import seaborn as sns
from PIL import Image
from tqdm import tqdm
from easydict import EasyDict
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torchmetrics import Accuracy
from torchvision import transforms
import matplotlib.patches as mpatches
from configs.utils import load_config_yaml, update_domain_sequence
from dl_toolbox.callbacks import *
import dl_toolbox.inference as dl_inf
from argparse import ArgumentParser, Namespace
from model.segmenter_adapt import SegmenterAdapt
from datasets.utils import *
%matplotlib inline

In [3]:
config_file = "/d/maboum/css-peft/configs/config.yml"
config = load_config_yaml(file_path=config_file)
data_config = config["dataset"]["flair1"]
binary = data_config["binary"]
directory_path = data_config["data_path"]
seed = config["seed"]
random.seed(seed)
selected_elements = random.sample(os.listdir(directory_path), 20)
update_domain_sequence(config_file, selected_elements)
remaining_elements = list(filter(lambda element: element not in selected_elements, os.listdir(directory_path)))
update_domain_sequence(config_file, remaining_elements, 'task_name')

In [None]:
args = Namespace(
    initial_lr=0.01,
    final_lr=0.005,
    lr_milestones=(20, 80),
    epoch_len=100,
    sup_batch_size=4,
    crop_size=256,
    workers=6,
    img_aug='d4_rot90_rot270_rot180_d1flip',
    max_epochs=200,
    sequence_path="",
    train_split_coef=0.85,
    strategy='continual_{}',
    commit=None,
    train_type="adaptmlp",
    replay=False,  # Par défaut, `store_true` est `False`
    config_file="/d/maboum/css-peft/configs/config.yml",
    ffn_adapt=True,  # Par défaut, `store_true` est `True`
    ffn_num=64,
    vpt=False,  # Par défaut, `store_true` est `False`
    vpt_num=1,
    fulltune=False  # Par défaut, `store_true` est `False`
)

In [None]:
def visualize_attention(image, attention_map, save_path=None):
    """Visualize the image with its corresponding attention map."""
    plt.figure(figsize=(10, 5))
    
    # Original Image
    plt.subplot(1, 2, 1)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    plt.title("Original Image")
    plt.axis("off")
    
    # Attention Map
    plt.subplot(1, 2, 2)
    plt.imshow(attention_map, cmap='viridis')
    plt.title("Attention Map")
    plt.axis("off")
    
    if save_path:
        plt.savefig(save_path)
    plt.show()

In [None]:
def test_function(model, test_dataloader, device, loss_fn, accuracy_fn, data_config, run=None, eval_freq=5):
    n_channels = data_config['n_channels']
    class_labels = data_config['classnames']
    n_class = data_config['n_cls']
    labels = [class_labels[i] for i in sorted(class_labels.keys())]


    # Initialize accumulators
    loss_sum, acc_sum = 0.0, 0.0
    iou_metrics = torch.zeros(n_class)
    confusion_matrices = []

    for i, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        # Preprocess input and target
        image = (batch['image'][:, :n_channels, :, :] / 255.).to(device)
        target = batch['mask'].to(device)

        # Model forward pass
        output = model(image)
        softmax_output = F.softmax(output, dim=1)

        # Compute loss and accuracy
        loss = loss_fn(softmax_output, target.squeeze(1).long())
        acc = accuracy_fn(softmax_output.argmax(dim=1).unsqueeze(1), target)

        # Update accumulators
        loss_sum += loss.item()
        acc_sum += acc

        # Compute confusion matrix
        cm = compute_conf_mat(
            target.contiguous().view(-1).cpu(),
            output.argmax(dim=1).contiguous().view(-1).cpu().long(),
            n_class
        )
        confusion_matrices.append(cm.numpy())

        # Compute IoU for each class
        metrics_per_class_df, _, _ = dl_inf.cm2metrics(cm.numpy())
        # metrics_per_class_df.rename(index=class_labels, inplace=True).round(3)
        iou_metrics += torch.tensor(metrics_per_class_df.IoU.values)

        # Evaluate and log images at specific intervals
        if i % eval_freq == 0:
            idx_list = [0, -1]  # Display first and last images in the batch
            for img_idx in idx_list:
                domain_id = batch['id'][img_idx] if 'id' in batch else i  # Default to batch index if no domain_id
                img_cm = compute_conf_mat(
                    target[img_idx].contiguous().view(-1).cpu(),
                    output.argmax(dim=1)[img_idx].contiguous().view(-1).cpu().long(),
                    n_class
                )
                img_metrics_per_class_df, _, _ = dl_inf.cm2metrics(img_cm.numpy())

                # Plot predictions and ground truth
                predictions = overlay_segmentation(
                    image[img_idx].permute(1, 2, 0).cpu().numpy(),
                    output.argmax(dim=1)[img_idx].cpu().numpy(),
                    class_labels
                )
                ground_truth = overlay_segmentation(
                    image[img_idx].permute(1, 2, 0).cpu().numpy(),
                    target[img_idx, 0].cpu().numpy(),
                    class_labels
                )
                fig, axs = plt.subplots(1, 2, figsize=(15, 7.5))
                axs[0].imshow(predictions)
                axs[0].set_title("Predictions")
                axs[0].axis('off')
                axs[1].imshow(ground_truth)
                axs[1].set_title("Ground Truth")
                axs[1].axis('off')

                # Legend
                legend_patches = [
                    mpatches.Patch(color=plt.cm.tab20(j / len(class_labels)), label=class_labels[j])
                    for j in class_labels
                ]
                fig.legend(handles=legend_patches, loc='upper center', ncol=4, bbox_to_anchor=(0.5, 0.11), fontsize='medium')
                
                fig_path = f"/d/maboum/css-peft/imgs/logs/_{img_idx}_batch{i}_{domain_id}.png"
                # test_metrics_df.to_csv('test_metrics.txt', sep='\t')
                img_metrics_per_class_df.rename(index=class_labels, inplace=True)
                img_metrics_per_class_df.round(3).to_csv(f'/d/maboum/css-peft/imgs/logs/_{img_idx}_batch{i}_{domain_id}.txt', sep='\t')
                # fig.savefig(fig_path)
                # print(f"Figure saved to {fig_path}")

                # Close the figure to free memory
                plt.close(fig)
                # Log figure
                if run:
                    run[f'test/batch_{i}/domain_{domain_id}'].upload(fig)

                    # Log IoU metrics for this image
                    for cls_idx, cls_name in class_labels.items():
                        run[f'test/metrics/{domain_id}_{i}_{cls_name}_iou'].append(
                            img_metrics_per_class_df.IoU.loc[cls_idx].round(2)
                        )

    # Overall metrics and logging
    val_loss = loss_sum / len(test_dataloader)
    val_acc = acc_sum.item() / len(test_dataloader)
    val_iou = torch.mean(iou_metrics) / len(test_dataloader)

    mean_cm = sum(confusion_matrices)/len(test_dataloader)
    # Compute IoU for each class
    test_metrics_df, _, _ = dl_inf.cm2metrics(mean_cm)
    
    

    # Log overall metrics
    if run:
        run['test/loss'].log(val_loss)
        run['test/accuracy'].log(val_acc)
        run['test/mean_iou'].log(val_iou)

    print(f"Test Loss: {val_loss:.4f}, Test Accuracy: {val_acc:.4f}, Test Mean IoU: {val_iou:.4f}")
    print(test_metrics_df.round(4))
    

    return val_loss, val_acc, val_iou, test_metrics_df, mean_cm


In [None]:
def test_and_visualize():
    # Configuration
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config_file = "/d/maboum/css-peft/configs/config.yml"  # Change this to your actual path
    checkpoint_path = "/scratcht/FLAIR_1/experiments/checkpoints/continual_260824"
    config = load_config_yaml(file_path=config_file)
    
    seed = config["seed"]
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.autograd.set_detect_anomaly(True) 
    random.seed(seed)
    
    # Model setup (assuming similar setup from main training code)
    selected_model = "vit_base_patch8_224"
    model_type = config["model"]
    data_sequence =config["dataset"]["flair1"]["task_name"]
    model_config = model_type[selected_model]
    im_size = model_config["image_size"]
    patch_size = model_config["patch_size"]
    d_model = model_config["d_model"]
    n_heads = model_config["n_heads"]
    n_layers = model_config["n_layers"]
    d_encoder = model_config["d_model"]
    n_class = config["dataset"]["flair1"]["n_cls"]
    
    tuning_config = EasyDict(
        ffn_adapt=True,
        ffn_option="parallel",
        ffn_adapter_layernorm_option="none",
        ffn_adapter_init_option="lora",
        ffn_adapter_scalar="0.1",
        ffn_num=64,
        d_model=d_model,
        vpt_on=False,
        vpt_num=1,
        nb_task=len(data_sequence[:5]),
        tasks=data_sequence[:5]
    )
    
    # Initialize model
    model = SegmenterAdapt(
        im_size, n_layers, d_model, d_encoder, 4 * d_model, n_heads, n_class,
        patch_size, selected_model, tuning_config=tuning_config,
        model_name=config["model_name"]
    ).to(device)
    
    # Load checkpoint
    model.load_pretrained_weights(model_path = checkpoint_path)
    model.eval()
    
    # Prepare test data loader
    
    data_config = config["dataset"]["flair1"]
    binary = data_config["binary"]
    directory_path = data_config["data_path"]
    test_dataloaders = []
    train_imgs, test_imgs = [],[]
    loss_fn = torch.nn.CrossEntropyLoss().cuda()
    accuracy_fn = Accuracy(task='multiclass',num_classes=n_class).cuda()
    
    for step,domain in enumerate(tuning_config.tasks[:1]):
        print(step , domain)
        img = glob.glob(os.path.join(directory_path, '{}/Z*_*/img/IMG_*.tif'.format(domain)))
        random.shuffle(img)
        train_imgs += img[:int(len(img)*args.train_split_coef)]
        test_imgs += img[int(len(img)*args.train_split_coef):]
        random.shuffle(train_imgs)
    test_dataloader = create_test_dataloader(test_imgs, args, data_config, binary= binary)
    val_loss, val_acc, val_iou, test_metrics_df, mean_cm = test_function(model, test_dataloader, device, loss_fn, accuracy_fn, data_config)
    return val_loss, val_acc, val_iou, test_metrics_df, mean_cm
val_loss, val_acc, val_iou, test_metrics_df, mean_cm = test_and_visualize()

In [None]:
config_file = "/d/maboum/css-peft/configs/config.yml"
config = load_config_yaml(file_path=config_file)
label_dict = config["dataset"]["flair1"]["classnames"]
labels = [label_dict[i] for i in sorted(label_dict.keys())]

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
cm_norm = mean_cm.astype('float')/(mean_cm.sum(axis=1)[:, np.newaxis] + np.finfo(float).eps)
plt.figure(figsize=(10,7))
sns.heatmap(cm_norm, annot=True, fmt=".2f")

plt.title('Recall Confusion Matrix')

# Définir les noms des labels
plt.xticks(ticks=np.arange(len(labels)) + 0.5, labels=labels, rotation=90)
plt.yticks(ticks=np.arange(len(labels)) + 0.5, labels=labels, rotation=0)


plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
cm_norm = cm_norm = mean_cm.astype('float') / (mean_cm.sum(axis=0) + np.finfo(float).eps)
plt.figure(figsize=(10,7))
sns.heatmap(cm_norm, annot=True, fmt=".2f")

plt.title('Precision Confusion Matrix')

# Définir les noms des labels
plt.xticks(ticks=np.arange(len(labels)) + 0.5, labels=labels, rotation=90)
plt.yticks(ticks=np.arange(len(labels)) + 0.5, labels=labels, rotation=0)

plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

In [None]:
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config_file = "/d/maboum/css-peft/configs/config.yml"  # Change this to your actual path
checkpoint_path = "/scratcht/FLAIR_1/experiments/checkpoints/IL_multi_finetuning_w13_4156_3"
config = load_config_yaml(file_path=config_file)

seed = config["seed"]
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.autograd.set_detect_anomaly(True) 
random.seed(seed)

# Model setup (assuming similar setup from main training code)
selected_model = "vit_base_patch8_224"
model_type = config["model"]
data_sequence =config["dataset"]["flair1"]["task_name"]
model_config = model_type[selected_model]
im_size = model_config["image_size"]
patch_size = model_config["patch_size"]
d_model = model_config["d_model"]
n_heads = model_config["n_heads"]
n_layers = model_config["n_layers"]
d_encoder = model_config["d_model"]
n_class = config["dataset"]["flair1"]["n_cls"]

tuning_config = EasyDict(
    ffn_adapt=True,
    ffn_option="parallel",
    ffn_adapter_layernorm_option="none",
    ffn_adapter_init_option="lora",
    ffn_adapter_scalar="0.1",
    ffn_num=64,
    d_model=d_model,
    vpt_on=False,
    vpt_num=1,
    nb_task=len(data_sequence[:5]),
    tasks=data_sequence[:5]
)

# Initialize model
model = SegmenterAdapt(
    im_size, n_layers, d_model, d_encoder, 4 * d_model, n_heads, n_class,
    patch_size, selected_model, tuning_config=tuning_config,
    model_name=config["model_name"]
).to(device)


# Load checkpoint
model.load_pretrained_weights(model_path = checkpoint_path)
model.eval()

# Prepare test data loader
data_config = config["dataset"]["flair1"]
binary = data_config["binary"]
directory_path = data_config["data_path"]
test_dataloaders = []
train_imgs, test_imgs = [],[]
loss_fn = torch.nn.CrossEntropyLoss().cuda()
accuracy_fn = Accuracy(task='multiclass',num_classes=n_class).cuda()

for step,domain in enumerate(tuning_config.tasks[:1]):
    print(step , domain)
    img = glob.glob(os.path.join(directory_path, '{}/Z*_*/img/IMG_*.tif'.format(domain)))
    random.shuffle(img)
    train_imgs += img[:int(len(img)*args.train_split_coef)]
    test_imgs += img[int(len(img)*args.train_split_coef):]
    random.shuffle(train_imgs)
test_dataloader = create_test_dataloader(test_imgs[:1], args, data_config, binary= binary)

for i, batch in tqdm(enumerate(test_dataloader), total=len(test_dataloader)):
        # Preprocess input and target
        image = (batch['image'][:, :3, :, :] / 255.).to(device)
        target = batch['mask'].to(device)


In [None]:
import matplotlib.pyplot as plt
# Assuming `image` is your image
plt.imshow(image[0].permute(1, 2, 0).cpu() )
plt.show()

In [None]:
import matplotlib.pyplot as plt

# Create a figure and a grid of subplots
fig, axs = plt.subplots(12, 12, figsize=(20, 20))

# Assuming `model` is your model and `image` is your image
for layer_id in range(12):  # Iterate over each layer
    attention_map_enc = model.get_attention_map_enc(image, layer_id)
    for head_id in range(12):  # Iterate over each head
        attention_image = attention_map_enc[0, head_id].cpu().detach().numpy()
        
        # Use the appropriate subplot
        ax = axs[layer_id, head_id]
        ax.imshow(attention_image, cmap='viridis')  # Display the image
        ax.set_title(f'L. {layer_id + 1}, Head {head_id + 1}')  # Add a title to the subplot

# Adjust the space between subplots to make the titles more readable
plt.tight_layout()

# Display the colorbar
fig.colorbar(ax.get_images()[0], ax=axs, orientation='vertical', fraction=.1)

# Show the figure
plt.show()