### 싱글모델로 분류

In [1]:

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
import torch.nn.functional as F
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from nmc.models import *
from nmc.datasets import * 
from nmc.augmentations import get_train_augmentation, get_val_augmentation
from nmc.losses import get_loss
from nmc.schedulers import get_scheduler
from nmc.optimizers import get_optimizer
from nmc.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from tools.val import evaluate_epi
from nmc.utils.episodic_utils import * 
from scipy.cluster import hierarchy
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
from torch.optim import lr_scheduler
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mutual_info_score
from scipy.cluster import hierarchy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.data import Subset
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import cv2
import random
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

In [2]:
with open('../configs/APTOS.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.SafeLoader)
print(cfg)
fix_seeds(3407)
setup_cudnn()
gpu = setup_ddp()
save_dir = Path(cfg['SAVE_DIR'])
save_dir.mkdir(exist_ok=True)
cleanup_ddp()

{'DEVICE': 'cuda:0', 'SAVE_DIR': 'output', 'MODEL': {'NAME': 'EfficientNetV2MModel', 'BACKBONE': 'EfficientNetV2', 'PRETRAINED': '/workspace/jhmoon/nmc_2024/checkpoints/pretrained/tf_efficientnetv2_m_weights.pth', 'UNFREEZE': 'full', 'VERSION': '384_32'}, 'DATASET': {'NAME': 'APTOSDataset', 'ROOT': '/data/public_data/aptos', 'TRAIN_RATIO': 0.7, 'VALID_RATIO': 0.15, 'TEST_RATIO': 0.15}, 'TRAIN': {'IMAGE_SIZE': [384, 384], 'BATCH_SIZE': 32, 'EPOCHS': 100, 'EVAL_INTERVAL': 25, 'AMP': False, 'DDP': False}, 'LOSS': {'NAME': 'CrossEntropy', 'CLS_WEIGHTS': False}, 'OPTIMIZER': {'NAME': 'adamw', 'LR': 0.001, 'WEIGHT_DECAY': 0.01}, 'SCHEDULER': {'NAME': 'warmuppolylr', 'POWER': 0.9, 'WARMUP': 10, 'WARMUP_RATIO': 0.1}, 'EVAL': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.APTOS.pth', 'IMAGE_SIZE': [384, 384]}, 'TEST': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.APTOS.pth', 'FILE': 'assests/ade', 'IMAGE_SIZE': [384, 384], 'OVERLAY': True}}


In [3]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

