#### 환경설정

##### 1. Wandb

In [None]:
import wandb

# wandb 로그인
!wandb login

##### 2. 라이브러리 로드

In [None]:

import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import math
import random
import pickle
import wandb
from tqdm import tqdm
from datetime import datetime
from zoneinfo import ZoneInfo

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import librosa.display

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torchvision
import torchvision.models as models
from torch import Tensor
from torchsummary import summary
from torch.hub import load_state_dict_from_url
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.metrics import confusion_matrix, f1_score
from sklearn.manifold import TSNE

##### 3. 경로 설정

In [None]:
ROOT = "/home/sbw/BOAZ-Chungzins/data/raw"
CHECKPOINT_PATH = "/home/sbw/boaz/notebook/note_ckp"
PICKLE_PATH = "/home/sbw/boaz/notebook/pickle"
text = "/home/sbw/BOAZ-Chungzins/data/metadata/train_test_split.txt"

##### 4. Seed 설정

In [None]:
def seed_everything(seed: int = 42):
    import random, os
    import numpy as np
    import torch

    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # ✅ 모든 GPU에 동일하게
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    # 추가: DataLoader에 worker_init_fn 활용 (아래 예시 참고)


## 1. Data Load

#### 1.1 Data Load

In [None]:
# WAV 파일이 있는 디렉토리 경로
data_dir = ROOT
txt_dir = ROOT

df = pd.read_csv(text, sep='\t', header=None)

# 컬럼 이름 변경
df.columns = ['filename', 'set']

# train, test split
train_df = df[df['set'] == 'train']
test_df = df[df['set'] == 'test']

# filename list
train_list = sorted(train_df['filename'].tolist())
test_list = sorted(test_df['filename'].tolist())

print(f'Train :{len(train_list)}, Test: {len(test_list)}, Total: {len(train_list) + len(test_list)}')

#### 1.2 Pretext-Finetune Split

In [None]:
# shuffle train data
df_shuffled = train_df.sample(frac=1, random_state=42)

# split ratio
train_size = int(len(df_shuffled))

# pretrain
pretrain_df = df_shuffled[:train_size]

# filename list (pretext_list -> pretrain list)
pretrain_list = sorted(pretrain_df['filename'].tolist())

# patient id list
pretrain_patient_list = []
for filename in pretrain_list:
    number = int(filename.split('_')[0])
    pretrain_patient_list.append(number)


pretrain_patient_counts = pd.Series(pretrain_patient_list).value_counts()

print(f"[Pretrain] 환자 수: {len(pretrain_patient_counts.index)}, 샘플 수: {pretrain_patient_counts.sum()}")

## 2. Data Preprocessing

#### 2.1 Args

        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)

In [None]:
class Args:
    # Audio & Spectrogram
    target_sr = 16000    # 4KHz
    frame_size = 1024
    hop_length = 512    # frame_size 절반
    n_mels = 64
    target_sec = 8

    # Augmentation
    time_mask_param = 0.5
    freq_mask_param = 0.5

    # Train
    lr = 1e-3 # adamw - 0.03
    warm = True                     # warm-up 사용 여부
    warm_epochs = 10                # warm-up 적용할 초기 epoch 수
    warmup_from = lr * 0.1          # warm-up 시작 learning rate (보통 lr의 10%)
    warmup_to = lr

    batch_size = 128
    workers = 2
    epochs = 300
    weight_decay = 0.0

    resume = None
    schedule=[120, 160] # schedule

    # MLS
    K = 512
    momentum = 0.999
    T = 0.07
    dim_prj = 128
    top_k = 20
    lambda_bce = 0.3
    out_dim = 512

    # Linear Evaluation
    # ft_epochs = 3

    # etc
    gpu = 0
    data = "./data_path"
    seed=42
    num_classes = 2

    # update
    ma_update = False
    ma_beta = 0.5
    target_type = 'grad_flow'
    alpha = 0.3


args = Args()

#### 2.2 Utils (func)

In [None]:
import torch.nn.functional as F
import random

# cycle의 클래스를 추출
def get_class(cr, wh):
    if cr == 1 and wh == 1:
        return 3
    elif cr == 0 and wh == 1:
        return 2
    elif cr == 1 and wh == 0:
        return 1
    elif cr == 0 and wh == 0:
        return 0
    else:
        return -1

# 256 맟춰주기
def generate_mel_spectrogram(waveform, sample_rate, frame_size, hop_length, n_mels):
    if hop_length is None:
        hop_length = frame_size // 2

    mel_spec_transform = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=frame_size,
        hop_length=hop_length,
        n_mels=n_mels,
        f_min=50,
        f_max=2000
    )
    mel_spectrogram = mel_spec_transform(waveform)
    mel_db = T.AmplitudeToDB()(mel_spectrogram)

    # dB 스케일에서 매우 낮은 값은 0으로 마스킹
    mel_db[mel_db <= -100.0] = 0.0

    # 🔧 가운데 padding 적용
    target_frames = 256
    current_frames = mel_db.shape[-1]
    if current_frames < target_frames:
        pad_total = target_frames - current_frames
        pad_left = pad_total // 2
        pad_right = pad_total - pad_left
        mel_db = F.pad(mel_db, (pad_left, pad_right))  # center padding
    elif current_frames > target_frames:
        # 가운데 자르기
        start = (current_frames - target_frames) // 2
        mel_db = mel_db[:, :, start:start + target_frames]

    return mel_db

# Cycle Repeat 또는 Crop
def repeat_or_truncate_segment(mel_segment, target_frames):
    current_frames = mel_segment.shape[-1]
    if current_frames >= target_frames:
        return mel_segment[:, :, :target_frames]
    else:
        repeat_ratio = math.ceil(target_frames / current_frames)
        mel_segment = mel_segment.repeat(1, 1, repeat_ratio)
        return mel_segment[:, :, :target_frames]

