In [None]:
import argparse
import os 
import csv
import wandb
# os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import warnings
warnings.filterwarnings("ignore")

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import pickle

import torch
import torch.nn as nn
import torch.optim as optim
from timm import optim
import torch.nn.functional as F
from torch.optim import lr_scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau

import torchvision
from torchvision import models
import timm
from torch.utils.data import DataLoader, WeightedRandomSampler, default_collate
import random
from tqdm import tqdm
import time

from Dataset_ML import *
from utils_ML import *

from models_ML3_v1 import *
from models_ML3_v2 import *
from models_ML3_v3 import *
from models_ML3_v4__ import *
from models_ML3_v4___ import *
from models_ML3_v4____ import *

from Transformer_USVN import Transformer_USVN
from BiLSTM_USVN import BiLSTM_USVN
from cnnlstm import CNNLSTM
from cnntransformer import CNNTransformer
from C3D_model import C3D
from R2Plus1D_model import R2Plus1DClassifier

import vidaug.augmentors as va

from sklearn.metrics import classification_report, confusion_matrix, accuracy_score
from sklearn.metrics import roc_auc_score
from sklearn import metrics
from sklearn.metrics import roc_curve

import math
from sklearn.preprocessing import OneHotEncoder

In [None]:
def set_all_seeds(SEED):
    # REPRODUCIBILITY
    torch.manual_seed(SEED)
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(SEED)
    random.seed(SEED)

In [None]:
def num_parameters(module):
    return sum(p.numel() for p in module.parameters() if p.requires_grad)

def collate_video(batch_list):
    """
    A custom collate function to be passed to the callate_fn argument when creating a pytorch dataloader.
    This is necessary because videos have different lengths. We handle by combining all videos along the time 
    dimension and returning the number of frames in each video.
    """
    vids = torch.concat([b[0] for b in batch_list])
    # num_frames = [b.shape[0] for b in batch_list]
    labels = [b[1] for b in batch_list]
    paths = [b[2] for b in batch_list]
    # record = {
    #     'video': vids,
    #     'num_frames': num_frames
    # }

    # use pytorch's default collate function for remaining items
    # for b in batch_list:
    #     b.pop('video')
    # record.update(default_collate(batch_list))

    return vids, labels, paths

In [None]:
class attention_video_dataset(Dataset):
    """ Video Dataset.
    
    """
    # def __init__(self, class0_csv_path, class1_csv_path, class2_csv_path, class3_csv_path, transforms, padding_type, is_train=True): # case 1
    def __init__(self, csv_path, transforms, img_size, is_train=True): 

        # class 0 / class 1, class 2 / class 3
        self.csv_path = csv_path
        
        self.video_df = pd.read_csv(self.csv_path)

        self.transforms = transforms
        self.is_train = is_train

        self.video_path_list = [str(i) for i in self.video_df[f'{img_size}_clip_path']] 
        
        # 4 artifacts class
        self.PRED_LABEL = [
            'A-line_lbl',
            'total-B-line_lbl',
            'Consolidation_lbl',
            'Pleural effusion_lbl'
            ]

    def __len__(self):
        return len(self.video_df)
    
    def __getitem__(self, idx):
         
        clip_path = self.video_path_list[idx]
        sampled_clip = load_video(self.video_path_list[idx])

        if self.is_train:
            # apply augmentation
            sometimes = lambda aug: va.Sometimes(0.5, aug)

            sigma = 0.7
            seq = va.Sequential([ # randomly rotates the video with a degree randomly choosen from [-10, 10]  
                sometimes(va.HorizontalFlip()),
                sometimes(va.RandomRotate(degrees=10))
            ])
            sampled_clip = np.array(seq(sampled_clip))
        
        augmented_images = []
        for frame in sampled_clip:
            augmented_image = torch.from_numpy(self.transforms(image=frame)['image']).permute(2, 0, 1)
            augmented_images.append(augmented_image)
            
        torch_auged_clip = torch.concat([f[None] for f in augmented_images])

        label = torch.FloatTensor(np.zeros(len(self.PRED_LABEL), dtype=float))
        
        for i in range(0, len(self.PRED_LABEL)):
            if (self.video_df[self.PRED_LABEL[i].strip()].iloc[idx].astype('float') > 0):
                label[i] = self.video_df[self.PRED_LABEL[i].strip()].iloc[idx].astype('float')
        
        return torch_auged_clip, label, clip_path