In [4]:
def get_train_augmentation(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_test_transform(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [5]:
class MultiTargetBalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, target_classes):
        self.dataset = dataset
        self.batch_size = batch_size
        self.target_classes = target_classes
        
        # 데이터셋에서 레이블 추출
        if hasattr(dataset, 'labels'):
            self.labels = dataset.labels
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        elif hasattr(dataset, 'targets'):
            self.labels = dataset.targets
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        else:
            try:
                self.labels = [sample[1] for sample in dataset]
                if isinstance(self.labels[0], np.ndarray):
                    self.labels = torch.from_numpy(np.array(self.labels))
                else:
                    self.labels = torch.tensor(self.labels)
            except:
                raise ValueError("Cannot access labels from dataset")
        
        # 각 타겟 클래스와 나머지 클래스의 인덱스 저장
        self.target_indices = {}
        for target in target_classes:
            if len(self.labels.shape) > 1:
                self.target_indices[target] = torch.where(self.labels[:, target] == 1)[0]
            else:
                self.target_indices[target] = torch.where(self.labels == target)[0]
        
        # 나머지 클래스의 인덱스 저장
        if len(self.labels.shape) > 1:
            self.other_indices = torch.where(
                torch.sum(self.labels[:, target_classes], dim=1) == 0)[0]
        else:
            mask = torch.ones_like(self.labels, dtype=torch.bool)
            for target in target_classes:
                mask &= (self.labels != target)
            self.other_indices = torch.where(mask)[0]
        
        # 각 그룹당 샘플 수 계산
        n_groups = len(target_classes) + 1  # 타겟 클래스들 + 나머지
        self.samples_per_group = batch_size // n_groups
        
        self.n_batches = len(self.dataset) // batch_size
        if len(self.dataset) % batch_size != 0:
            self.n_batches += 1
    
    def __iter__(self):
        for _ in range(self.n_batches):
            batch_indices = []
            
            # 각 타겟 클래스에서 샘플링
            for target in self.target_classes:
                target_selected = self.target_indices[target][
                    torch.randint(len(self.target_indices[target]), 
                                (self.samples_per_group,))
                ]
                batch_indices.extend(target_selected.tolist())
            
            # 나머지 클래스들에서 샘플링
            other_selected = self.other_indices[
                torch.randint(len(self.other_indices), 
                            (self.samples_per_group,))
            ]
            batch_indices.extend(other_selected.tolist())
            
            # 배치 셔플
            random.shuffle(batch_indices)
            
            # 배치 크기에 맞게 자르기 (나누어 떨어지지 않는 경우 처리)
            if len(batch_indices) > self.batch_size:
                batch_indices = batch_indices[:self.batch_size]
            
            yield batch_indices
    
    def __len__(self):
        return self.n_batches

In [6]:
start = time.time()
best_mf1 = 0.0
device = torch.device(cfg['DEVICE'])
device = "cuda:1"
print("device : ", device)
num_workers = mp.cpu_count()
train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']

image_size = [256,256]
image_dir = Path(dataset_cfg['ROOT']) / 'train_images'
train_transform = get_train_augmentation(image_size)
val_test_transform = get_val_test_transform(image_size)
batch_size = 32


dataset = eval(dataset_cfg['NAME']+'Test')(
    dataset_cfg['ROOT'] + '/combined_images',
    transform=None,
    target_label=None,
)
dataset.transform = val_test_transform
# trainset, valset, testset = dataset.get_splits()
# valset.transform = val_test_transform
# testset.transform = val_test_transform

# trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True)
# valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)
testloader = DataLoader(dataset, batch_size=1, num_workers=1, pin_memory=True)
    

device :  cuda:1
/data/public_data/aptos/combined_images


In [42]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import cv2
import os
from captum.attr import IntegratedGradients, LayerLRP
from torch.nn.functional import softmax

def denormalize(tensor):
   """Denormalize the image tensor"""
   mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
   std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
   return tensor * std + mean

def get_gradcam(model, image, target_label_idx, device):
   target_layer = model.features[-1]
   cam = GradCAM(model=model, target_layers=[target_layer])
   
   # target_label_idx가 리스트인 경우 첫 번째 값만 사용
   target = ClassifierOutputTarget(0)  # 단일 클래스만 타겟팅
   
   grayscale_cam = cam(input_tensor=image.unsqueeze(0),
                      targets=[target])
   
   return grayscale_cam[0]

def get_integrated_gradients(model, image, target_label_idx, device):
    ig = IntegratedGradients(model)
    
    image = image.clone().detach().requires_grad_(True).to(device)
    image = image.unsqueeze(0)
    baseline = torch.zeros_like(image).to(device)
    
    try:
        attributions = ig.attribute(
            image,
            baseline,
            target=0,
            n_steps=50
        )
        
        attribution_map = attributions.squeeze().permute(1, 2, 0).cpu().detach().numpy()
        denorm_image = denormalize(image.squeeze().cpu().detach()).permute(1, 2, 0).numpy()
        
        # 원본 이미지를 0-255 범위의 uint8로 변환 후 흑백으로 변환
        orig_uint8 = np.uint8(denorm_image * 255)
        gray_image = cv2.cvtColor(orig_uint8, cv2.COLOR_RGB2GRAY)
        gray_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2RGB)
        
        visualization = visualize(
            attributions=attribution_map,
            image=gray_image,  # uint8 형식의 흑백 이미지 사용
            positive_channel=[255, 0, 0],
            negative_channel=[0, 0, 0],
            polarity='positive',
            clip_above_percentile=99,
            clip_below_percentile=70,
            overlay=True,
            mask_mode=False  # overlay 모드 사용
        )
        
        # 최종 결과를 0-1 범위로 정규화
        visualization = visualization / 255.0
        
        return visualization

    except Exception as e:
        print(f"Error in Integrated Gradients: {e}")
        return np.zeros((image.shape[2], image.shape[3], 3))