def preprocess_waveform_segment(waveform, unit_length):

    """unit_length 기준으로 waveform을 repeat + padding 또는 crop하여 길이 정규화"""
    waveform = waveform.squeeze(0)  # (1, L) → (L,) 로 바꿔도 무방
    length_adj = unit_length - len(waveform)

    if length_adj > 0:
        # waveform이 너무 짧은 경우 → repeat + zero-padding
        half_unit = unit_length // 2

        if length_adj < half_unit:
            # 길이 차이가 작으면 단순 padding
            half_adj = length_adj // 2
            waveform = F.pad(waveform, (half_adj, length_adj - half_adj))
        else:
            # 반복 후 부족한 부분 padding
            repeat_factor = unit_length // len(waveform)
            waveform = waveform.repeat(repeat_factor)[:unit_length]
            remaining = unit_length - len(waveform)
            half_pad = remaining // 2
            waveform = F.pad(waveform, (half_pad, remaining - half_pad))
    else:
        # waveform이 너무 길면 앞쪽 1/4 내에서 랜덤 crop
        length_adj = len(waveform) - unit_length
        start = random.randint(0, length_adj // 4)
        waveform = waveform[start:start + unit_length]

    return waveform.unsqueeze(0)  # 다시 (1, L)로


# 데이터 Spec Augmentation ( 0~80% Random Masking )
def apply_spec_augment(mel_segment):

    M = mel_segment.shape[-1]
    F = mel_segment.shape[-2]

    # torchaudio의 마스킹은 0부터 mask_param까지 균등분포에서 랜덤하게 길이를 선택
    time_masking = T.TimeMasking(time_mask_param=int(M * 0.8))
    freq_masking = T.FrequencyMasking(freq_mask_param=int(F * 0.8) )

    aug1 = freq_masking(mel_segment.clone())
    aug2 = time_masking(mel_segment.clone())
    aug3 = freq_masking(time_masking(mel_segment.clone()))

    return aug1, aug2, aug3

# Waveform resample
def resample_waveform(waveform, orig_sr, target_sr=args.target_sr):
    if orig_sr != target_sr:
        resampler = torchaudio.transforms.Resample(
            orig_freq=orig_sr,
            new_freq=target_sr
        )
        return resampler(waveform), target_sr
    return waveform, orig_sr

# Normalize - Mean/Std
# def get_mean_and_std(dataset):
#     """ 전체 mel-spectrogram에서 mean과 std 계산 """
#     dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

#     cnt = 0
#     fst_moment = torch.zeros(1)
#     snd_moment = torch.zeros(1)
#     for inputs, _, _ in tqdm(dataloader, desc="[Calculating Mean/Std]"):
#         b, c, h, w = inputs.shape  # inputs: [1, 1, n_mels, time]
#         nb_pixels = b * h * w

#         fst_moment += torch.sum(inputs, dim=[0, 2, 3])
#         snd_moment += torch.sum(inputs**2, dim=[0, 2, 3])
#         cnt += nb_pixels

#     mean = fst_moment / cnt
#     std = torch.sqrt(snd_moment / cnt - mean**2)
#     return mean.item(), std.item()

def get_mean_and_std(dataset, mask_threshold=-99.0):
    """ 마스킹(-100 등)을 제외하고 mean/std 계산 """
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)

    cnt = 0
    fst_moment = 0.0
    snd_moment = 0.0

    for inputs, _, _ in tqdm(dataloader, desc="[Calculating Mean/Std]"):
        # mask: 유효한 mel 값만 추출 (e.g. > -99.0)
        valid = inputs[inputs > mask_threshold]  # 1D tensor

        fst_moment += valid.sum().item()
        snd_moment += (valid ** 2).sum().item()
        cnt += valid.numel()

    mean = fst_moment / cnt
    std = np.sqrt(snd_moment / cnt - mean**2)
    return mean, std

In [None]:
def get_timestamp():
    """Outputs current time in KST like 2404070830"""
    kst_time = datetime.now(ZoneInfo("Asia/Seoul"))
    return kst_time.strftime('%y%m%d%H%M')

#### 2.3 CycleDataset

In [None]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm

class CycleDataset(Dataset):
    def __init__(self, filename_list, wav_dir, txt_dir, target_sec=args.target_sec, target_sr=args.target_sr, frame_size=args.frame_size, hop_length=args.hop_length, n_mels=args.n_mels, mean=None, std=None):
        self.filename_list = filename_list
        self.wav_dir = wav_dir
        self.txt_dir = txt_dir
        self.target_sec = target_sec
        self.target_sr = target_sr
        self.frame_size = frame_size
        self.hop_length = hop_length
        self.n_mels = n_mels
        self.mean = mean
        self.std = std

        self.cycle_list = []

        print("[INFO] Preprocessing cycles...")
        for filename in tqdm(self.filename_list):
            txt_path = os.path.join(self.txt_dir, filename + '.txt')
            wav_path = os.path.join(self.wav_dir, filename + '.wav')

            if not os.path.exists(txt_path):
                print(f"[WARNING] Missing file: {txt_path}")
            if not os.path.exists(wav_path):
                print(f"[WARNING] Missing file: {wav_path}")

            # Load annotation
            cycle_data = np.loadtxt(txt_path, usecols=(0, 1))
            lung_label = np.loadtxt(txt_path, usecols=(2, 3))

            # Load waveform
            waveform, orig_sr = torchaudio.load(wav_path)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)  # Stereo to mono
                print(' waveform.shape[0] > 1:')

            # Resample to target sample rate (4kHz)
            waveform, sample_rate = resample_waveform(waveform, orig_sr, self.target_sr)

            for idx in range(len(cycle_data)):
                # 호흡 주기 start, end
                start_sample = int(cycle_data[idx, 0] * sample_rate)
                end_sample = int(cycle_data[idx, 1] * sample_rate)
                lung_duration = cycle_data[idx, 1] - cycle_data[idx, 0]

                if end_sample <= start_sample:
                    print('end_sample <= start_sample:')
                    continue  # 잘못된 구간 스킵

                # Waveform repeat + padding 후 Mel_db
                cycle_wave = waveform[:, start_sample:end_sample]
                seg_wave = preprocess_waveform_segment(cycle_wave, unit_length=int(self.target_sec * self.target_sr))
                mel = generate_mel_spectrogram(seg_wave, sample_rate, frame_size=self.frame_size, hop_length=self.hop_length, n_mels=self.n_mels)

                # 정규화
                if self.mean is not None and self.std is not None:
                    mask_value = -100.0 # mel db 에서 마스킹된 값
                    mask = (mel == mask_value)
                    mel = (mel - mean) / std
                    mel[mask] = 0.0
                    
                # crackle, wheeze -> class
                cr = int(lung_label[idx, 0])
                wh = int(lung_label[idx, 1])
                label = get_class(cr, wh)

                multi_label = torch.tensor([
                    float(label in [1, 3]),
                    float(label in [2, 3])
                ])  # 변환된 multi-label 반환

                # meta_data
                meta_data = (filename, lung_duration)

                self.cycle_list.append((mel, multi_label, meta_data))

        print(f"[INFO] Total cycles collected: {len(self.cycle_list)}")

    def __len__(self):
        return len(self.cycle_list)

    def __getitem__(self, idx):
        mel, label, meta_data = self.cycle_list[idx]
        return mel, label, meta_data

##### Pickle.dump

CycleDataset 객체 생성

In [None]:
len(train_list)

In [None]:
# # # import random
# # # import matplotlib.pyplot as plt
# # # import librosa.display

# # # wav_dir = ROOT
# # # txt_dir = ROOT

# # # # 1. Dataset 로드
# # # train_dataset = CycleDataset(train_list, wav_dir, txt_dir)
# # # test_dataset = CycleDataset(test_list, wav_dir, txt_dir)

# # ################################################################

# import random
# import matplotlib.pyplot as plt
# import librosa.display

# wav_dir = ROOT
# txt_dir = ROOT

# # # mean, std 먼저 계산
# # normless_dataset = CycleDataset(train_list, wav_dir, txt_dir)
# # mean, std = get_mean_and_std(normless_dataset)

# # 정규화 적용된 데이터셋 생성
# train_dataset = CycleDataset(train_list, wav_dir, txt_dir)
# test_dataset = CycleDataset(test_list, wav_dir, txt_dir)

# pickle_dict = {
#     'train_dataset': train_dataset,
#     'test_dataset': test_dataset
# }

# save_path = os.path.join(PICKLE_PATH, '0721_MLS_datasets.pkl')
# with open(save_path, 'wb') as f:
#     pickle.dump(pickle_dict, f)



pickle로 train_dataset, test_dataset 외부 저장

In [None]:
# pickle_name = f'Aug_Moco_MLS_MelSpec_{args.target_sr//1000}kHz_{args.frame_size}win_{args.hop_length}hop_{args.n_mels}mel_{args.target_sec}s'

In [None]:
# pickle_dict = {
#     'train_dataset': train_dataset,
#     'test_dataset': test_dataset
# }

# save_path = os.path.join(PICKLE_PATH, '3:7_saved_datasets_multilabel.pkl')
# with open(save_path, 'wb') as f:
#     pickle.dump(pickle_dict, f)

# # #####