In [None]:
set_all_seeds(1234)

# Set up model
# model_version = 'v1'
# pooling_method = 'attn_multilabel'

# model_version = 'v2'
# pooling_method = 'attn_multilabel_conv'

# model_version = 'v3'
# pooling_method = 'attn_multilabel_conv'

# model_version = 'v4'
# pooling_method = 'attn_multilabel_conv'

# model_version = 'v4_'
# pooling_method = 'attn_multilabel_conv'

# model_version = 'v4__'
# pooling_method = 'attn_multilabel_conv'

# model_version = 'v4___'
# pooling_method = 'attn_multilabel_conv'

model_version = 'v4____'
pooling_method = 'attn_multilabel_conv'

# model_version = 'v5'
# pooling_method = 'attn_multilabel_conv'

# model_version = 'v6'
# pooling_method = 'attn_multilabel_conv'

# pooling_method = 'attn'
# pooling_method = 'max'
# pooling_method = 'avg'

num_heads = 8
k_size = 13

batch_size=1

encoder = timm.create_model('densenet161', pretrained=False, num_classes=0)

num_frames = [30]*batch_size
if model_version == 'v1':
    model = MedVidNet_multi_attn(encoder, num_heads, pooling_method = pooling_method)
elif model_version == 'v2':
    model = MedVidNet_multi_attn_conv(encoder, num_heads, pooling_method = pooling_method)
elif model_version == 'v3':
    model = MedVidNet_multi_attn_conv2(encoder, num_heads, pooling_method = pooling_method)
elif model_version == 'v4':
    model = MedVidNet_multi_attn_conv3(encoder, num_heads, pooling_method = pooling_method, kernel_width= k_size)
elif model_version == 'v4_':
    model = MedVidNet_multi_attn_conv3_(encoder, num_heads, pooling_method = pooling_method, kernel_width= k_size)
elif model_version == 'v4__':
    model = MedVidNet_multi_attn_conv3__(encoder, num_heads, pooling_method = pooling_method, kernel_width= k_size)
elif model_version == 'v4___':
    model = MedVidNet_multi_attn_conv3___(encoder, num_heads, pooling_method = pooling_method, kernel_width= k_size)
elif model_version == 'v4____':
    model = MedVidNet_multi_attn_conv3____(encoder, num_heads, pooling_method = pooling_method, kernel_width= k_size)
elif model_version == 'v5':
    model = MedVidNet_multi_attn_conv4(encoder, num_heads, pooling_method = pooling_method, kernel_width= k_size)
elif model_version == 'v6':
    model = MedVidNet_multi_attn_conv5(encoder, num_heads, pooling_method = pooling_method)

In [None]:
################################
# load weight
fold_num = 3

chk_std = "loss"

lr = '1e-06'

version = 'version_1'
train_layer = "all"

# model_name = 'LUVM'
model_name = 'LUV_Net'

# model_name = 'USVN'
# model_name = 'C3D'
# model_name = 'R2Plus1D'
# model_name = 'Transformer_USVN'
# model_name = 'CNNLSTM'
# model_name = 'CNNTransformer'

# encoder_name = 'densenet161'
# encoder_name = 'mae_densenet161'
encoder_name = 'imgnet_init_densenet161'

model_test_rate = "0.2"

data_type = "before_all_data"

encoder_batch_size = 32
video_batch_size = 4

model_output_class = 5
img_size = 256