def get_lrp(model, image, target_label_idx, device):
    # 모델의 마지막 feature layer를 타겟으로 설정
    target_layer = model.features[-1]
    lrp = LayerLRP(model, target_layer)
    
    # requires_grad 설정 및 차원 추가
    image = image.clone().detach().requires_grad_(True)
    image = image.unsqueeze(0)  # Add batch dimension
    
    try:
        attributions = lrp.attribute(
            image,  # 이미 batch dimension이 추가된 상태
            target=0
        )
        
        # 속성값을 시각화 가능한 형태로 변환
        attribution_map = torch.sum(torch.abs(attributions), dim=1).squeeze().cpu().detach().numpy()
        
        # 정규화
        if attribution_map.max() != attribution_map.min():
            attribution_map = (attribution_map - attribution_map.min()) / (attribution_map.max() - attribution_map.min())
        else:
            attribution_map = np.zeros_like(attribution_map)
        
        return attribution_map
    except Exception as e:
        print(f"Error in LRP: {e}")
        # 에러 발생 시 zero map 반환
        return np.zeros((image.shape[2], image.shape[3]))
    
def create_visualization_comparison(original_img, methods_results, label_info, is_correct, save_path):
    # Denormalize original image
    orig_img = denormalize(torch.from_numpy(original_img)).numpy()
    orig_img = np.clip(orig_img.transpose(1, 2, 0), 0, 1)
    orig_img = np.uint8(orig_img * 255)

    # Calculate dimensions
    height, width = orig_img.shape[:2]
    info_height = height // 8

    # Get original image name and find corresponding label image
    img_name = [line for line in label_info if "Image: " in line][0].split("Image: ")[1]
    img_base_name = img_name.split('.')[0]
    label_img_path = find_label_image(img_base_name)

    # Load and resize label image
    try:
        if label_img_path:
            label_img = cv2.imread(label_img_path)
            if label_img is not None:
                label_img = cv2.cvtColor(label_img, cv2.COLOR_BGR2RGB)
                label_img = cv2.resize(label_img, (width, height))
            else:
                label_img = orig_img.copy()
        else:
            label_img = orig_img.copy()
    except Exception as e:
        print(f"Error loading label image: {e}")
        label_img = orig_img.copy()

    # Create info background
    info_bg = np.ones((info_height, width * 4, 3), dtype=np.uint8) * 255

    # Add text in single line
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 0.7
    thickness = 1
    
    combined_text = label_info[0]
    text_size = cv2.getTextSize(combined_text, font, font_scale, thickness)[0]
    x = (info_bg.shape[1] - text_size[0]) // 2
    y = info_height // 2 + 5
    cv2.putText(info_bg, combined_text, (x, y), font, font_scale, (0, 0, 0), thickness)

    # Create image labels for both rows
    top_labels = ['Original', 'GradCAM (NMC)', 'GradCAM (APTOS)', 'GradCAM (FT-APTOS)']
    bottom_labels = ['Original with Label', 'IG (NMC)', 'IG (APTOS)', 'IG (FT-APTOS)']

    # Create label backgrounds
    label_height = 30
    top_label_bg = np.ones((label_height, width * 4, 3), dtype=np.uint8) * 255
    bottom_label_bg = np.ones((label_height, width * 4, 3), dtype=np.uint8) * 255

    # Add labels
    for idx, (top_label, bottom_label) in enumerate(zip(top_labels, bottom_labels)):
        x = width * idx + (width - cv2.getTextSize(top_label, font, font_scale, thickness)[0][0]) // 2
        cv2.putText(top_label_bg, top_label, (x, 20), font, font_scale, (0, 0, 0), thickness)
        x = width * idx + (width - cv2.getTextSize(bottom_label, font, font_scale, thickness)[0][0]) // 2
        cv2.putText(bottom_label_bg, bottom_label, (x, 20), font, font_scale, (0, 0, 0), thickness)

    # Process GradCAM results
    def visualize_gradcam(cam_output, original_image):
        cam_output[cam_output < 0.7] = 0
        img_gray = cv2.cvtColor(original_image, cv2.COLOR_RGB2GRAY) 
        img_gray = cv2.cvtColor(img_gray, cv2.COLOR_GRAY2RGB)
        heatmap = np.zeros((cam_output.shape[0], cam_output.shape[1], 3), dtype=np.uint8)
        heatmap[..., 0] = np.uint8(255 * cam_output)
        return cv2.addWeighted(img_gray, 0.7, heatmap, 0.3, 0)

    # Create visualization rows
    top_row = np.concatenate([
        orig_img,
        visualize_gradcam(methods_results['GradCAM']['nmc'], orig_img.copy()),
        visualize_gradcam(methods_results['GradCAM']['aptos'], orig_img.copy()),
        visualize_gradcam(methods_results['GradCAM']['ft_aptos'], orig_img.copy())
    ], axis=1)

    bottom_row = np.concatenate([
        label_img,
        np.uint8(255 * methods_results['Integrated Gradients']['nmc']),
        np.uint8(255 * methods_results['Integrated Gradients']['aptos']),
        np.uint8(255 * methods_results['Integrated Gradients']['ft_aptos'])
    ], axis=1)

    # Combine all elements
    final_image = np.vstack([
        info_bg,
        top_label_bg,
        top_row,
        bottom_label_bg,
        bottom_row
    ])

    # Save the final image
    cv2.imwrite(save_path, cv2.cvtColor(final_image, cv2.COLOR_RGB2BGR))