# # 🔹 mean, std 함께 저장
# pickle_dict = {
#     'train_dataset': train_dataset,
#     'test_dataset': test_dataset,
#     'mean': mean,
#     'std': std
# }
# with open(os.path.join(PICKLE_PATH, 'pad0_norm_saved_datasets_multilabel.pkl'), 'wb') as f:
#     pickle.dump(pickle_dict, f)

# print(f'mean: {mean}, std: {std}')

In [None]:
# # 2. 간단 통계
# print(f"Total cycles: {len(train_dataset)}")

# label_counter = [0] * 4  # normal, crackle, wheeze, both
# for _, multi_label,_ in train_dataset:
#     if torch.equal(multi_label, torch.tensor([0., 0.])):
#         label_counter[0] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 0.])):
#         label_counter[1] += 1
#     elif torch.equal(multi_label, torch.tensor([0., 1.])):
#         label_counter[2] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 1.])):
#         label_counter[3] += 1

# for idx, count in enumerate(label_counter):
#     print(f"Class {idx}: {count} cycles")

##### Pickle.load
저장된 train_dataset, test_dataset을 로드  
(> Aug 는 Moco 모델에서 사용)

In [None]:
save_path = os.path.join(PICKLE_PATH, 'Mix_MLATT_datasets.pkl')
with open(save_path, 'rb') as f:
    pickle_dict = pickle.load(f)

train_dataset = pickle_dict['train_dataset']
test_dataset = pickle_dict['test_dataset']

print(f"[Train] Cycles: {len(train_dataset)}")
print(f"[Test] Cycles: {len(test_dataset)}")

###################

# save_path = os.path.join(PICKLE_PATH, 'pad0_norm_saved_datasets_multilabel.pkl')
# # 🔹 load with normalization values
# with open(save_path, 'rb') as f:
#     pickle_dict = pickle.load(f)

# train_dataset = pickle_dict['train_dataset']
# test_dataset = pickle_dict['test_dataset']
# mean = pickle_dict['mean']
# std = pickle_dict['std']

# print(f"[Train] Cycles: {len(train_dataset)}")
# print(f"[Test] Cycles: {len(test_dataset)}")
# print(f"[INFO] Loaded mean={mean:.4f}, std={std:.4f}")

In [None]:
train_dataset[0][0].shape

In [None]:
import matplotlib.pyplot as plt

# 데이터 로드
mel = train_dataset[0][0]  # (1, 64, 256)

# 증강 적용
aug_speconly, _ , _ = aug(mel)  # aug1: speconly, aug2: speconly

# 시각화 함수
def show_mel(mel_tensor, title):
    # 텐서 shape이 (1, 64, 256) 또는 (1, 1, 64, 256)일 수 있음
    if mel_tensor.ndim == 4:
        mel_tensor = mel_tensor.squeeze(0)  # (1, 64, 256)
    mel_np = mel_tensor.squeeze(0).cpu().numpy()  # (64, 256)
    
    plt.figure(figsize=(8, 4))
    plt.imshow(mel_np, origin='lower', aspect='auto', cmap='magma')
    plt.colorbar()
    plt.title(title)
    plt.xlabel("Time")
    plt.ylabel("Mel Frequency")
    plt.tight_layout()
    plt.show()

# 시각화
show_mel(mel, "Original Mel")
show_mel(aug_speconly, "Augmented Mel (Spec Only)")


#### 2.4 DataLoader

In [None]:
# ---------------- 학습 데이터 구성(seed) ----------------
seed_everything(args.seed)

# train_dataset 내에서 각 파일의 인덱스를 추출
pretrain_idx = []

for i in range(len(train_dataset)):
    filename = train_dataset[i][2][0]

    if filename in pretrain_list:
        pretrain_idx.append(i)

# 인덱스 순서 셔플
random.shuffle(pretrain_idx)

print(f"Pretrain set size: {len(pretrain_idx)}")

코드 실행 환경에 따라 num_workers를 적절한 값으로 지정해주세요!

In [None]:
# Dataset 생성 (Subset)
pretrain_dataset = Subset(train_dataset, pretrain_idx)

# DataLoader 생성
# DataLoader에서 shuffle=True로 지정하면 매 epoch마다 셔플 순서가 달라짐 => 재현성 문제 발생
# pretrain_dataset은 이미 셔플이 완료된 것으로, 이것을 DataLoader에 입력함
pretrain_loader = DataLoader(
    pretrain_dataset,
    batch_size=args.batch_size,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=False
)

label 분포 확인 (단순 참고용, 실제 환경에서는 pretrain set의 label 분포가 어떤지 알 수 없음)

In [None]:
from collections import Counter

# label
labels = torch.stack([multi_label for _, multi_label, _ in train_dataset])

# pretext와 finetune 데이터셋의 라벨 분포 출력
pretrain_labels = labels[pretrain_idx]
pretrain_labels_class = (
    pretrain_labels[:, 0].long() * 1 +  # crackle bit → *1
    pretrain_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}


# test 데이터셋의 라벨 분포 출력
test_labels = torch.stack([multi_label for _, multi_label, _ in test_dataset])
test_labels_class = (
    test_labels[:, 0].long() * 1 +  # crackle bit → *1
    test_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}

print(f"Pretrain sample: {len(pretrain_labels_class)}")
print("Pretrain label distribution:", Counter(pretrain_labels_class.tolist()))
print(f"Test sample: {len(test_labels_class)}")
print("Test label distribution:", Counter(test_labels_class.tolist()))

## 3. Modeling

#### 3.1 Pre-trained ResNet50

In [None]:
from torchvision.models import resnet50, ResNet50_Weights

class ResNet50(torchvision.models.resnet.ResNet):
    def __init__(self, track_bn=True):
        def norm_layer(*args, **kwargs):
            return nn.BatchNorm2d(*args, **kwargs, track_running_stats=track_bn)
        super().__init__(torchvision.models.resnet.Bottleneck, [3, 4, 6, 3], norm_layer=norm_layer)
        del self.fc
        self.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.final_feat_dim = 2048

    def load_sl_official_weights(self, progress=True):
        weights = ResNet50_Weights.DEFAULT
        state_dict = weights.get_state_dict(progress=progress)

        del state_dict['conv1.weight']
        missing, unexpected = self.load_state_dict(state_dict, strict=False)
        # if len(missing) > 0:
            # raise AssertionError('Model code may be incorrect')

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        # x = self.fc(x)

        return x

In [None]:
def backbone_resnet50_patch():

    model = ResNet50()
    return model

#### 3.2 Pre-trained CNN6

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F


def init_layer(layer):
    """Initialize a Linear or Convolutional layer. """
    nn.init.xavier_uniform_(layer.weight)
    if hasattr(layer, 'bias'):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)
            

def init_bn(bn):
    """Initialize a Batchnorm layer. """
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.)


class ConvBlock5x5(nn.Module): #for CNN6
    def __init__(self, in_channels, out_channels, stride=(1,1)):
        
        super(ConvBlock5x5, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=in_channels, 
                              out_channels=out_channels,
                              kernel_size=(5, 5), stride=stride,
                              padding=(2, 2), bias=False)
                              
        self.bn1 = nn.BatchNorm2d(out_channels)

        self.init_weight()
        
    def init_weight(self):
        init_layer(self.conv1)
        init_bn(self.bn1)
        
    def forward(self, input, pool_size=(2, 2), pool_type='avg'):        
        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')
        
        return x