gpu_index = 0
device = torch.device(f"cuda:{gpu_index}" if torch.cuda.is_available() else "cpu")
# weight_path = f'/data2/hoon2/Results/video_model2/seed234_test{model_test_rate}_std_{chk_std}_{data_type}_{version}_{train_layer}_{model_output_class}_artifacts_duplicate_batch{video_batch_size}_256_30frame_{model_name}_{model_version}_{encoder_name}_{encoder_batch_size}_{pooling_method}_{num_heads}head_{k_size}ksize_fold{fold_num}_lr{lr}_checkpoint'
weight_path = f'/data2/hoon2/Results/video_model2/seed234_test{model_test_rate}_std_{chk_std}_{data_type}_{version}_{train_layer}_{model_output_class}_artifacts_duplicate_batch{video_batch_size}_256_30frame_{model_name}_{encoder_name}_{encoder_batch_size}_{pooling_method}_{num_heads}head_{k_size}ksize_fold{fold_num}_lr{lr}_checkpoint'

check_point = torch.load(weight_path, map_location=device)

# torch.nn.DataParallel을 사용하여 모델을 학습하고 저장한 경우에 이러한 접두어가 자주 발생
if 'module' in list(check_point['model'].keys())[0]:
    # If so, remove the 'module.' prefix from the keys in the state_dict
    new_state_dict = {k[7:]: v for k, v in check_point['model'].items()}
    model.load_state_dict(new_state_dict)
else:
    # If not using DataParallel, simply load the state_dict
    model.load_state_dict(check_point['model'])

model = model.to(device)
best_val_thres = check_point['best_valid_thres']

num_pars = num_parameters(model)
num_pars_encoder = num_parameters(encoder)
print(f"Number of trainable params: {num_pars} ({num_pars - num_pars_encoder} excluding encoder).")
print(f"Number of encoder params: {num_pars_encoder}")
print(f"Number of excluding encoder: {num_pars - num_pars_encoder}")

In [None]:
base_path = f'/data2/hoon2/LUS_Dataset/csv_files/clip_multilabel_classification'

test_csv_path = os.path.join(base_path, f'{data_type}/{version}/{model_output_class}_artifacts/test_{model_test_rate}/fold_{fold_num}/test.csv')
    
# dataset
test_dataset = attention_video_dataset(test_csv_path, transforms = apply_transforms(mode=None), img_size = img_size, is_train = False)

# dataloader
test_dataloader =  torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle = False, collate_fn=collate_video, drop_last=False)

#len_dataloader
len_test_dataset = len(test_dataloader.dataset)
print("Test dataset size:", len_test_dataset)
print("Test data loader size:", len(test_dataset))

In [None]:
df = pd.read_csv(test_csv_path, index_col = False)
df

In [None]:
best_val_thres

In [None]:
print(test_dataloader.dataset[10][0].shape)
print(test_dataloader.dataset[10][1])
print(test_dataloader.dataset[10][2])

In [None]:
def get_video_label(test_label):
    # 각 인덱스에 해당하는 레이블의 이름
    label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']
    
    # test_label이 tensor인지 확인하고 list로 변환
    if isinstance(test_label, torch.Tensor):
        test_label = test_label.tolist()
    
    labels = [label_names[idx] for idx, value in enumerate(test_label) if value == 1.0]
    
    return f"{', '.join(labels)}"

In [None]:
print(test_dataloader.dataset[10][0].shape)
print(test_dataloader.dataset[10][1])
print(test_dataloader.dataset[10][2])

In [None]:
# num_heads = 16
# k_size = 5

save_dir = f'/home/work/LUS/Results/attention_output/internal_test/{model_name}_{model_version}_{video_batch_size}batch_{encoder_name}_{num_heads}head_{k_size}ksize_fold{fold_num}'

