### 싱글모델로 분류

In [9]:

import torch 
import argparse
import yaml
import time
import multiprocessing as mp
import torch.nn.functional as F
from tabulate import tabulate
from tqdm import tqdm
from torch.utils.data import DataLoader
from pathlib import Path
#from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import GradScaler, autocast
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DistributedSampler, RandomSampler
from torch import distributed as dist
from nmc.models import *
from nmc.datasets import * 
from nmc.augmentations import get_train_augmentation, get_val_augmentation
from nmc.losses import get_loss
from nmc.schedulers import get_scheduler
from nmc.optimizers import get_optimizer
from nmc.utils.utils import fix_seeds, setup_cudnn, cleanup_ddp, setup_ddp
from tools.val import evaluate_epi
from nmc.utils.episodic_utils import * 
from scipy.cluster import hierarchy
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from torchvision import models
import torch.nn as nn
from torch.optim import lr_scheduler
import numpy as np
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import mutual_info_score
from scipy.cluster import hierarchy
from tqdm import tqdm
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, hamming_loss
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.utils.data import Subset
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import cv2
import random
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

In [10]:
with open('../configs/APTOS.yaml') as f:
    cfg = yaml.load(f, Loader=yaml.SafeLoader)
print(cfg)
fix_seeds(3407)
setup_cudnn()
gpu = setup_ddp()
save_dir = Path(cfg['SAVE_DIR'])
save_dir.mkdir(exist_ok=True)
cleanup_ddp()