class CNN6(nn.Module):
    def __init__(self):
        super(CNN6, self).__init__()
        self.final_feat_dim = 512

        self.do_dropout = False
        self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64, stride=(1,1))
        self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128, stride=(1,1))
        self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256, stride=(1,1))
        self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512, stride=(1,1))
        self.dropout = nn.Dropout(0.2)
        # self.linear = nn.Linear(512, num_classes, bias=True)

    def load_sl_official_weights(self):
        """ download AudioSet pretrained CNN6 in https://zenodo.org/record/3960586#.Y8dz8y_kEiY
        """
        weights = torch.load('/home/sbw/boaz/notebook/Cnn6_mAP=0.343.pth')['model']
        state_dict = {k: v for k, v in weights.items() if k in self.state_dict().keys()}
        missing, unexpected = self.load_state_dict(state_dict, strict=False)

    def forward(self, x, return_feature_map=False):
        x = self.conv_block1(x, pool_size=(2, 2), pool_type='avg')
        if self.do_dropout:
            x = self.dropout(x)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type='avg')
        if self.do_dropout:
            x = self.dropout(x)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type='avg')
        if self.do_dropout:
            x = self.dropout(x)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type='avg')
        if self.do_dropout:
            x = self.dropout(x)
        
        if return_feature_map:
            return x  # shape: (B, 512, 4, 16)

        x = torch.mean(x, dim=3) #mean over time dim
        (x1, _) = torch.max(x, dim=2) #max over freq dim
        x2 = torch.mean(x, dim=2) #mean over freq dim (after mean over time)
        x = x1 + x2

        # if self.embed_only:
        #     return x
        # return self.linear(x)
        return x

In [None]:
def backbone_cnn6():
    """
    MoCo 구조에 사용할 CNN6 백본 정의 함수.
    
    주요 변경 사항:
    - ResNet50 대신 CNN6 클래스 사용
    - 출력 feature dim은 512로 고정됨 (MoCo에서는 dim_enc=2048 → 이 부분만 맞춰서 사용하면 문제 없음)
    - ImageNet pretrained 사용 대신 공식 CNN6 pretrained 로딩 함수 포함 (옵션 사용 가능)
    """
    model = CNN6()

    # 공식 SL pretrained weight를 사용하고자 할 경우 아래 줄을 주석 해제
    model.load_sl_official_weights()

    return model


##### 3.3 Multilabel Attention

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F


# class MultilabelAttention(nn.Module):
#     def __init__(self, backbone, num_classes=2, lambda_attn=0.5, attention_heads=[1, float('inf')]):
#         super(MultilabelAttention, self).__init__()
#         self.backbone = backbone()
#         self.num_classes = num_classes
#         self.lambda_attn = lambda_attn
#         self.attention_heads = attention_heads  # e.g., [1, inf] for H=2

#         self.class_weights = nn.Parameter(torch.randn(len(attention_heads), num_classes, 512))

#         self.output_layer = nn.ModuleList([
#             nn.Linear(512, 1) for _ in range(num_classes)
#         ])

#     def forward(self, x):
#         # CNN6 백본 통과 → shape: (B, 512, 4, 16)
#         feat_map = self.backbone(x, return_feature_map=True)  # (B, 512, 4, 16)

#         B, C, Freq, Time = feat_map.shape
#         flat_feat = feat_map.view(B, C, Freq * Time).permute(0, 2, 1)  # (B, 64, 512)

#         # Class-specific attention a_i 계산
#         attn_outputs = []
#         for h, T in enumerate(self.attention_heads):
#             Ci = self.class_weights[h]  # (num_classes, 512)
#             logits = torch.einsum("bnc, kc -> bnk", flat_feat, Ci)  # (B, 64, num_classes)
#             logits = logits.permute(0, 2, 1)  # (B, num_classes, 64)
#             if T == float('inf'):
#                 attn_scores = F.one_hot(torch.argmax(logits, dim=2), num_classes=logits.shape[2]).float()
#             else:
#                 attn_scores = F.softmax(T * logits, dim=2)  # (B, num_classes, 64)

#             attn_scores = attn_scores.unsqueeze(-1)  # (B, num_classes, 64, 1)
#             flat_feat_exp = flat_feat.unsqueeze(1)  # (B, 1, 64, 512)
#             attn_feat = torch.sum(attn_scores * flat_feat_exp, dim=2)  # (B, num_classes, 512)
#             attn_outputs.append(attn_feat)

#         # Class-specific global feature g_i 계산
#         feat_avg_t = torch.mean(feat_map, dim=3)  # (B, 512, Freq)
#         gmp = torch.max(feat_avg_t, dim=2)[0]     # (B, 512)
#         gap = torch.mean(feat_avg_t, dim=2)       # (B, 512)
#         g = gmp + gap                             # (B, 512)
#         g = g.unsqueeze(1).repeat(1, self.num_classes, 1)  # (B, num_classes, 512)

#         # Combine: f_i = g_i + lambda * a_i
#         combined = g
#         for attn in attn_outputs:
#             combined = combined + self.lambda_attn * attn  # sum over heads

#         # Output layer for each class
#         out = []
#         for i in range(self.num_classes):
#             cls_feat = combined[:, i, :]  # (B, 512)
#             logit = self.output_layer[i](cls_feat).squeeze(-1)  # (B,)
#             out.append(logit)

#         logits = torch.stack(out, dim=1)  # (B, num_classes)
#         probs = torch.sigmoid(logits)    # (B, num_classes)
#         return combined, logits, probs                 # 마지막 dim: (B, 2, 512) 


# def backbone_mlattention():
#     """
#     Multi-label attention 기반 backbone 정의 함수
#     CNN6 기반 특징 추출기 + CSRA 기반 attention 구조 결합
    
#     Returns:
#         nn.Module: Multi-label attention 기반 분류기
#     """
#     return MultilabelAttention(backbone=backbone_cnn6, num_classes=2, lambda_attn=0.5, attention_heads=[1, float('inf')])


In [None]:
x = torch.randn(10, 1, 64, 256) # B=10
model = backbone_mlattention()
out = model(x)  # (B, 2, 512)

print(f"\ntorch.Size : {out.shape}")  # → torch.Size([B=10, 2, 512])

##### 3.4 Mix-MultiLabel Attention

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def group_mix(group_spec, labels, beta=1.0):
    """
    Refer by CycleGuradian - https://github.com/chumingqian/CycleGuardian/blob/main/nets/CycleGuardian_v5_1_3.py
    ( def group_mix )
    """
    B, N, D = group_spec.shape  # B: 배치 크기, N: 그룹 수 (ex. 64), D: 차원 수 (ex. 512)
                                # e.g., group_spec.shape == [B, 64, 512]

    device = group_spec.device  # e.g., 'cuda:0'

    # 🔹 lam: beta 분포에서 샘플링 (mix 비율)
    lam = np.random.beta(beta, beta)  # scalar float (e.g., 0.66)

    # 🔹 num_mask: 총 N 그룹 중에서 몇 개를 섞을지 결정
    num_mask = int(D * (1. - lam))  # scalar int (e.g., 64 * 0.34 = 21)

    # 🔹 mask: 섞을 group index (공통)
    mask = torch.randperm(D)[:num_mask].to(device)  # shape: [num_mask] (e.g., [21])

    # 🔹 index: 다른 sample과 섞기 위해 순서를 섞음
    index = torch.randperm(B).to(device)  # shape: [B] (e.g., [3, 0, 1, 2])

    # 🔹 mix: 같은 위치의 group들을 index 기준으로 섞기
    mixed_group_spec = group_spec.clone()                  # shape: [B, 64, 512]
    mixed_group_spec[:, :, mask] = group_spec[index][:, :, mask]  
    # group_spec[index]: shape [B, 64, 512]
    # group_spec[index][:, mask, :]: shape [B, num_mask, 512]
    # 최종적으로 mixed_group_spec[:, mask, :]: shape [B, num_mask, 512]

    # 🔹 lam_tensor: 각 sample에 대해 lam 값을 broadcasting 하기 위한 텐서
    lam_tensor = torch.full((B,), lam, device=device)  # shape: [B] (e.g., [0.66, 0.66, 0.66, 0.66])

    # 🔹 return: 섞은 group, 원래 라벨, 섞인 라벨, lam 값, 섞인 index
    return mixed_group_spec, labels, labels[index], lam_tensor, index
    # mixed_group_spec: shape [B, 64, 512]
    # labels: shape [B] 또는 [B, C] (멀티클래스인지 멀티라벨인지에 따라 다름)
    # labels[index]: shape [B] 또는 [B, C]
    # lam_tensor: shape [B]
    # index: shape [B]