model.eval() 
with torch.no_grad():
    for data in tqdm(test_dataloader, desc="Testing", unit="batch"):    
        test_img, test_label, test_path = data
        
        '''
        # Assuming the model and get_video_label function are already defined
        plot_and_save_topk_frames(
            test_img=test_img, 
            test_label=test_label[0], 
            test_path=test_path[0], 
            best_val_thres = best_val_thres,
            num_frames=num_frames, 
            model=model, 
            get_video_label=get_video_label,
            save_dir=save_root_dir,  # Specify the save directory
        )
        file_name = test_path.split('/')[-1].split('.')[0]
        '''
        file_name = test_path[0].split('/')[-1].split('.')[0]
        test_img, test_label = test_img.float().to(device), test_label

        test_output, attentions = model(test_img, num_frames)

        sigmoid = nn.Sigmoid()
        test_output_ = sigmoid(test_output)
    #     print(test_output)
        test_preds_np = test_output_[0].data.cpu().numpy()

        test_preds_np = np.where(test_preds_np >= best_val_thres, 1, 0)

        # 각 레이블에 대해서 attention 값의 평균을 통해 각각 프레임의 중요도를 계산
        top_k = 5
        num_labels = 4
        label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']

        fig, axs = plt.subplots(4, 5, figsize=(15, 12))

        for label_idx in range(num_labels):
            # 현재 레이블의 attention 값을 추출
            frame_attention_sum = attentions[label_idx][:, 0].sum(dim=1)  # torch.Size([30])
            frame_attention_mean = frame_attention_sum / num_heads

            # 상위 top_k 프레임 인덱스 선택
            top_frame_indices = torch.topk(frame_attention_mean, k=top_k).indices
            top_frame_indices = top_frame_indices[torch.argsort(top_frame_indices)]  # 인덱스를 오름차순으로 정렬

            # 상위 프레임을 plot
            for i, idx in enumerate(top_frame_indices):
                ax = axs[label_idx, i]  
                ax.imshow(test_img[idx, 0].cpu(), cmap='gray')  
                ax.set_title(f"Frame {idx.item()} ({frame_attention_mean[idx]:.2f})", fontsize=11)
                ax.axis('off')  

            # 각 행의 첫 번째 열 위에 레이블 이름 추가
            axs[label_idx, 0].text(
                -0.3, 0.5, label_names[label_idx], transform=axs[label_idx, 0].transAxes,
                fontsize=14, fontweight='bold', ha='center', va='center', rotation=90, color='black'
            )

        # Set the title
        int_test_label_lst = test_label[0].int().tolist()
#         print('label:', int_test_label_lst)
#         print('pred:', test_preds_np)
        result = get_video_label(test_label[0])

        # 전체 그래프 제목 추가
        fig.suptitle(f"{file_name}'s Top-5 Frames for Each Label(video labels : {result})", fontsize=16)

        plt.subplots_adjust(wspace=0.3, hspace=0.5)  
        plt.tight_layout()

        # Save the plot
        os.makedirs(save_dir, exist_ok=True)  
        save_path = os.path.join(save_dir, f'{file_name}_{int_test_label_lst}_{list(test_preds_np)}.png')
        plt.savefig(save_path, bbox_inches='tight', dpi=300)

## test

In [None]:
model.eval() 
with torch.no_grad():
#     test_img, test_label = batch
    test_img, test_label, test_path = test_dataloader.dataset[0]
    
    test_img, test_label = test_img.float().to(device), test_label

    test_output, attentions = model(test_img, num_frames)

In [None]:
attentions[0].shape

In [None]:
print(test_output.shape)
print(test_output)
print(test_label)

In [None]:
test_output
sigmoid = nn.Sigmoid()
test_output2 = sigmoid(test_output)
print(test_output2)
test_preds_np = test_output2[0].data.cpu().numpy()

test_preds_np = np.where(test_preds_np >= best_val_thres, 1, 0)
print(test_preds_np)

In [None]:
label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']

# 그래프 크기 설정 (4행 1열 레이아웃)
fig, axes = plt.subplots(4, 1, figsize=(8, 8))  # 너비 8, 높이 8로 설정

# Attention 데이터를 numpy 형식으로 변환
np_attentions = [att[:, 0].cpu().numpy() for att in attentions]  # 각 레이블별 attention 값 저장

for i, ax in enumerate(axes):
    for jx in range(num_heads):  # 헤드 수만큼 그래프를 그립니다.
        ax.plot(np_attentions[i][:, jx])