def find_label_image(base_name, label_dir='nmc_labeling'):
    """Find label image that contains the base image name"""
    try:
        for filename in os.listdir(label_dir):
            if base_name in filename:
                return os.path.join(label_dir, filename)
        return None
    except:
        return None

def compare_and_save_visualizations(nmc_model, aptos_model, ft_aptos_model, dataloader, nmc_label_idx, aptos_label_idx, device, save_dir='visualization_results'):
    categories = {
        'both_correct': [], 'only_nmc_correct': [], 
        'only_aptos_correct': [], 'both_wrong': []
    }
    
    for model in [nmc_model, aptos_model, ft_aptos_model]:
        model.eval()
        
    save_path = f"{save_dir}/comparison/label_{'-'.join(map(str, nmc_label_idx))}_vs_{aptos_label_idx}"
    os.makedirs(save_path, exist_ok=True)
    
    for batch_idx, (images, labels, img_name) in enumerate(dataloader):
        if labels.item() != aptos_label_idx:
            continue
            
        images = images.to(device)
        
        with torch.no_grad():
            # Get model outputs
            nmc_outputs = nmc_model(images)
            aptos_outputs = aptos_model(images)
            ft_aptos_outputs = ft_aptos_model(images)
            
            # Single label case handling
            if len(nmc_label_idx) == 1:
                nmc_preds = (torch.sigmoid(nmc_outputs) > 0.5).squeeze()
                nmc_raw = torch.sigmoid(nmc_outputs).squeeze()
                if len(images) == 1:
                    nmc_preds = nmc_preds.unsqueeze(0)
                    nmc_raw = nmc_raw.unsqueeze(0)
            else:
                # Multi-label case
                nmc_preds = (torch.sigmoid(nmc_outputs) > 0.5)
                nmc_raw = torch.sigmoid(nmc_outputs)
            
            # Handle APTOS and FT-APTOS predictions
            aptos_preds = (torch.sigmoid(aptos_outputs) > 0.5).squeeze()
            ft_aptos_preds = (torch.sigmoid(ft_aptos_outputs) > 0.5).squeeze()
            
            if len(images) == 1:
                aptos_preds = aptos_preds.unsqueeze(0)
                ft_aptos_preds = ft_aptos_preds.unsqueeze(0)
        
        for i in range(len(images)):
            try:
                # Format predictions
                if len(nmc_label_idx) > 1:
                    nmc_pred_list = nmc_preds[i].cpu().numpy()
                    nmc_pred_str = f"({','.join(map(str, nmc_pred_list.astype(int)))})"
                    is_nmc_correct = nmc_pred_list.all()
                else:
                    nmc_pred_val = int(nmc_preds[i].item())
                    nmc_pred_str = str(nmc_pred_val)
                    is_nmc_correct = bool(nmc_pred_val)
                
                is_aptos_correct = bool(aptos_preds[i].item())
                is_ft_aptos_correct = bool(ft_aptos_preds[i].item())
                
                methods_results = {
                    'GradCAM': {
                        'nmc': get_gradcam(nmc_model, images[i], nmc_label_idx, device),
                        'aptos': get_gradcam(aptos_model, images[i], [0], device),
                        'ft_aptos': get_gradcam(ft_aptos_model, images[i], [0], device)
                    },
                    'Integrated Gradients': {
                        'nmc': get_integrated_gradients(nmc_model, images[i], nmc_label_idx, device),
                        'aptos': get_integrated_gradients(aptos_model, images[i], [0], device),
                        'ft_aptos': get_integrated_gradients(ft_aptos_model, images[i], [0], device)
                    }
                }
                
                category = 'both_correct' if is_nmc_correct and is_aptos_correct else \
                          'only_nmc_correct' if is_nmc_correct and not is_aptos_correct else \
                          'only_aptos_correct' if not is_nmc_correct and is_aptos_correct else \
                          'both_wrong'
                
                if len(categories[category]) < 3:
                    sample_info = {
                        'image': images[i].cpu().numpy(),
                        'methods_results': methods_results,
                        'label_info': [
                            f"GT: 1    NMC: {nmc_pred_str}    APTOS: {int(aptos_preds[i].item())}    FT-APTOS: {int(ft_aptos_preds[i].item())}",
                            f"Image: {img_name[i]}"  # Keep image name but don't display it
                        ],
                        'is_correct': {
                            'nmc': is_nmc_correct,
                            'aptos': is_aptos_correct,
                            'ft_aptos': is_ft_aptos_correct
                        }
    }
                    categories[category].append(sample_info)
                    
            except Exception as e:
                print(f"Error processing sample {i} in batch {batch_idx}: {e}")
                continue
                
        if all(len(cat) >= 3 for cat in categories.values()):
            break
    
    for category_name, samples in categories.items():
        for idx, sample in enumerate(samples):
            try:
                save_name = os.path.join(save_path, f'{category_name}_{idx}.png')
                create_visualization_comparison(
                    sample['image'],
                    sample['methods_results'],
                    sample['label_info'],
                    sample['is_correct'],
                    save_name
                )
                print(f"Saved visualization: {save_name}")
            except Exception as e:
                print(f"Error creating visualization for {category_name} sample {idx}: {e}")