In [None]:
class GroupMixConLoss(torch.nn.Module):
    def __init__(self, temperature=0.07):
        super().__init__()
        self.temperature = temperature

    def forward(self, proj_orig, proj_mix, labels_a, labels_b, lam, index):
        """
        lam:       [B]             # 각 샘플마다 lam 값
        index:     [B]             # 섞인 대상 인덱스
        
        Refer by CycleGuradian - https://github.com/chumingqian/CycleGuardian/blob/main/nets/CycleGuardian_v5_1_3.py
        (class GroupMixConLoss)
        
        """
        B = proj_orig.size(0)           # 배치 크기
        device = proj_orig.device

        # 🔹 L2 정규화
        proj_orig = F.normalize(proj_orig, dim=1)   # [B, D]
        proj_mix  = F.normalize(proj_mix, dim=1)    # [B, D]
        # print(f"proj_orig.shape: {proj_orig.shape}, proj_mix.shape: {proj_mix.shape}")

        # 🔹 유사도 행렬: mix vs. original 간의 내적
        sim_matrix = torch.matmul(proj_mix, proj_orig.T) / self.temperature  
        # [B, D] x [D, B] -> [B, B]

        # 🔹 마스크 A: 원래 자기 자신이랑만 1인 마스크
        mask_a = torch.eye(B, device=device)  # [B, B]

        # 🔹 마스크 B: 각 mix가 섞인 대상과 1인 마스크
        mask_b = torch.zeros_like(mask_a)     # [B, B]
        mask_b[torch.arange(B), index] = 1    # 예: i-th row에서 index[i] column에 1

        # 🔹 soft positive mask = lam * identity + (1 - lam) * mix_target
        # lam: [B] -> [B, 1], broadcasting 됨
        mask = lam.view(-1, 1) * mask_a + (1 - lam).view(-1, 1) * mask_b  # [B, B]

        # 수치 안정성 위해 max 빼고 logits 계산
        logits_max, _ = torch.max(sim_matrix, dim=1, keepdim=True)
        logits = sim_matrix - logits_max.detach()

        exp_logits = torch.exp(logits)

        # 🔹 softmax log-prob 계산
        log_prob = logits - torch.log(exp_logits.sum(dim=1, keepdim=True))  
        # log_softmax(sim_matrix, dim=1)과 동일
        # sim_matrix: [B, B], log_prob: [B, B]

        # 🔹 positive log-prob만 평균냄
        # 각 row (i)에서 soft positive에 해당하는 위치에 대해서만 log_prob * mask
        mean_log_prob_pos = (mask * log_prob).sum(dim=1) / mask.sum(dim=1)  # [B]

        # 🔹 전체 평균 loss
        loss = -mean_log_prob_pos.mean()  # scalar

        return loss


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MixMLATT(nn.Module):
    def __init__(self, backbone, num_classes=2, lambda_attn=0.5, attention_heads=[1, float('inf')], projector_dim=128):
        super(MixMLATT, self).__init__()
        self.num_classes = num_classes
        self.lambda_attn = lambda_attn
        self.attention_heads = attention_heads

        # CNN6 백본 (e.g., CNN6 → [B, 512, 4, 16])
        self.backbone = backbone()

        # Class-specific Attention weights: 각 head마다 [num_classes, 512]
        self.class_weights = nn.Parameter(torch.randn(len(attention_heads), num_classes, 512))

    def forward(self, x, mix_feature=False, patch_mix=False, y=None, lam=None, index=None):
        """
        Refer Multi-label class-specific method - https://arxiv.org/abs/2407.10828
        Refer CSRA - https://github.com/Kevinz-code/CSRA/blob/master/pipeline/csra.py
        """
        B = x.size(0)
        origin_feat = None

        # 1. CNN6 Backbone → [B, 512, 4, 16]
        feat_map = self.backbone(x, return_feature_map=True)  # (B, 512, 4, 16)

        # 2. Reshape → [B, 64, 512]
        feat_flat = feat_map.view(B, 512, -1).permute(0, 2, 1)  # (B, 64, 512)

        # 3. Optional: Patch-wise Mixing
        if patch_mix and y is not None:
            # 1. group_mix 수행
            # mixed_group_spec, labels, labels[index], lam_tensor, index
            feat_flat, label_origin, label_mix, lam, index = group_mix(feat_flat, y)

            # 2. origin_feat 저장
            origin_feat = feat_flat.detach() if mix_feature else None

        # 4. Class-specific Attention
        attn_outputs = []
        for h, T in enumerate(self.attention_heads):
            class_weight = self.class_weights[h]  # (num_classes, 512)
            logits = torch.einsum("bnc,kc->bnk", feat_flat, class_weight)  # (B, 64, num_classes)
            logits = logits.permute(0, 2, 1)  # (B, num_classes, 64)

            if T == float('inf'):
                attn_scores = F.one_hot(torch.argmax(logits, dim=2), num_classes=logits.shape[2]).float()
            else:
                attn_scores = F.softmax(T * logits, dim=2)

            attn_scores = attn_scores.unsqueeze(-1)           # (B, num_classes, 64, 1)
            feat_exp = feat_flat.unsqueeze(1)                 # (B, 1, 64, 512)
            attn_feat = torch.sum(attn_scores * feat_exp, dim=2)  # (B, num_classes, 512)
            attn_outputs.append(attn_feat)

        # 5. Global Feature Aggregation
        feat_avg_t = torch.mean(feat_map, dim=3)  # (B, 512, Freq)
        gmp = torch.max(feat_avg_t, dim=2)[0]     # (B, 512)
        gap = torch.mean(feat_avg_t, dim=2)       # (B, 512)
        g = gmp + gap                             # (B, 512)
        g = g.unsqueeze(1).repeat(1, self.num_classes, 1)  # (B, num_classes, 512)

        # 6. Combine Global + Attention Feature
        attn_feat = g
        for attn in attn_outputs:
            attn_feat = attn_feat + self.lambda_attn * attn  # (B, num_classes, 512)

        # 8. Return
        if not patch_mix:
            return attn_feat, origin_feat # [B, 2, 512], [B, 64, 512]
        else:
            return attn_feat, origin_feat, label_origin, label_mix, lam, index


In [None]:
def backbone_mixmlatt():
    """
    Multi-label attention 기반 backbone 정의 함수
    CNN6 기반 특징 추출기 + CSRA 기반 attention 구조 결합
    
    Returns:
        nn.Module: Multi-label attention 기반 분류기
    """
    return MixMLATT(backbone=backbone_cnn6, num_classes=2, lambda_attn=0.5, attention_heads=[1, float('inf')])