#         ax.plot(np_attentions[i][:, jx], label=f'Head {jx+1}')
    
    ax.set_title(f"Attention for {label_names[i]}")
    ax.grid()
    ax.legend(fontsize=8, loc='upper right')  # 범례를 작게 설정
    ax.set_xlabel("Frame Index")
    ax.set_ylabel("Attention Score")

# 전체 레이아웃 조정
plt.tight_layout()
plt.show()

In [None]:
result = get_video_label(test_label)
result

In [None]:
print(attentions[0].shape)
print(attentions[0][:, 0].shape)
frame_attention_sum = attentions[0][:, 0].sum(dim=1)
frame_attention_sum.shape
frame_attention_sum.sum()

In [None]:
# 각 레이블에 대해서 attention 값의 평균을 통해 각각 프레임의 중요도를 계산
top_k = 5
num_labels = 4
label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']

fig, axs = plt.subplots(4, 5, figsize=(15, 12))

for label_idx in range(num_labels):
    # 현재 레이블의 attention 값을 추출
    frame_attention_sum = attentions[label_idx][:, 0].sum(dim=1)  # torch.Size([30])
    frame_attention_mean = frame_attention_sum / num_heads
    
    # 상위 top_k 프레임 인덱스 선택
    top_frame_indices = torch.topk(frame_attention_mean, k=top_k).indices
    top_frame_indices = top_frame_indices[torch.argsort(top_frame_indices)]  # 인덱스를 오름차순으로 정렬
    
    # 상위 프레임을 plot
    for i, idx in enumerate(top_frame_indices):
        ax = axs[label_idx, i]  
        ax.imshow(test_img[idx, 0].cpu(), cmap='gray')  
        ax.set_title(f"Frame {idx.item()} ({frame_attention_mean[idx]:.2f})", fontsize=11)
        ax.axis('off')  

    # 각 행의 첫 번째 열 위에 레이블 이름 추가
    axs[label_idx, 0].text(
        -0.3, 0.5, label_names[label_idx], transform=axs[label_idx, 0].transAxes,
        fontsize=14, fontweight='bold', ha='center', va='center', rotation=90, color='black'
    )

# 전체 그래프 제목 추가
fig.suptitle(f"Top-5 Frames for Each Label(video labels : {result})", fontsize=16)

plt.subplots_adjust(wspace=0.3, hspace=0.5)  
plt.tight_layout()
plt.show()

# attention plot

In [None]:
plot_file_name = '30625107_00021_25_54'

model.eval()
with torch.no_grad():
    for data in tqdm(test_dataloader, desc="Testing", unit="batch"):
        test_img, test_label, test_path = data

        # Check if the file name matches `plot_file_name`
        file_name = test_path[0].split('/')[-1].split('.')[0]
        if plot_file_name in file_name:
            print(f"Processing file: {file_name}")

            test_img, test_label = test_img.float().to(device), test_label

            # Get model outputs and attentions
            test_output, attentions = model(test_img, num_frames)

            sigmoid = nn.Sigmoid()
            test_output_ = sigmoid(test_output)
            test_preds_np = test_output_[0].data.cpu().numpy()
            test_preds_np = np.where(test_preds_np >= best_val_thres, 1, 0)

            # Plot attention scores for each label
            top_k = 3
            num_labels = 4
            label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']

            fig, axs = plt.subplots(2, 2, figsize=(16, 10))  # 2행 2열로 플롯
            axs = axs.flatten()  # Flatten for easier indexing

            for label_idx in range(num_labels):
                # Extract attention scores
                frame_attention_sum = attentions[label_idx][:, 0].sum(dim=1).cpu().numpy()
                frame_attention_mean = frame_attention_sum / num_heads

                # Find top-k indices
                top_indices = np.argsort(frame_attention_mean)[-top_k:]

                # Plot attention scores
                ax = axs[label_idx]
                ax.plot(frame_attention_mean, label='Attention Score', color='blue', linewidth=2)

                # Highlight top-k indices
                for idx in top_indices:
                    ax.scatter(
                        idx, frame_attention_mean[idx], color='red', s=100, edgecolors='black', linewidth=2, 
                        label='Top-k' if idx == top_indices[0] else ""
                    )

                # Customize the plot
                ax.set_title(f"Attention for {label_names[label_idx]}", fontsize=14, fontweight='bold')
                ax.set_xlabel("Frame Index", fontsize=12)
                ax.set_ylabel("Attention Score", fontsize=12)
                ax.grid(True)

                # Add legend to the first plot only
                if label_idx == 0:
                    ax.legend(loc='upper right', fontsize=10)

            # Adjust layout and add a main title
            plt.tight_layout()
            plt.subplots_adjust(top=0.9)
            fig.suptitle(f"Attention Scores for {plot_file_name}", fontsize=16, fontweight='bold')

            # Show the plot
            plt.show()
            break  # Exit after processing the matching file