In [43]:
G = [0, 255, 0]
R = [255, 0, 0]

def convert_to_gray_scale(attributions):
    return np.average(attributions, axis=2)

def linear_transform(attributions, clip_above_percentile=99.9, clip_below_percentile=70.0, low=0.2, plot_distribution=False):
    m = compute_threshold_by_top_percentage(attributions, percentage=100-clip_above_percentile, plot_distribution=plot_distribution)
    e = compute_threshold_by_top_percentage(attributions, percentage=100-clip_below_percentile, plot_distribution=plot_distribution)
    transformed = (1 - low) * (np.abs(attributions) - e) / (m - e) + low
    transformed *= np.sign(attributions)
    transformed *= (transformed >= low)
    transformed = np.clip(transformed, 0.0, 1.0)
    return transformed

def compute_threshold_by_top_percentage(attributions, percentage=60, plot_distribution=True):
    if percentage < 0 or percentage > 100:
        raise ValueError('percentage must be in [0, 100]')
    if percentage == 100:
        return np.min(attributions)
    flat_attributions = attributions.flatten()
    attribution_sum = np.sum(flat_attributions)
    sorted_attributions = np.sort(np.abs(flat_attributions))[::-1]
    cum_sum = 100.0 * np.cumsum(sorted_attributions) / attribution_sum
    threshold_idx = np.where(cum_sum >= percentage)[0][0]
    threshold = sorted_attributions[threshold_idx]
    if plot_distribution:
        raise NotImplementedError 
    return threshold

def polarity_function(attributions, polarity):
    if polarity == 'positive':
        return np.clip(attributions, 0, 1)
    elif polarity == 'negative':
        return np.clip(attributions, -1, 0)
    else:
        raise NotImplementedError

def overlay_function(attributions, image):
    return np.clip(0.7 * image + 0.5 * attributions, 0, 255)