In [None]:
# summary 함수 사용: (채널, 높이, 너비) 크기를 지정
summary(backbone_mixmlatt().to(device), input_size=(1, 64, 256))

## 4. Training

In [None]:
next(iter(pretrain_loader))[0].shape

In [None]:
def validate(model, classifier, projector_0, projector_1, val_loader, criterion, device, args):
    """
    Multi-label + GroupMix Contrastive 평가용 검증 함수
    - args.target_type 에 따라 grad_block, grad_flow, etc. 처리
    - 입력: inputs [B, 1, F, T], labels [B, 2]
    - 출력: 평균 loss, 전체 label, 전체 예측값
    """
    model.eval()
    classifier.eval()
    projector_0.eval()
    projector_1.eval()

    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels, _ in val_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # 1. Original forward (patch_mix=False)
            attn_feat, _ = model(inputs, mix_feature=True, patch_mix=False)  # [B, 2, 512]

            # 2. Classification logits
            out = []
            for i in range(attn_feat.shape[1]):
                cls_feat = attn_feat[:, i, :]               # [B, 512]
                logit = classifier[i](cls_feat).squeeze(-1) # [B]
                out.append(logit)
            logits = torch.stack(out, dim=1)  # [B, 2]

            # 3. classification loss
            loss_ce = criterion[0](logits, labels)

            # 4. Projector1 (target type 설정)
            if args.target_type == 'grad_block':
                proj1_0 = deepcopy(attn_feat[:, 0, :].detach())
                proj1_1 = deepcopy(attn_feat[:, 1, :].detach())
            elif args.target_type == 'grad_flow':
                proj1_0 = attn_feat[:, 0, :]
                proj1_1 = attn_feat[:, 1, :]
            elif args.target_type == 'project_block':
                proj1_0 = projector_0(attn_feat[:, 0, :]).detach()
                proj1_1 = projector_1(attn_feat[:, 1, :]).detach()
            elif args.target_type == 'project_flow':
                proj1_0 = projector_0(attn_feat[:, 0, :])
                proj1_1 = projector_1(attn_feat[:, 1, :])

            # 5. PatchMix 적용 (mix_feature=True)
            mix_attn_feat, origin_feat, label_origin, label_mix, lam, index = model(inputs, y=labels, patch_mix=True, mix_feature=True)

            # 6. mix는 무조건 projector 통과
            proj2_0 = projector_0(mix_attn_feat[:, 0, :])
            proj2_1 = projector_1(mix_attn_feat[:, 1, :])

            # 7. Contrastive loss
            loss_con0 = criterion[1](proj1_0, proj2_0, label_origin, label_mix, lam, index)
            loss_con1 = criterion[1](proj1_1, proj2_1, label_origin, label_mix, lam, index)

            # 8. Total loss
            loss = loss_ce + args.alpha * (loss_con0 + loss_con1)
            running_loss += loss.item()

            # 9. Prediction
            preds = (torch.sigmoid(logits) > 0.4).int()
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0).numpy()   # [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy() # [N, 2]

    avg_loss = running_loss / len(val_loader)
    return avg_loss, all_labels, all_preds


In [None]:
import os
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score
from copy import deepcopy
from torch.cuda.amp import GradScaler
import torch.nn as nn
import time
import torch
from torch.cuda.amp import GradScaler

"""
Refered by PatchMix - https://github.com/raymin0223/patch-mix_contrastive_learning/blob/main/main.py
(def train.py)
"""

# from utils.meters import AverageMeter
################################
train_losses = []
test_losses = []
train_icbhi_scores = []
test_icbhi_scores = []
test_labels_all = []
test_preds_all = []
epochs = []

# 모델 지정하기 전 seed 고정 필요
seed_everything(args.seed) # Seed 고정

pretrain_project_name = f'Raw_MultilabelAtt_T_{args.batch_size}bs_{get_timestamp()}'

# -------------------------------------------wan
# wandb 초기화 (프로젝트명, 실험 이름 등 설정)
wandb.init(
    project="SBW_ICBHI_MLATT_all", # 프로젝트 이름
    name=f"{pretrain_project_name}",  # 실험 이름
    config={
        "epochs": args.epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay
    }
)
# -------------------------------------------wan

################################
# 1. Model / Classifier 
model = MixMLATT(backbone=backbone_cnn6, 
                 num_classes=2, 
                 lambda_attn=0.5, 
                 attention_heads=[1, float('inf')]
                 ).cuda()

classifier = nn.ModuleList([nn.Linear(args.out_dim, 1) for _ in range(2)]).cuda()

# 2. Projector 0/1
projector_0 = nn.Sequential(nn.Linear(args.out_dim, args.out_dim),nn.ReLU(),nn.Linear(args.out_dim, args.dim_prj)).cuda()
projector_1 = nn.Sequential(nn.Linear(args.out_dim, args.out_dim),nn.ReLU(),nn.Linear(args.out_dim, args.dim_prj)).cuda()

# 3. EMA 선언
ema_model = deepcopy(model)
ema_projector_0 = deepcopy(projector_0)
ema_projector_1 = deepcopy(projector_1)
ema_classifier = deepcopy(classifier)
for m in [ema_model, ema_projector_0, ema_projector_1, ema_classifier]:
    m.eval()
    for p in m.parameters():
        p.requires_grad_(False)

# 4. criterion
criterion = [
    nn.BCEWithLogitsLoss().cuda(),       # criterion[0]: classification
    GroupMixConLoss(temperature=0.07).cuda()  # criterion[1]: contrastive
]

# 5. optimizer  
optimizer = optim.Adam(
    list(model.parameters()) + list(classifier.parameters()) + 
    list(projector_0.parameters()) + list(projector_1.parameters()),
    lr=args.lr, weight_decay=args.weight_decay
)


# 6. EMA (Exponential Moving Average) 설정
@torch.no_grad()
def update_ema(student, ema, beta=0.999):
    for param, ema_param in zip(student.parameters(), ema.parameters()):
        ema_param.data = beta * ema_param.data + (1 - beta) * param.data


# 7. Train
# Best loss 초기화
best_loss = float('inf')
best_epoch = -1