{'DEVICE': 'cuda:0', 'SAVE_DIR': 'output', 'MODEL': {'NAME': 'EfficientNetV2MModel', 'BACKBONE': 'EfficientNetV2', 'PRETRAINED': '/workspace/jhmoon/nmc_2024/checkpoints/pretrained/tf_efficientnetv2_m_weights.pth', 'UNFREEZE': 'full', 'VERSION': '384_32'}, 'DATASET': {'NAME': 'APTOSDataset', 'ROOT': '/data/public_data/aptos', 'TRAIN_RATIO': 0.7, 'VALID_RATIO': 0.15, 'TEST_RATIO': 0.15}, 'TRAIN': {'IMAGE_SIZE': [384, 384], 'BATCH_SIZE': 32, 'EPOCHS': 100, 'EVAL_INTERVAL': 25, 'AMP': False, 'DDP': False}, 'LOSS': {'NAME': 'CrossEntropy', 'CLS_WEIGHTS': False}, 'OPTIMIZER': {'NAME': 'adamw', 'LR': 0.001, 'WEIGHT_DECAY': 0.01}, 'SCHEDULER': {'NAME': 'warmuppolylr', 'POWER': 0.9, 'WARMUP': 10, 'WARMUP_RATIO': 0.1}, 'EVAL': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.APTOS.pth', 'IMAGE_SIZE': [384, 384]}, 'TEST': {'MODEL_PATH': 'checkpoints/pretrained/FGMaxxVit/FGMaxxVit.FGMaxxVit.APTOS.pth', 'FILE': 'assests/ade', 'IMAGE_SIZE': [384, 384], 'OVERLAY': True}}


In [11]:
# Early Stopping
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_score = None
        self.early_stop = False

    def __call__(self, val_score):
        if self.best_score is None:
            self.best_score = val_score
        elif val_score < self.best_score + self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = val_score
            self.counter = 0

In [12]:
def get_train_augmentation(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

def get_val_test_transform(size):
    return transforms.Compose([
        transforms.Resize(size),
        transforms.Lambda(lambda x: x.float() if x.dtype == torch.uint8 else x),
        transforms.Lambda(lambda x: x / 255.0 if x.max() > 1.0 else x),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])


In [13]:
class MultiTargetBalancedBatchSampler(Sampler):
    def __init__(self, dataset, batch_size, target_classes):
        self.dataset = dataset
        self.batch_size = batch_size
        self.target_classes = target_classes
        
        # 데이터셋에서 레이블 추출
        if hasattr(dataset, 'labels'):
            self.labels = dataset.labels
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        elif hasattr(dataset, 'targets'):
            self.labels = dataset.targets
            if isinstance(self.labels, np.ndarray):
                self.labels = torch.from_numpy(self.labels)
        else:
            try:
                self.labels = [sample[1] for sample in dataset]
                if isinstance(self.labels[0], np.ndarray):
                    self.labels = torch.from_numpy(np.array(self.labels))
                else:
                    self.labels = torch.tensor(self.labels)
            except:
                raise ValueError("Cannot access labels from dataset")
        
        # 각 타겟 클래스와 나머지 클래스의 인덱스 저장
        self.target_indices = {}
        for target in target_classes:
            if len(self.labels.shape) > 1:
                self.target_indices[target] = torch.where(self.labels[:, target] == 1)[0]
            else:
                self.target_indices[target] = torch.where(self.labels == target)[0]
        
        # 나머지 클래스의 인덱스 저장
        if len(self.labels.shape) > 1:
            self.other_indices = torch.where(
                torch.sum(self.labels[:, target_classes], dim=1) == 0)[0]
        else:
            mask = torch.ones_like(self.labels, dtype=torch.bool)
            for target in target_classes:
                mask &= (self.labels != target)
            self.other_indices = torch.where(mask)[0]
        
        # 각 그룹당 샘플 수 계산
        n_groups = len(target_classes) + 1  # 타겟 클래스들 + 나머지
        self.samples_per_group = batch_size // n_groups
        
        self.n_batches = len(self.dataset) // batch_size
        if len(self.dataset) % batch_size != 0:
            self.n_batches += 1
    
    def __iter__(self):
        for _ in range(self.n_batches):
            batch_indices = []
            
            # 각 타겟 클래스에서 샘플링
            for target in self.target_classes:
                target_selected = self.target_indices[target][
                    torch.randint(len(self.target_indices[target]), 
                                (self.samples_per_group,))
                ]
                batch_indices.extend(target_selected.tolist())
            
            # 나머지 클래스들에서 샘플링
            other_selected = self.other_indices[
                torch.randint(len(self.other_indices), 
                            (self.samples_per_group,))
            ]
            batch_indices.extend(other_selected.tolist())
            
            # 배치 셔플
            random.shuffle(batch_indices)
            
            # 배치 크기에 맞게 자르기 (나누어 떨어지지 않는 경우 처리)
            if len(batch_indices) > self.batch_size:
                batch_indices = batch_indices[:self.batch_size]
            
            yield batch_indices
    
    def __len__(self):
        return self.n_batches

In [14]:
start = time.time()
best_mf1 = 0.0
device = torch.device(cfg['DEVICE'])
device = "cuda:1"
print("device : ", device)
num_workers = mp.cpu_count()
train_cfg, eval_cfg = cfg['TRAIN'], cfg['EVAL']
dataset_cfg, model_cfg = cfg['DATASET'], cfg['MODEL']
loss_cfg, optim_cfg, sched_cfg = cfg['LOSS'], cfg['OPTIMIZER'], cfg['SCHEDULER']
epochs, lr = train_cfg['EPOCHS'], optim_cfg['LR']

image_size = [256,256]
image_dir = Path(dataset_cfg['ROOT']) / 'train_images'
train_transform = get_train_augmentation(image_size)
val_test_transform = get_val_test_transform(image_size)
batch_size = 32


dataset = eval(dataset_cfg['NAME']+'Test')(
    dataset_cfg['ROOT'] + '/combined_images',
    transform=None,
    target_label=None,
)
dataset.transform = val_test_transform
# trainset, valset, testset = dataset.get_splits()
# valset.transform = val_test_transform
# testset.transform = val_test_transform

# trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=num_workers, drop_last=True, pin_memory=True)
# valloader = DataLoader(valset, batch_size=1, num_workers=1, pin_memory=True)
testloader = DataLoader(dataset, batch_size=1, num_workers=1, pin_memory=True)
    

device :  cuda:1
/data/public_data/aptos/combined_images


In [15]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn.functional as F
import cv2
import os
def denormalize(tensor):
   """Denormalize the image tensor"""
   mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
   std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
   return tensor * std + mean

def get_gradcam(model, image, target_label_idx, device):
    target_layer = model.features[-1]
    cam = GradCAM(model=model, target_layers=[target_layer])
    
    # target_label_idx가 리스트인 경우 첫 번째 값만 사용
    target = ClassifierOutputTarget(0)  # 단일 클래스만 타겟팅
    
    grayscale_cam = cam(input_tensor=image.unsqueeze(0),
                       targets=[target])
    
    return grayscale_cam[0]

def create_comparison_image(original_img, nmc_heatmap, aptos_heatmap, label_info, nmc_correct, aptos_correct, save_path):
   # Denormalize the original image
   orig_img = denormalize(torch.from_numpy(original_img)).numpy()
   orig_img = np.clip(orig_img.transpose(1, 2, 0), 0, 1)
   orig_img = np.uint8(orig_img * 255)
   
   # Create heatmap overlays
   nmc_heatmap_rgb = np.uint8(255 * nmc_heatmap)
   nmc_heatmap_rgb = cv2.applyColorMap(nmc_heatmap_rgb, cv2.COLORMAP_JET)
   nmc_superimposed = cv2.addWeighted(orig_img, 0.6, nmc_heatmap_rgb, 0.4, 0)
   
   aptos_heatmap_rgb = np.uint8(255 * aptos_heatmap)
   aptos_heatmap_rgb = cv2.applyColorMap(aptos_heatmap_rgb, cv2.COLORMAP_JET)
   aptos_superimposed = cv2.addWeighted(orig_img, 0.6, aptos_heatmap_rgb, 0.4, 0)
   
   # Create a white background for label info
   height, width = orig_img.shape[:2]
   info_height = height // 2  # 라벨 정보 영역의 높이를 절반으로
   info_bg = np.ones((info_height, width * 3, 3), dtype=np.uint8) * 255  # 3배 너비
   
   # Add text to info background
   font = cv2.FONT_HERSHEY_SIMPLEX
   font_scale = 0.5  # 글씨 크기 축소
   thickness = 1
   color = (0, 0, 0)  # Black color
   
   # Add label information
   y_offset = 20
   for line in label_info:
       text_size = cv2.getTextSize(line, font, font_scale, thickness)[0]
       x = (info_bg.shape[1] - text_size[0]) // 2
       y = y_offset
       cv2.putText(info_bg, line, (x, y), font, font_scale, color, thickness)
       y_offset += 20
   
   # Create result texts
   nmc_result = "NMC: Correct" if nmc_correct else "NMC: Incorrect"
   aptos_result = "APTOS: Correct" if aptos_correct else "APTOS: Incorrect"
   
   # Add model results with color coding
   text_size = cv2.getTextSize(nmc_result, font, font_scale, thickness)[0]
   x = width + (width - text_size[0]) // 2
   y = height // 2 - 20
   color = (0, 255, 0) if nmc_correct else (0, 0, 255)
   cv2.putText(info_bg, nmc_result, (x, y), font, font_scale, color, thickness)
   
   text_size = cv2.getTextSize(aptos_result, font, font_scale, thickness)[0]
   x = 2 * width + (width - text_size[0]) // 2
   color = (0, 255, 0) if aptos_correct else (0, 0, 255)
   cv2.putText(info_bg, aptos_result, (x, y), font, font_scale, color, thickness)
   
   # Concatenate images horizontally
   images_row = np.concatenate([orig_img, nmc_superimposed, aptos_superimposed], axis=1)
   
   # Concatenate info and images vertically
   final_image = np.concatenate([info_bg, images_row], axis=0)
   
   # Save the final image
   cv2.imwrite(save_path, cv2.cvtColor(final_image, cv2.COLOR_RGB2BGR))

def compare_and_save_gradcam(nmc_model, aptos_model, dataloader, nmc_label_idx, aptos_label_idx, device, save_dir='grad_results'):
    nmc_model.eval()
    aptos_model.eval()
    
    both_correct = []
    only_nmc_correct = []
    only_aptos_correct = []
    both_wrong = []
    
    save_path = f"{save_dir}/comparison/label_{'-'.join(map(str, nmc_label_idx))}_vs_{aptos_label_idx}"
    os.makedirs(save_path, exist_ok=True)
    
    print(f"\nProcessing APTOS label {aptos_label_idx} with corresponding NMC labels {nmc_label_idx}")
    
    with torch.no_grad():
        for batch_idx, (images, labels,img_name) in enumerate(dataloader):
            images = images.to(device)
            labels = labels.to(device)
            
            # 현재 데이터의 라벨이 우리가 평가하려는 라벨(aptos_label_idx)과 같은 경우만 처리
            if labels.item() != aptos_label_idx:
                continue
                
            # APTOS model predictions (binary classification for current label)
            aptos_outputs = aptos_model(images)
            aptos_predictions = (torch.sigmoid(aptos_outputs) > 0.5).squeeze()
            aptos_raw_preds = torch.sigmoid(aptos_outputs).squeeze()
            
            # NMC model predictions
            nmc_outputs = nmc_model(images)
            nmc_predictions = (torch.sigmoid(nmc_outputs) > 0.5)
            nmc_raw_preds = torch.sigmoid(nmc_outputs)
            
            # Handle batch size 1 case
            if len(images) == 1:
                if not aptos_predictions.shape:  # 스칼라인 경우
                    aptos_predictions = aptos_predictions.unsqueeze(0)
                    aptos_raw_preds = aptos_raw_preds.unsqueeze(0)
                if not nmc_predictions.shape:  # 스칼라인 경우
                    nmc_predictions = nmc_predictions.unsqueeze(0)
                    nmc_raw_preds = nmc_raw_preds.unsqueeze(0)
            
            for i in range(len(images)):
                try:
                    # APTOS의 예측이 1이면 맞는 것 (현재 라벨에 대한 이진 분류)
                    aptos_correct = aptos_predictions[i].item() == 1
                    
                    # NMC의 경우 모든 해당 라벨에 대해 1이어야 맞는 것
                    nmc_is_correct = torch.all(nmc_predictions[i] == 1)
                    
                    # 디버깅을 위한 출력
                    print(f"\nSample {i} in batch {batch_idx}:")
                    print(f"APTOS prediction: {aptos_raw_preds[i].item():.3f}, Correct: {aptos_correct}")
                    print(f"NMC predictions: {nmc_raw_preds[i].cpu().numpy()}, Correct: {nmc_is_correct}")
                    
                    sample_info = {
                        'image': images[i],
                        'aptos_target': 1,  # 현재 라벨에 대해서는 항상 1이 타겟
                        'nmc_preds': nmc_raw_preds[i].cpu().numpy(),
                        'aptos_pred': aptos_raw_preds[i].item(),
                        'img_name': img_name[i]  # 이미지 파일명 추가
                    }
                    
                    # Categorize samples
                    if aptos_correct and nmc_is_correct and len(both_correct) < 3:
                        print("Adding to both_correct")
                        both_correct.append(sample_info)
                    elif not aptos_correct and nmc_is_correct and len(only_nmc_correct) < 3:
                        print("Adding to only_nmc_correct")
                        only_nmc_correct.append(sample_info)
                    elif aptos_correct and not nmc_is_correct and len(only_aptos_correct) < 3:
                        print("Adding to only_aptos_correct")
                        only_aptos_correct.append(sample_info)
                    elif not aptos_correct and not nmc_is_correct and len(both_wrong) < 3:
                        print("Adding to both_wrong")
                        both_wrong.append(sample_info)
                
                except Exception as e:
                    print(f"Error processing sample {i} in batch {batch_idx}: {e}")
                    continue
                
                # 각 카테고리별 현재 수집된 샘플 수 출력
                print(f"\nCurrent samples collected:")
                print(f"Both correct: {len(both_correct)}")
                print(f"Only NMC correct: {len(only_nmc_correct)}")
                print(f"Only APTOS correct: {len(only_aptos_correct)}")
                print(f"Both wrong: {len(both_wrong)}")
                
                if (len(both_correct) >= 3 and 
                    len(only_nmc_correct) >= 3 and 
                    len(only_aptos_correct) >= 3 and 
                    len(both_wrong) >= 3):
                    break
                    
            if (len(both_correct) >= 3 and 
                len(only_nmc_correct) >= 3 and 
                len(only_aptos_correct) >= 3 and 
                len(both_wrong) >= 3):
                break
    
    print(f"\nFinal samples collected:")
    print(f"Both correct: {len(both_correct)}")
    print(f"Only NMC correct: {len(only_nmc_correct)}")
    print(f"Only APTOS correct: {len(only_aptos_correct)}")
    print(f"Both wrong: {len(both_wrong)}")
    
    # Save GradCAM visualizations for all categories
    for category, samples, category_name in [
        (both_correct, "both_correct", "Both Correct"),
        (only_nmc_correct, "only_nmc", "Only NMC Correct"),
        (only_aptos_correct, "only_aptos", "Only APTOS Correct"),
        (both_wrong, "both_wrong", "Both Wrong")
    ]:
        for idx, sample in enumerate(category):
            try:
                # Generate GradCAM for both models
                nmc_heatmap = get_gradcam(nmc_model, sample['image'], nmc_label_idx, device)
                aptos_heatmap = get_gradcam(aptos_model, sample['image'], [0], device)  # APTOS는 단일 출력
                
                # Prepare label info
                nmc_pred_str = np.array2string(sample['nmc_preds'], precision=3)
                aptos_pred_str = f"{sample['aptos_pred']:.3f}"
                aptos_target_str = f"{int(sample['aptos_target'])}"
                
                nmc_is_correct = np.all(sample['nmc_preds'] > 0.5)
                aptos_is_correct = (sample['aptos_pred'] > 0.5) == sample['aptos_target']
                
                label_info = [
                    f"Category: {category_name}",
                    f"Image: {sample['img_name']}",  # 이미지 파일명 추가
                    f"NMC Labels: {nmc_label_idx}, APTOS Label: {aptos_label_idx}",
                    f"APTOS - True: {aptos_target_str}, Pred: {aptos_pred_str}",
                    f"NMC Predictions: {nmc_pred_str}"
                ]
                
                # Save combined result
                save_name = os.path.join(save_path, f'{category_name.lower().replace(" ", "_")}_{idx}.png')
                create_comparison_image(
                    sample['image'].cpu().numpy(),
                    nmc_heatmap,
                    aptos_heatmap,
                    label_info,
                    nmc_is_correct,
                    aptos_is_correct,
                    save_name
                )
                print(f"Saved image: {save_name}")
                
            except Exception as e:
                print(f"Error processing {category_name} sample {idx}: {e}")
                print(f"Sample info: {sample}")
                continue

In [16]:
# Main execution
nmc_labels = [[0],[2],[1],[1,2],[5,6]]
aptos_labels = [0,1,2,3,4]  # 각각 대응되는 APTOS 라벨

for idx, nmc_label_idx in enumerate(nmc_labels):
   aptos_label_idx = aptos_labels[idx]
   
   print(f"\nProcessing NMC label {nmc_label_idx} and APTOS label {aptos_label_idx}")
   
   # Load NMC model
   nmc_model = models.efficientnet_v2_m(pretrained=True)
   num_ftrs = nmc_model.classifier[1].in_features
   nmc_model.classifier = nn.Sequential(
       nn.BatchNorm1d(num_ftrs),
       nn.Linear(num_ftrs, len(nmc_label_idx))
   )
   nmc_model = nmc_model.to(device)
   
   if len(nmc_label_idx)==1:
       nmc_model.load_state_dict(torch.load(f'model/singlelabel/best_model_label_{nmc_label_idx[0]}_nmc_cnn.pth'))
   else:
       nmc_model.load_state_dict(torch.load(f'model/singlelabel/best_model_labels_{"-".join(map(str,nmc_label_idx))}_nmc_cnn.pth'))
   
   # Load APTOS model
   aptos_model = models.efficientnet_v2_m(pretrained=True)
   aptos_model.classifier = nn.Sequential(
       nn.BatchNorm1d(num_ftrs),
       nn.Linear(num_ftrs, 1)
   )
   aptos_model = aptos_model.to(device)
   aptos_model.load_state_dict(torch.load(f'model/singlelabel/best_model_label_{aptos_label_idx}_aptos_cnn.pth'))
   
   # Compare and save results
   compare_and_save_gradcam(nmc_model, aptos_model, testloader, nmc_label_idx, aptos_label_idx, device,save_dir='grad_label_results')


Processing NMC label [0] and APTOS label 0

Processing APTOS label 0 with corresponding NMC labels [0]

Sample 0 in batch 0:
APTOS prediction: 1.000, Correct: True
NMC predictions: [0.7937736], Correct: True
Adding to both_correct

Current samples collected:
Both correct: 1
Only NMC correct: 0
Only APTOS correct: 0
Both wrong: 0

Sample 0 in batch 1:
APTOS prediction: 1.000, Correct: True
NMC predictions: [0.81145424], Correct: True
Adding to both_correct

Current samples collected:
Both correct: 2
Only NMC correct: 0
Only APTOS correct: 0
Both wrong: 0

Sample 0 in batch 2:
APTOS prediction: 1.000, Correct: True
NMC predictions: [0.9791905], Correct: True
Adding to both_correct

Current samples collected:
Both correct: 3
Only NMC correct: 0
Only APTOS correct: 0
Both wrong: 0

Sample 0 in batch 3:
APTOS prediction: 1.000, Correct: True
NMC predictions: [0.9932122], Correct: True

Current samples collected:
Both correct: 3
Only NMC correct: 0
Only APTOS correct: 0
Both wrong: 0

Sampl