def visualize(attributions, image, positive_channel=G, negative_channel=R, polarity='positive', \
                clip_above_percentile=99.9, clip_below_percentile=0, morphological_cleanup=False, \
                structure=np.ones((3, 3)), outlines=False, outlines_component_percentage=90, overlay=True, \
                mask_mode=False, plot_distribution=False):
    if polarity == 'both':
        raise NotImplementedError

    elif polarity == 'positive':
        attributions = polarity_function(attributions, polarity=polarity)
        channel = positive_channel
    
    # convert the attributions to the gray scale
    attributions = convert_to_gray_scale(attributions)
    attributions = linear_transform(attributions, clip_above_percentile, clip_below_percentile, 0.0, plot_distribution=plot_distribution)
    attributions_mask = attributions.copy()
    if morphological_cleanup:
        raise NotImplementedError
    if outlines:
        raise NotImplementedError
    attributions = np.expand_dims(attributions, 2) * channel
    if overlay:
        if mask_mode == False:
            attributions = overlay_function(attributions, image)
        else:
            attributions = np.expand_dims(attributions_mask, 2)
            attributions = np.clip(attributions * image, 0, 255)
            attributions = attributions[:, :, (2, 1, 0)]
    return attributions

In [44]:
# Main execution
nmc_labels = [[0],[2],[1],[1,2],[5,6]]
aptos_labels = [0,1,2,3,4]  # 각각 대응되는 APTOS 라벨

for idx, nmc_label_idx in enumerate(nmc_labels):
    aptos_label_idx = aptos_labels[idx]

    print(f"\nProcessing NMC label {nmc_label_idx} and APTOS label {aptos_label_idx}")

    # Load NMC model
    nmc_model = models.efficientnet_v2_m(pretrained=True)
    num_ftrs = nmc_model.classifier[1].in_features
    nmc_model.classifier = nn.Sequential(
        nn.BatchNorm1d(num_ftrs),
        nn.Linear(num_ftrs, len(nmc_label_idx))
    )
    nmc_model = nmc_model.to(device)

    if len(nmc_label_idx)==1:
        nmc_model.load_state_dict(torch.load(f'model/singlelabel/best_model_label_{nmc_label_idx[0]}_nmc_cnn.pth'))
    else:
        nmc_model.load_state_dict(torch.load(f'model/singlelabel/best_model_labels_{"-".join(map(str,nmc_label_idx))}_nmc_cnn.pth'))

    # Load APTOS model
    aptos_model = models.efficientnet_v2_m(pretrained=True)
    aptos_model.classifier = nn.Sequential(
        nn.BatchNorm1d(num_ftrs),
        nn.Linear(num_ftrs, 1)
    )
    aptos_model = aptos_model.to(device)
    aptos_model.load_state_dict(torch.load(f'model/singlelabel/best_model_label_{aptos_label_idx}_aptos_cnn.pth'))


    # Load finetuned model
    ft_aptos_model = models.efficientnet_v2_m(pretrained=True)
    ft_aptos_model.classifier = nn.Sequential(
        nn.BatchNorm1d(num_ftrs),
        nn.Linear(num_ftrs, 1)
    )
    ft_aptos_model = ft_aptos_model.to(device)
    ft_aptos_model.load_state_dict(torch.load(f'model/singlelabel_finetuning/best_model_label_{aptos_label_idx}_aptos_cnn.pth'))
   
    # Compare and save visualizations with all three models
    compare_and_save_visualizations(
        nmc_model, 
        aptos_model,
        ft_aptos_model,
        testloader, 
        nmc_label_idx, 
        aptos_label_idx, 
        device,
        save_dir='visualization_results'
    )


Processing NMC label [0] and APTOS label 0
Saved visualization: visualization_results/comparison/label_0_vs_0/both_correct_0.png
Saved visualization: visualization_results/comparison/label_0_vs_0/both_correct_1.png
Saved visualization: visualization_results/comparison/label_0_vs_0/both_correct_2.png

Processing NMC label [2] and APTOS label 1
Saved visualization: visualization_results/comparison/label_2_vs_1/both_correct_0.png
Saved visualization: visualization_results/comparison/label_2_vs_1/both_correct_1.png
Saved visualization: visualization_results/comparison/label_2_vs_1/only_aptos_correct_0.png
Saved visualization: visualization_results/comparison/label_2_vs_1/both_wrong_0.png

Processing NMC label [1] and APTOS label 2
Saved visualization: visualization_results/comparison/label_1_vs_2/both_correct_0.png
Saved visualization: visualization_results/comparison/label_1_vs_2/both_correct_1.png
Saved visualization: visualization_results/comparison/label_1_vs_2/only_aptos_correct_0.pn