for epoch in range(args.epochs):
    # ===============================
    # Training
    # ===============================
    model.train()
    projector_0.train()
    projector_1.train()
    classifier.train()

    total_train_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0

    all_preds = []
    all_labels = []
    all_outputs = []

    pbar = tqdm(pretrain_loader, desc='Mix_MLATT Trainig only')
    for idx, (repeat_mel, labels, _) in enumerate(pretrain_loader):

        repeat_mel = repeat_mel.cuda(non_blocking=True)
        labels = labels.cuda(non_blocking=True)
        bsz = labels.size(0)

        # (1) Original 이미지 forward
        # attn_feat: [B, 2, 512], origin_feat: [B, 64, 512]
        attn_feat, _ = model(repeat_mel, mix_feature=True, patch_mix=False)
        out = []
        for i in range(args.num_classes):
            cls_feat = attn_feat[:, i, :]  # (B, 512)
            logit = classifier[i](cls_feat).squeeze(-1)  # (B,)
            out.append(logit)
        logits = torch.stack(out, dim=1) # [B, 2]

        # (2) classification loss
        loss_ce = criterion[0](logits, labels)

        # (3) projector1 생성 (class별)
        if args.target_type == 'grad_block':
            proj1_0 = deepcopy(attn_feat[:, 0, :].detach())
            proj1_1 = deepcopy(attn_feat[:, 1, :].detach())
        elif args.target_type == 'grad_flow':
            proj1_0 = attn_feat[:, 0, :]
            proj1_1 = attn_feat[:, 1, :]
        elif args.target_type == 'project_block':
            proj1_0 = projector_0(attn_feat[:, 0, :]).detach()
            proj1_1 = projector_1(attn_feat[:, 1, :]).detach()
        elif args.target_type == 'project_flow':
            proj1_0 = projector_0(attn_feat[:, 0, :])
            proj1_1 = projector_1(attn_feat[:, 1, :])

        # (4) PatchMix 수행 (mix된 이미지 반환)
        mix_attn_feat, origin_feat, label_origin, label_mix, lam, index = model(repeat_mel, y=labels, patch_mix=True, mix_feature=True)

        # (5) mix - projector 
        proj2_0 = projector_0(mix_attn_feat[:, 0, :])  # [B, 128]
        proj2_1 = projector_1(mix_attn_feat[:, 1, :])

        # (6) GroupMixConLoss & Final loss 계산
        loss_con0 = criterion[1](proj1_0, proj2_0, label_origin, label_mix, lam, index)
        loss_con1 = criterion[1](proj1_1, proj2_1, label_origin, label_mix, lam, index)
        loss = loss_ce + args.alpha * (loss_con0 + loss_con1)
        

        # (8) Backpopagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # (9) EMA 업데이트
        if args.ma_update:
            update_ema(model, ema_model, beta=args.ma_beta)
            update_ema(projector_0, ema_projector_0, beta=args.ma_beta)
            update_ema(projector_1, ema_projector_1, beta=args.ma_beta)
            update_ema(classifier, ema_classifier, beta=args.ma_beta)

        # (10) Loss 기록
        total_train_loss += loss.item()
        
        # 예측값과 실제값 저장 ( Ablation(4-1) threshold ?? )
        predicted = (torch.sigmoid(logits) > 0.5).float()
        all_preds.append(predicted.detach().cpu())
        all_labels.append(labels.detach().cpu())
        all_outputs.append(logits.detach().cpu())


    # train loss
    train_loss = total_train_loss / len(pretrain_loader)

    # Concatenate
    all_preds = torch.cat(all_preds, dim=0).numpy()    # shape: [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy()  # shape: [N, 2]
    all_output = torch.cat(all_outputs, dim=0).numpy()

    print(f"[Epoch {epoch} | Train Loss: {train_loss:.4f}, attn_feat: {attn_feat.shape}")


    # =====================================
    # 2-Edited. Multi-class 민감도/특이도 계산
    # =====================================
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import wandb
    from sklearn.metrics import confusion_matrix

    def multilabel_to_multiclass(y):
        # Crackle → 1, Wheeze → 2, Both → 3, None → 0
        y = np.array(y)
        return y[:, 0] + y[:, 1]*2

    def evaluate_multiclass_confusion(y_true, y_pred, class_names=["Normal", "Wheeze", "Crackle", "Both"]):
        y_true_cls = multilabel_to_multiclass(y_true)
        y_pred_cls = multilabel_to_multiclass(y_pred)

        cm = confusion_matrix(y_true_cls, y_pred_cls, labels=[0, 1, 2, 3])

        # N_n: 정상 → 정상
        N_n = cm[0, 0]
        N_total = cm[0].sum()

        # 이상 클래스 정답 수: W, C, B
        W_total = cm[1].sum()
        C_total = cm[2].sum()
        B_total = cm[3].sum()

        # 각각의 정답 → 정확한 예측만 고려
        W_w = cm[1, 1]
        C_c = cm[2, 2]
        B_b = cm[3, 3]

        SP = N_n / (N_total + 1e-6) #spec
        SE = (W_w + C_c + B_b) / (W_total + C_total + B_total + 1e-6) #sense

        AS = (SP + SE) / 2
        HS = 2 * SP * SE / (SP + SE + 1e-6)

        return cm, SE, SP, y_true_cls, y_pred_cls

    def log_multiclass_conf_matrix_wandb(cm, class_names, sens, spec, normalize, tag):
        # Normalize (비율) 옵션
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
            fmt = '.2f'
            title = "Confusion Matrix (Normalized %)"
        else:
            fmt = 'd'
            title = "Confusion Matrix (Raw Count)"

        fig, ax = plt.subplots(figsize=(7, 6))
        sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names, ax=ax)

        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(title)

        icbhi_score = (sens + spec) / 2
        # 우하단에 성능 출력
        ax.text(
            0.99, 0.15,
            f"Sensitivity: {sens*100:.2f}%\nSpecificity: {spec*100:.2f}%\nICBHI Score: {icbhi_score*100:.2f}%",
            ha='right', va='bottom',
            transform=plt.gca().transAxes,
            fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
        )

        plt.tight_layout()
        # wandb.log({tag: wandb.Image(fig)})
        # plt.close(fig)
        return fig

    # 1. 4-class Confusion Matrix 평가
    class_names = ["Normal", "Crackle", "Wheeze", "Both"]
    cm_4x4, train_sens, train_spec, y_true_cls, y_pred_cls = evaluate_multiclass_confusion(all_labels, all_preds, class_names)
    icbhi_score = (train_sens + train_spec)/2

    print("4-Class Confusion Matrix:\n", cm_4x4)
    print(f"Sensitivity: {train_sens:.4f}, Specificity: {train_spec:.4f}, ICBHI Score: {icbhi_score:.4f}")


    # ===============================
    # 3. Validation
    # ===============================
    # test_loss, test_labels, test_preds = validate(
    #     model, test_loader, criterion, device
    # )

    test_loss, test_labels, test_preds = validate(
        model=ema_model if args.ma_update else model,
        classifier=ema_classifier if args.ma_update else classifier,
        projector_0=ema_projector_0 if args.ma_update else projector_0,
        projector_1=ema_projector_1 if args.ma_update else projector_1,
        val_loader=test_loader,
        criterion=criterion,
        device=device,
        args=args
    )

    precision = precision_score(test_labels, test_preds, average='macro')
    recall = recall_score(test_labels, test_preds, average='macro')
    f1 = f1_score(test_labels, test_preds, average='macro')

    test_cm_4x4, test_sens, test_spec, test_y_true_cls, test_y_pred_cls = evaluate_multiclass_confusion(test_labels, test_preds)
    test_icbhi_score = (test_sens+test_spec)/2

    print("[Validation] Confusion Matrix:\n", test_cm_4x4)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"[VALIDATION] Sensitivity: {test_sens:.4f}, Specificity: {test_spec:.4f}, Avg ICBHI Score: {(test_sens+test_spec)/2:.4f}")
    print("##################################################")


    # ===============================
    # 4. Confusion Matrix
    # ===============================

    # 2. Finetune Count Confusion Matrix 시각화
    fig_finetune_raw = log_multiclass_conf_matrix_wandb(cm_4x4, class_names, train_sens, train_spec, normalize=False, tag="Training_conf_matrix_raw")
    fig_finetune_norm = log_multiclass_conf_matrix_wandb(cm_4x4, class_names, train_sens, train_spec, normalize=True, tag="Training_conf_matrix_norm")

    # 3. Test Confusion Matrix 시각화
    fig_test_raw = log_multiclass_conf_matrix_wandb(test_cm_4x4, class_names, test_sens, test_spec, normalize=False, tag="test_conf_matrix_raw")
    fig_test_norm = log_multiclass_conf_matrix_wandb(test_cm_4x4, class_names, test_sens, test_spec, normalize=True, tag="test_conf_matrix_norm")

    # 4. log dictionary 생성 -------------------------------------------wan
    wandb_log_dict = {
        "finetune_conf_matrix_raw": wandb.Image(fig_finetune_raw),
        "finetune_conf_matrix_norm": wandb.Image(fig_finetune_norm),
        "test_conf_matrix_raw": wandb.Image(fig_test_raw),
        "test_conf_matrix_norm": wandb.Image(fig_test_norm)
    }
    # -------------------------------------------wan

    # =====================================
    # 5. Checkpoint (Every 100 epochs)
    # =====================================
    if (epoch + 1) % 100 == 0:
        ckpt_path = CHECKPOINT_PATH + f"{pretrain_project_name}_{epoch:03d}.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, ckpt_path)
        print(f"💾 Saved checkpoint to {ckpt_path}")

        # =====================================
        # EMA 모델 저장 (조건: EMA 활성화일 때)
        if args.ma_update:
            ema_ckpt_path = CHECKPOINT_PATH + f"{pretrain_project_name}_ema_{epoch:03d}.pth.tar"
            torch.save({
                'epoch': epoch + 1,
                'state_dict': ema_model.state_dict(),
                'classifier': ema_classifier.state_dict(),
                'projector_0': ema_projector_0.state_dict(),
                'projector_1': ema_projector_1.state_dict()
            }, ema_ckpt_path)
            print(f"💾 Saved EMA checkpoint to {ema_ckpt_path}")
        # ================================

    # ===============================
    # 6. Save Best Checkpoint
    # ===============================
    if test_loss < best_loss:
        best_loss = test_loss
        best_epoch = epoch
        best_ckpt_path = CHECKPOINT_PATH + f"{pretrain_project_name}_best.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")

        # ================================
        # EMA 모델 저장
        if args.ma_update:
            best_ema_ckpt_path = CHECKPOINT_PATH + f"{pretrain_project_name}_best_ema.pth.tar"
            torch.save({
                'epoch': epoch + 1,
                'state_dict': ema_model.state_dict(),
                'classifier': ema_classifier.state_dict(),
                'projector_0': ema_projector_0.state_dict(),
                'projector_1': ema_projector_1.state_dict(),
                'loss': best_loss
            }, best_ema_ckpt_path)
            print(f"=> Saved best EMA checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")
        # ================================


        # 🔹 Confusion Matrix Logging for Best
        cm_best, sens_best, spec_best,_, _ = evaluate_multiclass_confusion(test_labels, test_preds, class_names)
        fig_best_raw = log_multiclass_conf_matrix_wandb(cm_best, class_names, sens_best, spec_best, normalize=False, tag="best_test_conf_matrix_raw")

        fig_best_norm = log_multiclass_conf_matrix_wandb(cm_best, class_names, sens_best, spec_best, normalize=True, tag="best_test_conf_matrix_norm")

        # -------------------------------------------wan
        wandb_log_dict.update({
            "best_test_conf_matrix_raw": wandb.Image(fig_best_raw),
            "best_test_conf_matrix_norm": wandb.Image(fig_best_norm)
        })
        # -------------------------------------------wan


    if epoch == args.epochs - 1:
        # 🔸 Confusion Matrix Logging for Last Epoch
        cm_last, sens_last, spec_last, _, _  = evaluate_multiclass_confusion(test_labels, test_preds, class_names)
        fig_last_raw = log_multiclass_conf_matrix_wandb(cm_last, class_names, sens_last, spec_last, normalize=False, tag="last_test_conf_matrix_raw")

        fig_last_norm = log_multiclass_conf_matrix_wandb(cm_last, class_names, sens_last, spec_last, normalize=True, tag="last_test_conf_matrix_norm")

        # -------------------------------------------wan
        wandb_log_dict.update({
            "last_test_conf_matrix_raw": wandb.Image(fig_last_raw),
            "last_test_conf_matrix_norm": wandb.Image(fig_last_norm)
        })
        # -------------------------------------------wan

    # =====================================
    # 7. Logging with wandb confusion matrix
    # =====================================

    # -------------------------------------------wan
    # step 1. metrics
    wandb.log({
        # Train metrics
        "Training/epoch": epoch,
        "Training/train_loss": train_loss,
        "Training/test_loss": test_loss,
        "Training/train_sens": train_sens,
        "Training/train_spec": train_spec,
        "Training/icbhi_score": icbhi_score,

        # Test metrics
        "Test/loss": test_loss,
        "Test/sensitivity": test_sens,
        "Test/specificity": test_spec,
        "Test/icbhi_score": test_icbhi_score
    })

    # step 2. Confusion matrix
    wandb.log(wandb_log_dict)

    # -------------------------------------------wan


    plt.close(fig_finetune_raw)
    plt.close(fig_finetune_norm)
    plt.close(fig_test_raw)
    plt.close(fig_test_norm)
    if 'fig_best_raw' in locals(): plt.close(fig_best_raw)
    if 'fig_best_norm' in locals(): plt.close(fig_best_norm)
    if 'fig_last_raw' in locals(): plt.close(fig_last_raw)
    if 'fig_last_norm' in locals(): plt.close(fig_last_norm)

    # # ===============================
    # # 8. Scheduler Step
    # # ===============================
    # scheduler.step()

    # ===============================
    # 9. Save Metrics
    # ===============================
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_icbhi_scores.append(icbhi_score)
    test_icbhi_scores.append(test_icbhi_score)
    epochs.append(epoch)
    # ================================

    # ================================
    # test_labels_all, test_preds_all에 저장
    # ================================
    test_labels_all.append(test_labels)
    test_preds_all.append(test_preds)
    # ================================