In [None]:
frame_attention_sum

In [None]:
frame_attention_mean

In [None]:
frame_attention_mean.sum()

In [None]:
test_label

In [None]:
# single pattern example
# plot_file_name = '59681665_00010_49_78' # A-line
# plot_file_name = "30625107_00012_1_30" # B-line

# multi patterns example
plot_file_name = "30625107_00021_25_54" # B-line / consolidation

model.eval()
with torch.no_grad():
    for data in tqdm(test_dataloader, desc="Testing", unit="batch"):
        test_img, test_label, test_path = data

        # Check if the file name matches `plot_file_name`
        file_name = test_path[0].split('/')[-1].split('.')[0]
        if plot_file_name in file_name:
            print(f"Processing file: {file_name}")

            test_img, test_label = test_img.float().to(device), test_label

            # Get model outputs and attentions
            test_output, attentions = model(test_img, num_frames)

            sigmoid = nn.Sigmoid()
            test_output_ = sigmoid(test_output)
            test_preds_np = test_output_[0].data.cpu().numpy()
            test_preds_np = np.where(test_preds_np >= best_val_thres, 1, 0)

            # Parameters
            top_k = 3  # Number of top frames
            num_labels = 4
            label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']

            # Create figure for attention maps
            fig, axs = plt.subplots(2, 2, figsize=(16, 10))  # 2행 2열로 플롯
            axs = axs.flatten()  # Flatten for easier indexing

            # Initialize dictionary to store top-k frame indices per label
            top_k_frames = {}

            for label_idx in range(num_labels):
                # Extract attention scores
                frame_attention_sum = attentions[label_idx][:, 0].sum(dim=1).cpu().numpy()
                frame_attention_mean = frame_attention_sum / num_heads

                # Find top-k indices
                top_indices = np.argsort(frame_attention_mean)[-top_k:]
                top_k_frames[label_idx] = top_indices  # Save for later use

                # Plot attention scores
                ax = axs[label_idx]
                ax.plot(frame_attention_mean, label='Attention Score', color='blue', linewidth=2)

                # Highlight top-k indices
                for idx in top_indices:
                    ax.scatter(
                        idx, frame_attention_mean[idx], color='red', s=100, edgecolors='black', linewidth=2,
                        label='Top-k' if idx == top_indices[0] else ""
                    )

                # Customize the plot
                ax.set_title(f"Attention Scores for {label_names[label_idx]}", fontsize=14, fontweight='bold')
                ax.set_xlabel("Frame Index", fontsize=12)
                ax.set_ylabel("Attention Score", fontsize=12)
                ax.grid(True)
                
                # Add legend to the first plot only
                ax.legend(loc='lower right', fontsize=10)
#                 if label_idx == 0:
#                     ax.legend(loc='upper right', fontsize=10)

            # Adjust layout and add a main title
            plt.tight_layout()
            plt.subplots_adjust(top=0.9)
            fig.suptitle(f"Attention Scores for {plot_file_name}", fontsize=16, fontweight='bold')

            # Show attention map plots
            plt.show()

            # Plot top-k frames for each label
            fig, axs = plt.subplots(num_labels, top_k, figsize=(16, 4 * num_labels))  # num_labels행, top_k열

            for label_idx, indices in top_k_frames.items():
                for i, frame_idx in enumerate(indices):
                    ax = axs[label_idx, i] if num_labels > 1 else axs[i]  # Handle single-row case
                    frame = test_img[frame_idx, 0].cpu().numpy()  # Extract frame as numpy array

                    # Plot the frame
                    ax.imshow(frame, cmap='gray')

                    # Add title with frame index and attention score
                    ax.set_title(f"Label: {label_names[label_idx]}\nFrame {frame_idx} (Score: {frame_attention_mean[frame_idx]:.3f})",
                                 fontsize=10)
                    ax.axis('off')

            # Add main title for the frames plot
            plt.tight_layout()
            plt.subplots_adjust(top=0.9)