# -------------------------------------------wan
wandb.finish()
# -------------------------------------------wan



In [None]:
# --- 훈련 종료 후 그래프 ---
import matplotlib.pyplot as plt

plt.figure(figsize=(10,6))
plt.plot(epochs, train_losses, label='Train Loss')
plt.plot(epochs, test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Train/Test Loss per Epoch')
plt.legend()
plt.grid(True)
plt.show()

plt.figure(figsize=(10,6))
plt.plot(epochs, train_icbhi_scores, label='Train ICBHI Score')
plt.plot(epochs, test_icbhi_scores, label='Test ICBHI Score')
plt.xlabel('Epoch')
plt.ylabel('ICBHI Score')
plt.title('Train/Test ICBHI Score per Epoch')
plt.legend()
plt.grid(True)
plt.show()

best_epoch_idx = np.argmax(test_icbhi_scores)
best_epoch = epochs[best_epoch_idx]
best_icbhi_score = test_icbhi_scores[best_epoch_idx]
best_test_loss = test_losses[best_epoch_idx]

# 최고점 epoch에서의 labels, preds
best_test_labels = test_labels_all[best_epoch_idx]
best_test_preds = test_preds_all[best_epoch_idx]

best_cm, best_sens, best_spec, best_y_true_cls, best_y_pred_cls = evaluate_multiclass_confusion(
    best_test_labels, best_test_preds)

print("\n=== [최고 Test ICBHI Score 시점 정보] ===")
print(f"Best Test ICBHI Score: {best_icbhi_score:.4f} (Epoch {best_epoch})")
print(f"Test Loss at Best: {best_test_loss:.4f}")
print("Confusion Matrix at Best ICBHI Score:")
print(best_cm)
print(f"Sensitivity: {best_sens:.4f}, Specificity: {best_spec:.4f}, ICBHI Score: {(best_sens+best_spec)/2:.4f}")
print(f"Best Epoch: {best_epoch}")