#             plt.subplots_adjust(wspace=0.2, hspace=0.3)
            fig.suptitle(f"Top-{top_k} Frames for Each Label ({plot_file_name})", fontsize=16, fontweight='bold')

            # Show top-k frame plots
            plt.show()

            break  # Exit after processing the matching file

In [None]:
frame_attention_mean

In [None]:
frame_attention_mean.sum()

In [None]:
# Define multiple ground truth ranges

# single pattern example
# gt_ranges = [(1, 14)] # A-line
# gt_ranges = [(18, 19)] # B-line

# multi patterns example
# gt_ranges = [(6, 15)] # B-line
gt_ranges = [(1, 10), (16, 29)] # consolidation

label_idx = 2

# Parameters
top_k = 3  # Number of top frames
num_labels = 4
label_names = ['A-line', 'B-lines', 'Consolidation', 'Pleural effusion']

# Compute frame attention scores
frame_attention_sum = attentions[label_idx][:, 0].sum(dim=1).cpu().numpy()
frame_attention_scaled = (frame_attention_sum - frame_attention_sum.min()) / (frame_attention_sum.max() - frame_attention_sum.min())

# Initialize the gt_line with zeros
gt_line = np.zeros(len(frame_attention_scaled))  # Length matches the number of frames

# Set the ground truth to 1 for each range
for start, end in gt_ranges:
    gt_line[start-1:end] = 1
    
# Find top-k indices based on scaled attention
top_indices = np.argsort(frame_attention_scaled)[-top_k:]

# Plot: Scaled attention scores
# fig, ax = plt.subplots(figsize=(12, 6))
fig, ax = plt.subplots(figsize=(15, 6))

# ax.plot(range(1, len(frame_attention_scaled) + 1), frame_attention_scaled, 
#         label='Scaled Attention Score', color='blue', linewidth=2)
ax.plot(range(1, len(frame_attention_scaled) + 1), frame_attention_scaled, 
        label='Scaled Attention Score', color='orange', linewidth=2)
ax.plot(range(1, len(gt_line) + 1), gt_line, 
        label='GT (Ground Truth)', color='green', linestyle='--', linewidth=2)
for idx in top_indices:
    ax.scatter(idx + 1, frame_attention_scaled[idx], color='red', s=100, 
               edgecolors='black', linewidth=2,
               label='Top-k' if idx == top_indices[0] else "")

# ax.set_title("Scaled Attention Scores (0-1 range)", fontsize=14)
ax.set_title(f"Scaled Attention Scores for {label_names[label_idx]}", fontsize=25, fontweight='bold')
ax.set_xlabel("Frame Index", fontsize=25)
ax.set_ylabel("Scaled Attention Score", fontsize=25)
ax.grid(True)
ax.legend(loc='upper right', fontsize=20)

# Adjust layout and add a main title
plt.tight_layout()
plt.subplots_adjust(top=0.85)
# plt.suptitle(f"Scaled Attention Scores for {label_names[label_idx]}\n{plot_file_name}", 
#              fontsize=14, fontweight='bold')

# Show the plot
plt.show()

In [None]:
import torch

# 원래 binary label
frame_labels = torch.tensor([0, 0, 0, 1, 0, 0, 0], dtype=torch.float32)

smoothed = gaussian_filter1d(frame_labels.numpy(), sigma=1.0)
smoothed = torch.tensor(smoothed)

# normalize to form distribution
label_dist = smoothed / smoothed.sum()