#### 환경설정

##### 1. Wandb

In [None]:
import wandb

# wandb 로그인
!wandb login

##### 2. 라이브러리 로드

In [None]:
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
import os
import math
import random
import pickle
import wandb
from tqdm import tqdm
from datetime import datetime
from zoneinfo import ZoneInfo

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import librosa
import librosa.display

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import torchaudio.transforms as T
import torchvision
import torchvision.models as models
from torch import Tensor
from torchsummary import summary
from torch.hub import load_state_dict_from_url
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim.lr_scheduler import CosineAnnealingLR

from sklearn.metrics import confusion_matrix, f1_score
from sklearn.manifold import TSNE

##### 3. 경로 설정

In [None]:
ROOT = "/home/sbw/BOAZ-Chungzins/data/raw"
CHECKPOINT_PATH = "/home/sbw/boaz/notebook/note_ckp"
PICKLE_PATH = "/home/sbw/boaz/notebook/pickle"
text = "/home/sbw/BOAZ-Chungzins/data/metadata/train_test_split.txt"


##### 4. Seed 설정

In [None]:
def seed_everything(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = False  # type: ignore

seed_everything(42) # Seed 고정

## 1. Data Load

#### 1.1 Data Load

In [None]:
# WAV 파일이 있는 디렉토리 경로
data_dir = ROOT
txt_dir = ROOT

df = pd.read_csv(text, sep='\t', header=None)

# 컬럼 이름 변경
df.columns = ['filename', 'set']

# train, test split
train_df = df[df['set'] == 'train']
test_df = df[df['set'] == 'test']

# filename list
train_list = sorted(train_df['filename'].tolist())
test_list = sorted(test_df['filename'].tolist())

print(f'Train :{len(train_list)}, Test: {len(test_list)}, Total: {len(train_list) + len(test_list)}')

#### 1.2 Pretext-Finetune Split

In [None]:
# shuffle train data
df_shuffled = train_df.sample(frac=1, random_state=42)

# split ratio
train_size = int(0.8 * len(df_shuffled))

# pretrain, finetune split
pretrain_df = df_shuffled[:train_size]
finetune_df = df_shuffled[train_size:]

# filename list (pretext_list -> pretrain list)
pretrain_list = sorted(pretrain_df['filename'].tolist())
finetune_list = sorted(finetune_df['filename'].tolist())

# patient id list
pretrain_patient_list = []
for filename in pretrain_list:
    number = int(filename.split('_')[0])
    pretrain_patient_list.append(number)

finetune_patient_list = []
for filename in finetune_list:
    number = int(filename.split('_')[0])
    finetune_patient_list.append(number)

pretrain_patient_counts = pd.Series(pretrain_patient_list).value_counts()
finetune_patient_counts = pd.Series(finetune_patient_list).value_counts()

print(f"[Pretrain] 환자 수: {len(pretrain_patient_counts.index)}, 샘플 수: {pretrain_patient_counts.sum()}")
print(f"[Finetune] 환자 수: {len(finetune_patient_counts.index)}, 샘플 수: {finetune_patient_counts.sum()}")

## 2. Data Preprocessing

#### 2.1 Args

        K: queue size; number of negative keys (default: 65536)
        m: moco momentum of updating key encoder (default: 0.999)
        T: softmax temperature (default: 0.07)

In [None]:
class Args:
    # Audio & Spectrogram
    target_sr = 4000
    frame_size = 256
    hop_length = 128    # frame_size 절반
    n_mels = 128
    target_sec = 8

    # Augmentation
    time_mask_param = 0.5
    freq_mask_param = 0.5

    # Train
    lr = 0.03
    warm = True                     # warm-up 사용 여부
    warm_epochs = 10                # warm-up 적용할 초기 epoch 수
    warmup_from = lr * 0.1          # warm-up 시작 learning rate (보통 lr의 10%)
    warmup_to = lr

    batch_size = 128
    workers = 4
    epochs = 300
    weight_decay = 1e-4

    resume = None
    schedule=[120, 160] # schedule

    # MLS
    K = 1024
    momentum = 0.999
    T = 0.07
    dim_prj = 128
    top_k = 15
    lambda_bce = 0.5
    out_dim = 2048

    # Linear Evaluation
    ft_epochs = 100

    # etc
    gpu = 0
    data = "./data_path"
    seed=42

args = Args()

#### 2.2 Utils (func)

In [None]:
import torch.nn.functional as F
import random

# cycle의 클래스를 추출
def get_class(cr, wh):
    if cr == 1 and wh == 1:
        return 3
    elif cr == 0 and wh == 1:
        return 2
    elif cr == 1 and wh == 0:
        return 1
    elif cr == 0 and wh == 0:
        return 0
    else:
        return -1

# Mel Spectrogram 생성 ( sr=4KHz, frame_size=1024, hop_length=512, n_mels=128 )
def generate_mel_spectrogram(waveform, sample_rate, frame_size, hop_length, n_mels):
    if hop_length is None:
        hop_length = frame_size // 2
    mel_spec_transform = T.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=frame_size,
        hop_length=hop_length,
        n_mels=n_mels
    )
    mel_spectrogram = mel_spec_transform(waveform)
    mel_db = T.AmplitudeToDB()(mel_spectrogram)

    return mel_db

# Cycle Repeat 또는 Crop
def repeat_or_truncate_segment(mel_segment, target_frames):
    current_frames = mel_segment.shape[-1]
    if current_frames >= target_frames:
        return mel_segment[:, :, :target_frames]
    else:
        repeat_ratio = math.ceil(target_frames / current_frames)
        mel_segment = mel_segment.repeat(1, 1, repeat_ratio)
        return mel_segment[:, :, :target_frames]

def preprocess_waveform_segment(waveform, unit_length):

    """unit_length 기준으로 waveform을 repeat + padding 또는 crop하여 길이 정규화"""
    waveform = waveform.squeeze(0)  # (1, L) → (L,) 로 바꿔도 무방
    length_adj = unit_length - len(waveform)

    if length_adj > 0:
        # waveform이 너무 짧은 경우 → repeat + zero-padding
        half_unit = unit_length // 2

        if length_adj < half_unit:
            # 길이 차이가 작으면 단순 padding
            half_adj = length_adj // 2
            waveform = F.pad(waveform, (half_adj, length_adj - half_adj))
        else:
            # 반복 후 부족한 부분 padding
            repeat_factor = unit_length // len(waveform)
            waveform = waveform.repeat(repeat_factor)[:unit_length]
            remaining = unit_length - len(waveform)
            half_pad = remaining // 2
            waveform = F.pad(waveform, (half_pad, remaining - half_pad))
    else:
        # waveform이 너무 길면 앞쪽 1/4 내에서 랜덤 crop
        length_adj = len(waveform) - unit_length
        start = random.randint(0, length_adj // 4)
        waveform = waveform[start:start + unit_length]

    return waveform.unsqueeze(0)  # 다시 (1, L)로

# 데이터 Spec Augmentation ( 0~80% Random Masking )
def apply_spec_augment(mel_segment):

    M = mel_segment.shape[-1]
    F = mel_segment.shape[-2]

    # torchaudio의 마스킹은 0부터 mask_param까지 균등분포에서 랜덤하게 길이를 선택
    time_masking = T.TimeMasking(time_mask_param=int(M * 0.8))
    freq_masking = T.FrequencyMasking(freq_mask_param=int(F * 0.8) )

    aug1 = freq_masking(mel_segment.clone())
    aug2 = time_masking(mel_segment.clone())
    aug3 = freq_masking(time_masking(mel_segment.clone()))

    return aug1, aug2, aug3

# Waveform resample
def resample_waveform(waveform, orig_sr, target_sr=args.target_sr):
    if orig_sr != target_sr:
        resampler = torchaudio.transforms.Resample(
            orig_freq=orig_sr,
            new_freq=target_sr
        )
        return resampler(waveform), target_sr
    return waveform, orig_sr

In [None]:
##############################################
import torch
import torch.nn.functional as F
import torchaudio.transforms as T
import numpy as np
import random

# -------------------- Augmentation functions (ICBHI 멜스펙트로그램에 최적화) --------------------

def spec_augment(mel, time_mask_ratio=0.15, freq_mask_ratio=0.15):
    """
    SpecAugment: 시간/주파수 영역 마스킹
    - 시간축 마스킹: 63 * 0.15 ≈ 9 프레임
    - 주파수 마스킹: 128 * 0.1 ≈ 12 채널
    """
    M = mel.shape[-1]  # 시간 축
    F = mel.shape[-2]  # 주파수 축

    time_masking = T.TimeMasking(time_mask_param=max(1, int(M * time_mask_ratio)))
    freq_masking = T.FrequencyMasking(freq_mask_param=max(1, int(F * freq_mask_ratio)))

    mel = freq_masking(mel.clone())
    mel = time_masking(mel)
    return mel

def add_noise(mel, noise_level=0.001):
    """
    노이즈 추가: 적당한 수준의 표준 정규분포 노이즈 (너무 높으면 손실 커짐)
    """
    noise = torch.randn_like(mel) * noise_level
    return mel + noise

def pitch_shift(mel, n_steps=2):
    """
    주파수 축 순환 이동 (mel axis). shape은 그대로 유지됨.
    n_steps=2면 ±2 멜 채널만 이동.
    """
    shift = random.randint(-n_steps, n_steps)
    if shift == 0:
        return mel
    if shift > 0:
        mel = torch.cat([mel[:, :, shift:, :], mel[:, :, :shift, :]], dim=2)
    else:
        shift = abs(shift)
        mel = torch.cat([mel[:, :, -shift:, :], mel[:, :, :-shift, :]], dim=2)
    return mel

def time_stretch(mel, min_rate=0.95, max_rate=1.05):
    """
    시간 축 길이 조절. 너무 심하지 않게 ±5% 범위로만 조정.
    - shape 유지 위해 interpolation 후 crop/pad
    """
    rate = random.uniform(min_rate, max_rate)
    if rate == 1.0:
        return mel

    orig_size = mel.shape[-1]
    target_size = int(orig_size * rate)

    mel_stretched = F.interpolate(
        mel, size=(mel.shape[-2], target_size),  # (mel_bins, time)
        mode='bilinear',
        align_corners=False
    )

    if target_size > orig_size:
        return mel_stretched[..., :orig_size]
    else:
        pad = orig_size - target_size
        return F.pad(mel_stretched, (0, pad))

# -------------------- Dispatcher --------------------

AUGMENTATION_FUNCTIONS_TORCH = {
    "spec_augment": spec_augment,
    "add_noise": add_noise,
    "pitch_shift": pitch_shift,
    "time_stretch": time_stretch
}

def apply_augmentations_torch(x, methods=[], **kwargs):
    for method in methods:
        func = AUGMENTATION_FUNCTIONS_TORCH.get(method)
        if func is None:
            raise ValueError(f"Unknown augmentation: {method}")
        x = func(x, **kwargs.get(method, {}))
    return x

In [None]:
def aug(repeat_mel):
    # 먼저 복사본 준비
    mel1 = repeat_mel.clone()
    mel2 = repeat_mel.clone()

    # 각각 다른 증강 A, B 적용
    aug1 = apply_augmentations_torch(mel1, methods=["add_noise"], add_noise={"noise_level": 0.005})
    aug2 = apply_augmentations_torch(mel2, methods=["time_stretch"], time_stretch={"min_rate": 0.8, "max_rate": 1.2})
    # aug3 = apply_augmentations_torch(mel3, methods=["pitch_shift"], pitch_shift={"n_steps": 2})

    # # 각 결과에 spec_augment 추가 적용
    aug1_spec = spec_augment(aug1, time_mask_ratio=0.6, freq_mask_ratio=0.4)
    aug2_spec = spec_augment(aug2, time_mask_ratio=0.6, freq_mask_ratio=0.4)
    # aug3_spec = spec_augment(aug3, time_mask_ratio=0.6, freq_mask_ratio=0.4)

    return aug1_spec, aug2_spec, None


def get_timestamp():
    """Outputs current time in KST like 2404070830"""
    kst_time = datetime.now(ZoneInfo("Asia/Seoul"))
    return kst_time.strftime('%y%m%d%H%M')

# Origin
# def aug(repeat_mel):
#     aug1, aug2, aug3 = apply_spec_augment(repeat_mel)
#     return aug1, aug2, aug3

#### 2.3 CycleDataset

In [None]:
import os
import torch
import torchaudio
import numpy as np
import torch.nn.functional as F
from torch.utils.data import Dataset
from tqdm import tqdm

class CycleDataset(Dataset):
    def __init__(self, filename_list, wav_dir, txt_dir, target_sec=args.target_sec, target_sr=args.target_sr, frame_size=args.frame_size, hop_length=args.hop_length, n_mels=args.n_mels):
        self.filename_list = filename_list
        self.wav_dir = wav_dir
        self.txt_dir = txt_dir
        self.target_sec = target_sec
        self.target_sr = target_sr
        self.frame_size = frame_size
        self.hop_length = hop_length
        self.n_mels = n_mels

        self.cycle_list = []

        print("[INFO] Preprocessing cycles...")
        for filename in tqdm(self.filename_list):
            txt_path = os.path.join(self.txt_dir, filename + '.txt')
            wav_path = os.path.join(self.wav_dir, filename + '.wav')

            if not os.path.exists(txt_path):
                print(f"[WARNING] Missing file: {txt_path}")
            if not os.path.exists(wav_path):
                print(f"[WARNING] Missing file: {wav_path}")

            # Load annotation
            cycle_data = np.loadtxt(txt_path, usecols=(0, 1))
            lung_label = np.loadtxt(txt_path, usecols=(2, 3))

            # Load waveform
            waveform, orig_sr = torchaudio.load(wav_path)
            if waveform.shape[0] > 1:
                waveform = torch.mean(waveform, dim=0, keepdim=True)  # Stereo to mono

            # Resample to target sample rate (4kHz)
            waveform, sample_rate = resample_waveform(waveform, orig_sr, self.target_sr)

            for idx in range(len(cycle_data)):
                # 호흡 주기 start, end
                start_sample = int(cycle_data[idx, 0] * sample_rate)
                end_sample = int(cycle_data[idx, 1] * sample_rate)
                lung_duration = cycle_data[idx, 1] - cycle_data[idx, 0]

                if end_sample <= start_sample:
                    continue  # 잘못된 구간 스킵

                # Waveform repeat + padding 후 Mel_db
                cycle_wave = waveform[:, start_sample:end_sample]
                normed_wave = preprocess_waveform_segment(cycle_wave, unit_length=int(self.target_sec * self.target_sr))
                mel = generate_mel_spectrogram(normed_wave, sample_rate, frame_size=self.frame_size, hop_length=self.hop_length, n_mels=self.n_mels)

                # crackle, wheeze -> class
                cr = int(lung_label[idx, 0])
                wh = int(lung_label[idx, 1])
                label = get_class(cr, wh)

                multi_label = torch.tensor([
                    float(label in [1, 3]),
                    float(label in [2, 3])
                ])  # 변환된 multi-label 반환

                # meta_data
                meta_data = (filename, lung_duration)

                self.cycle_list.append((mel, multi_label, meta_data))

        print(f"[INFO] Total cycles collected: {len(self.cycle_list)}")

    def __len__(self):
        return len(self.cycle_list)

    def __getitem__(self, idx):
        mel, label, meta_data = self.cycle_list[idx]
        return mel, label, meta_data

##### Pickle.dump

CycleDataset 객체 생성

In [None]:
import random
import matplotlib.pyplot as plt
import librosa.display

wav_dir = ROOT
txt_dir = ROOT

# 1. Dataset 로드
train_dataset = CycleDataset(train_list, wav_dir, txt_dir)
test_dataset = CycleDataset(test_list, wav_dir, txt_dir)

pickle로 train_dataset, test_dataset 외부 저장

In [None]:
pickle_name = f'MLS_{args.target_sr//1000}kHz_{args.frame_size}win_{args.hop_length}hop_{args.n_mels}mel_{args.target_sec}s'

In [None]:
pickle_dict = {
    'train_dataset': train_dataset,
    'test_dataset': test_dataset
}

save_path = os.path.join(PICKLE_PATH, pickle_name + '.pkl')
with open(save_path, 'wb') as f:
    pickle.dump(pickle_dict, f)

In [None]:
# # 2. 간단 통계
# print(f"Total cycles: {len(train_dataset)}")

# label_counter = [0] * 4  # normal, crackle, wheeze, both
# for _, multi_label,_ in train_dataset:
#     if torch.equal(multi_label, torch.tensor([0., 0.])):
#         label_counter[0] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 0.])):
#         label_counter[1] += 1
#     elif torch.equal(multi_label, torch.tensor([0., 1.])):
#         label_counter[2] += 1
#     elif torch.equal(multi_label, torch.tensor([1., 1.])):
#         label_counter[3] += 1

# for idx, count in enumerate(label_counter):
#     print(f"Class {idx}: {count} cycles")

##### Pickle.load
저장된 train_dataset, test_dataset을 로드  
(> Aug 는 Moco 모델에서 사용)

In [None]:
pickle_name

In [None]:
save_path = os.path.join(PICKLE_PATH, pickle_name + '.pkl')
with open(save_path, 'rb') as f:
    pickle_dict = pickle.load(f)

train_dataset = pickle_dict['train_dataset']
test_dataset = pickle_dict['test_dataset']

print(f"[Train] Cycles: {len(train_dataset)}")
print(f"[Test] Cycles: {len(test_dataset)}")

#### 2.4 DataLoader

In [None]:
# ---------------- 학습 데이터 구성(seed) ----------------
seed_everything(args.seed)

# train_dataset 내에서 각 파일의 인덱스를 추출
pretrain_idx = []
finetune_idx = []

for i in range(len(train_dataset)):
    filename = train_dataset[i][2][0]

    if filename in pretrain_list:
        pretrain_idx.append(i)
    elif filename in finetune_list:
        finetune_idx.append(i)

# 인덱스 순서 셔플
random.shuffle(pretrain_idx)
random.shuffle(finetune_idx)

print(f"Pretrain set size: {len(pretrain_idx)}, Finetune set size: {len(finetune_idx)}")

코드 실행 환경에 따라 num_workers를 적절한 값으로 지정해주세요!

In [None]:
# Dataset 생성 (Subset)
pretrain_dataset = Subset(train_dataset, pretrain_idx)
finetune_dataset = Subset(train_dataset, finetune_idx)

# DataLoader 생성
pretrain_loader = DataLoader(
    pretrain_dataset,
    batch_size=args.batch_size,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    shuffle=False
)

finetune_loader = DataLoader(
    finetune_dataset,
    batch_size=args.batch_size,
    num_workers=0,
    drop_last=True,
    pin_memory=True,
    shuffle=False
)

test_loader = DataLoader(
    test_dataset,
    batch_size=args.batch_size,
    num_workers=0,
    pin_memory=True,
    shuffle=False
)

In [None]:
seed_everything(42)

train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    num_workers=4,
    drop_last=True,
    pin_memory=True,
    shuffle=True
)

label 분포 확인 (단순 참고용, 실제 환경에서는 pretrain set의 label 분포가 어떤지 알 수 없음)

In [None]:
from collections import Counter

# label
labels = torch.stack([multi_label for _, multi_label, _ in train_dataset])

# pretext와 finetune 데이터셋의 라벨 분포 출력
pretrain_labels = labels[pretrain_idx]
pretrain_labels_class = (
    pretrain_labels[:, 0].long() * 1 +  # crackle bit → *1
    pretrain_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}
finetune_labels = labels[finetune_idx]
finetune_labels_class = (
    finetune_labels[:, 0].long() * 1 +  # crackle bit → *1
    finetune_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}

# test 데이터셋의 라벨 분포 출력
test_labels = torch.stack([multi_label for _, multi_label, _ in test_dataset])
test_labels_class = (
    test_labels[:, 0].long() * 1 +  # crackle bit → *1
    test_labels[:, 1].long() * 2    # wheeze bit  → *2
)  # [N] shape, values in {0, 1, 2, 3}

print(f"Pretrain sample: {len(pretrain_labels_class)}")
print("Pretrain label distribution:", Counter(pretrain_labels_class.tolist()))
print(f"\nFinetune sample: {len(finetune_labels_class)}")
print("Finetune label distribution:", Counter(finetune_labels_class.tolist()))
print(f"\nTest sample: {len(test_labels_class)}")
print("Test label distribution:", Counter(test_labels_class.tolist()))

## 3. Modeling

#### 3.1 Pre-trained ResNet50

In [None]:
def backbone_resnet():
    # 1. 기본 ResNet50 생성 (pretrained=False로 시작)
    resnet = models.resnet50(pretrained=False)

    # 2. 첫 번째 conv 레이어를 1채널용으로 수정
    resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

    # 먼저 fc 제거
    resnet.fc = nn.Identity()

    # 3. ImageNet 가중치 로드 (conv1 제외)
    state_dict = load_state_dict_from_url(
        'https://download.pytorch.org/models/resnet50-19c8e357.pth',
        progress=True
    )
    if 'conv1.weight' in state_dict:
        del state_dict['conv1.weight']
    resnet.load_state_dict(state_dict, strict=False)

    return resnet

ResNet18

In [None]:
# from torchvision import models
# from torch.hub import load_state_dict_from_url
# import torch.nn as nn

# def backbone_resnet():
#     # 1. 기본 ResNet18 생성
#     resnet = models.resnet18(pretrained=False)

#     # 2. 첫 번째 conv 레이어를 1채널용으로 수정
#     resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

#     # fc 제거
#     resnet.fc = nn.Identity()

#     # 3. ImageNet 가중치 로드 (conv1 제외)
#     state_dict = load_state_dict_from_url(
#         'https://download.pytorch.org/models/resnet18-f37072fd.pth',
#         progress=True
#     )
#     if 'conv1.weight' in state_dict:
#         del state_dict['conv1.weight']
#     resnet.load_state_dict(state_dict, strict=False)

#     return resnet

In [None]:
# summary 함수 사용: (채널, 높이, 너비) 크기를 지정
summary(backbone_resnet().to(device), input_size=(1, 224, 64))

#### 3.2 MoCo (MLS)

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# K: queue_g의 크기
# dim_enc: projector 통과 전, g1,g2 벡터의 차원
# dim_prj: projector 통과 후, z1,z2 벡터의 차원
class MoCo(nn.Module):
    def __init__(self, base_encoder, dim_enc=args.out_dim, dim_prj=128, K=512, m=0.999, T=0.07, top_k=10, lambda_bce=0.5):
        super().__init__()
        self.K = K
        self.m = m
        self.T = T
        self.top_k = top_k
        self.lambda_bce = lambda_bce

        self.encoder_q = base_encoder()
        self.encoder_k = base_encoder()

        dim_enc = dim_enc
        self.proj_head_q = nn.Sequential(
            nn.Linear(dim_enc, dim_enc),
            nn.BatchNorm1d(dim_enc),
            nn.GELU(),
            nn.Linear(dim_enc, dim_prj)
        )
        self.proj_head_k = nn.Sequential(
            nn.Linear(dim_enc, dim_enc),
            nn.BatchNorm1d(dim_enc),
            nn.GELU(),
            nn.Linear(dim_enc, dim_prj)
        )

        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data.copy_(param_q.data)
            param_k.requires_grad = False

        self.register_buffer("queue_g", F.normalize(torch.randn(dim_enc, K), dim=0))      # g2를 정규화한 후 열 단위로 Qg에 저장
        self.register_buffer("queue_z", F.normalize(torch.randn(dim_prj, K), dim=0))      # z2를 정규화한 후 열 단위로 Qz에 저장
        self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))               # 현재 queue에 새로 쓸 위치(인덱스)를 추적하는 포인터 역할

    @torch.no_grad()
    def _momentum_update_key_encoder(self):
        for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
            param_k.data = param_k.data * self.m + param_q.data * (1. - self.m)

    @torch.no_grad()
    def _dequeue_and_enqueue(self, g2, z2):
        batch_size = g2.shape[0]
        ptr = int(self.queue_ptr)
        assert self.K % batch_size == 0
        self.queue_g[:, ptr:ptr+batch_size] = g2.T.detach()
        self.queue_z[:, ptr:ptr+batch_size] = z2.T.detach()
        self.queue_ptr[0] = (ptr + batch_size) % self.K

    def forward(self, im_q, im_k, epoch=None, warmup_epochs=10):
        # encoder_q → g1 (feature)
        g1 = F.normalize(self.encoder_q(im_q), dim=1)  # shape: [B, 2048]

        # projection head → z1
        z1 = F.normalize(self.proj_head_q(g1), dim=1)  # shape: [B, 128]

        # encoder k
        with torch.no_grad():
            self._momentum_update_key_encoder()
            g2 = F.normalize(self.encoder_k(im_k), dim=1)
            z2 = F.normalize(self.proj_head_k(g2), dim=1)

        # top-k mining
        sim_g = torch.matmul(g1, self.queue_g.clone().detach())  # [N, K]
        # Ablation(1-1) Hard top-k
        topk_idx = torch.topk(sim_g, self.top_k, dim=1).indices
        y = torch.zeros_like(sim_g)
        y.scatter_(1, topk_idx, 1.0)
        # # Ablation(1-2) Soft top-k
        # topk_sim, topk_idx = torch.topk(sim_g, self.top_k, dim=1)
        # y = torch.zeros_like(sim_g)
        # y.scatter_(1, topk_idx, F.softmax(topk_sim / self.T, dim=1))

        ##################################################################
        # logits from z1 · Qz
        sim_z = torch.matmul(z1, self.queue_z.clone().detach())
        # Ablation(2-1) BCE Loss
        bce_loss = F.binary_cross_entropy_with_logits(sim_z / self.T, y) # 개선-> sigmoid(sim_z), 1/D

        # # Ablation(2-2) Weighted BCE Loss
        # pos_weight = torch.ones_like(sim_z) * (self.K / self.top_k)
        # bce_loss = F.binary_cross_entropy_with_logits(sim_z / self.T, y, pos_weight=pos_weight)

        ###################################################################
        # InfoNCE loss
        l_pos = torch.sum(z1 * z2, dim=1, keepdim=True)
        l_neg = torch.matmul(z1, self.queue_z.clone().detach())
        logits = torch.cat([l_pos, l_neg], dim=1) / self.T
        labels = torch.zeros(logits.shape[0], dtype=torch.long).to(logits.device)
        info_nce_loss = F.cross_entropy(logits, labels)

        # Total loss (with optional warmup) # MLS 논문에서는 warmup 아예 안쓴다고 함
        if epoch is not None and epoch < warmup_epochs:
            loss = info_nce_loss
        # else:
        loss = info_nce_loss + self.lambda_bce * bce_loss
        # print(f"INFO_NCE: {info_nce_loss}")
        # print(f"TRIPLET: {triplet_loss}")
        # print(f"BCE: {bce_loss}")

        self._dequeue_and_enqueue(g2, z2)

        return loss, logits, labels

## 4. Pretrain

In [None]:
next(iter(pretrain_loader))[0][0].shape

In [None]:
pretrain_project_name = f'SBW_aug(T.N)_PT_{args.target_sr}sr_{args.n_mels}mels_{args.batch_size}bs_top{args.top_k}_{args.lambda_bce}ld_{get_timestamp()}'

In [None]:
# 모델 지정하기 전 seed 고정 필요
seed_everything(args.seed) # Seed 고정

wandb.init(
    project="SBW_ICBHI_MLS",           # 프로젝트 이름
    name=f"{pretrain_project_name}", # 실험 이름
    config={
        "epochs": args.ft_epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay,
    }
)

# 1. MoCo 모델 생성
model = MoCo(
    base_encoder = backbone_resnet,
    dim_enc = args.out_dim,
    dim_prj = args.dim_prj,
    K = args.K,
    m = args.momentum,
    T = args.T,
    top_k = args.top_k,
    lambda_bce = args.lambda_bce
).cuda()

# 2. Optimizer
optimizer = torch.optim.AdamW(model.parameters(), args.lr, weight_decay=args.weight_decay)
# optimizer = torch.optim.SGD(
#     model.parameters(),
#     lr=args.lr,
#     momentum=0.9,
#     weight_decay=args.weight_decay,
#     nesterov=True
# )

# 3. Cosine Scheduler
scheduler = CosineAnnealingLR(optimizer, T_max=args.epochs, eta_min=1e-6)

# 4. Train
# Best loss 초기화
best_loss = float('inf')
best_epoch = -1

for epoch in range(args.epochs):
    # ===============================
    # Training
    # ===============================
    model.train()
    total_train_loss = 0.0

    for i, (repeat_mel, label, _) in enumerate(pretrain_loader): # label 여기선 사용 X
        im_q, im_k, _ = aug(repeat_mel)

        # scaling augs
        im_q = (im_q - im_q.mean() ) / (im_q.std() + 1e-6)
        im_k = (im_k - im_k.mean() ) / (im_k.std() + 1e-6)

        im_q = im_q.cuda(device=args.gpu, non_blocking=True)
        im_k = im_k.cuda(device=args.gpu, non_blocking=True)

        optimizer.zero_grad()
        loss, output, target = model(im_q=im_q, im_k=im_k)
        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(pretrain_loader)
    print(f"Epoch {epoch} | Avg Train Loss: {avg_train_loss:.4f}")
    print(f"[Epoch {epoch} | Step {i}] im_q: {im_q.shape}, im_k: {im_k.shape}")

    # =====================================
    # Scheduler
    # =====================================
    scheduler.step()

    # # =====================================
    # Logging with wandb
    # =====================================
    current_lr = optimizer.param_groups[0]['lr']
    wandb.log({
        # "epoch": epoch,
        "train_loss": avg_train_loss,
        # "lr": current_lr
    })

    # =====================================
    # Checkpoint (Every 100 epochs)
    # =====================================
    if (epoch + 1) % 100 == 0:
        ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_{epoch:03d}.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, ckpt_path)
        print(f"💾 Saved checkpoint to {ckpt_path}")

    # ===============================
    # Save Best Checkpoint
    # ===============================
    if avg_train_loss < best_loss:
        best_loss = avg_train_loss
        best_epoch = epoch
        best_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")

## 5. Linear Evaluation

#### validate

In [None]:
len(test_dataset)

In [None]:
def validate(model, val_loader, criterion, device):
    import numpy as np
    from sklearn.metrics import confusion_matrix

    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels, _ in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            running_loss += loss.item()

            preds = (torch.sigmoid(outputs) > 0.5).int()  # threshold = 0.5
            all_preds.append(preds.cpu())
            all_labels.append(labels.cpu())

    all_preds = torch.cat(all_preds, dim=0).numpy()   # [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy() # [N, 2]

    avg_loss = running_loss / len(val_loader)
    return avg_loss, all_labels, all_preds


### Weighted loss

In [None]:
from collections import Counter
import torch
import numpy as np

# 💡 다중 라벨 예시: targets는 [B, C] binary matrix (e.g., [1, 0, 1, 0])
label_list = []

# 👇 train_dataset이 (x, multi_label_tensor, _) 형태라고 가정
for _, label, _ in test_dataset:
    label_list.append(label)  # label: Tensor([0, 1, 0, 1])처럼

# 전체 label을 합치기
all_labels = torch.stack(label_list, dim=0)  # shape: [N, C]
num_classes = all_labels.size(1)
total_samples = all_labels.size(0)

# 클래스별 1의 개수 세기
class_counts = all_labels.sum(dim=0)  # shape: [C]
class_weights = total_samples / (num_classes * class_counts + 1e-6)  # smoothed

# tensor로 변환
class_weights_tensor = class_weights.float().to(device)

# 🔹 출력
for i, count in enumerate(class_counts.tolist()):
    print(f"Class {i} - Positives (1): {int(count)} / {total_samples} samples")
print(f"Class Weights: {class_weights_tensor}")

alpha_norm = class_weights_tensor / class_weights_tensor.sum()
print(f"alpha_norm: {alpha_norm}")

In [None]:
import torch

# ⚙️ 각 클래스의 positive 개수 (from label distribution)
crackle_pos = 262 + 83  # label 1 or 3
wheeze_pos  = 84 + 83   # label 2 or 3

total_samples = 885
num_classes = 2

# ⚖️ 기본 class weight 계산: inverse frequency
class_counts = torch.tensor([crackle_pos, wheeze_pos], dtype=torch.float)
class_weights = total_samples / (num_classes * class_counts + 1e-6)

# ✅ 정규화: sum = 1
alpha_norm = class_weights / class_weights.sum()

# 출력
print("Raw Class Weights:", class_weights)
print("Normalized Alpha (sum=1):", alpha_norm)


### Multi-label Focal Loss

In [None]:
import torch.nn.functional as F
import torch.nn as nn

class MultiLabelFocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction='mean'):
        super().__init__()
        self.alpha = alpha  # Tensor of shape [C], or scalar
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        """
        logits: [B, C] - raw scores
        targets: [B, C] - binary or soft labels
        """
        probs = torch.sigmoid(logits)  # [B, C]
        ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')  # [B, C]

        pt = probs * targets + (1 - probs) * (1 - targets)  # p_t
        focal_weight = (1 - pt) ** self.gamma               # (1 - pt)^γ

        loss = focal_weight * ce_loss                       # focal weight 적용

        if self.alpha is not None:
            alpha_factor = self.alpha * targets + (1 - self.alpha) * (1 - targets)  # [B, C]
            loss = alpha_factor * loss

        if self.reduction == 'mean':
            return loss.mean()
        elif self.reduction == 'sum':
            return loss.sum()
        else:
            return loss

In [None]:
from collections import Counter
import torch

label_dist = Counter({0:456, 1:262, 2:84, 3:83})  # Finetune 분포

# Crackle: (1 + Both), Wheeze: (2 + Both)
n_crackle = label_dist[1] + label_dist[3]  # 262 + 83
n_wheeze  = label_dist[2] + label_dist[3]  # 84 + 83
n_total   = sum(label_dist.values())       # 885

pos_weight = torch.tensor([
    (n_total - n_crackle) / (n_crackle + 1e-6),
    (n_total - n_wheeze) / (n_wheeze + 1e-6)
], device=device)

print(pos_weight)

## Linear Evaluation

In [None]:
wandb.finish()

In [None]:
## Wandb 정의

# import wandb
finetune_project_name = f'SBW_aug(T.N)_LE_{args.target_sr}sr_{args.n_mels}mels{args.batch_size}bs_{get_timestamp()}'

wandb.init(
    project="SBW_ICBHI_MLS",           # 프로젝트 이름
    name=f"{finetune_project_name}", # 실험 이름
    config={
        "epochs": args.ft_epochs,
        "batch_size": args.batch_size,
        "lr": args.lr,
        "momentum": args.momentum,
        "weight_decay": args.weight_decay,
    }
)

In [None]:
import os
from torch.utils.data import DataLoader
import torch.optim as optim
from sklearn.metrics import precision_score, recall_score, f1_score

# 1. Model Load
# 위에서부터 했다면
load_ckpt_path = CHECKPOINT_PATH + f"/{pretrain_project_name}_best_checkpoint.pth.tar"
# 중간부터 이어서 한다면
# load_ckpt_path = CHECKPOINT_PATH + "/SHS_aug(T.N)_PT_128bs_top15_0.5ld_2507110810_best_checkpoint.pth.tar"

# 저장 경로
save_ckpt_path = CHECKPOINT_PATH+"/LE_pth"

# 재현성을 위한 시드 재설정
seed_everything(args.seed)

# MoCo 모델 생성 및 체크포인트 로드
model_eval = MoCo(
    base_encoder=backbone_resnet,
    dim_enc = args.out_dim,
    dim_prj=args.dim_prj,
    K=args.K,
    m=args.momentum,
    T=args.T,
    top_k=args.top_k,
    lambda_bce=args.lambda_bce
)

checkpoint = torch.load(load_ckpt_path, map_location=device)
model_eval.load_state_dict(checkpoint["state_dict"])

# 사전 학습된 encoder 추출
encoder = model_eval.encoder_q.to(device)

# 2. Dataset 정의
# Dataset 정의는 이미 되어있음 - test_loader

# 3. Fine-tuning을 위한 분류 모델 정의 ( Data 개수 작으므로, encoder 파라미터 frozen )
class FineTuningModel(nn.Module):
    def __init__(self, encoder, out_dim=args.out_dim, num_classes=2):
        super().__init__()
        self.encoder = encoder
        # 마지막 FC layer를 제외한 encoder의 모든 레이어 freeze
        for param in self.encoder.parameters():
            param.requires_grad = False

        # 새로운 분류 헤드 추가
        self.classifier = nn.Linear(out_dim, num_classes)

    def forward(self, x):
        features = self.encoder(x)
        return self.classifier(features)

# 재현성을 위한 시드 재설정
seed_everything(args.seed)

# 4. 모델, 손실 함수, 옵티마이저 설정
model = FineTuningModel(encoder, out_dim = args.out_dim).to(device)
##############################

# # Ablation(3-1) LE -> BCE Loss
criterion = nn.BCEWithLogitsLoss()
# criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

# # Ablation(3-2) LE -> Multi-label Focal Loss
# criterion = MultiLabelFocalLoss(
#     alpha=alpha_norm.to(device),  # 정규화된 값
#     gamma=2.0,                    # hard label일 경우
#     reduction='mean'
# )

############################
optimizer = optim.AdamW(model.classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay)
# optimizer = torch.optim.SGD(
#     model.parameters(),
#     lr=args.lr,
#     momentum=0.9,
#     weight_decay=args.weight_decay,
#     nesterov=True
# )
scheduler = CosineAnnealingLR(optimizer, T_max=args.ft_epochs, eta_min=1e-6)  # Linear Evaluation에서 epochs는 다르게 적용

# Best loss 초기화
best_loss = float('inf')
best_epoch = -1


# 5. Linear Evaluation
for epoch in range(args.ft_epochs):

    # ===============================
    # 1. Training
    # ===============================
    model.train()
    total_loss = 0.0
    total_predictions = 0.0
    correct_predictions = 0.0

    all_preds = []
    all_labels = []
    all_outputs = []

    pbar = tqdm(finetune_loader, desc='Linear Evaluation')
    for i, (cycle, labels, _) in enumerate(pbar):
        # Forward pass
        cycle = cycle.cuda(args.gpu)
        labels = labels.cuda(args.gpu)

        # backpropagation
        optimizer.zero_grad()
        output = model(cycle)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()

        # loss 계산
        total_loss += loss.item() # loss : -> float

        # 예측값과 실제값 저장 ( Ablation(4-1) threshold ?? )
        predicted = (torch.sigmoid(output) > 0.5).float()
        all_preds.append(predicted.detach().cpu())
        all_labels.append(labels.detach().cpu())
        all_outputs.append(output.detach().cpu())

    # train loss
    train_loss = total_loss / len(finetune_loader)

    # Concatenate
    all_preds = torch.cat(all_preds, dim=0).numpy()    # shape: [N, 2]
    all_labels = torch.cat(all_labels, dim=0).numpy()  # shape: [N, 2]
    all_output = torch.cat(all_outputs, dim=0).numpy()

    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.4f}")

    # =====================================
    # 2-Edited. Multi-class 민감도/특이도 계산
    # =====================================
    import numpy as np
    import matplotlib.pyplot as plt
    import seaborn as sns
    import wandb
    from sklearn.metrics import confusion_matrix

    def multilabel_to_multiclass(y):
        # Crackle → 1, Wheeze → 2, Both → 3, None → 0
        y = np.array(y)
        return y[:, 0] + y[:, 1]*2

    def evaluate_multiclass_confusion(y_true, y_pred, class_names=["Normal", "Wheeze", "Crackle", "Both"]):
        y_true_cls = multilabel_to_multiclass(y_true)
        y_pred_cls = multilabel_to_multiclass(y_pred)

        cm = confusion_matrix(y_true_cls, y_pred_cls, labels=[0, 1, 2, 3])

        # N_n: 정상 → 정상
        N_n = cm[0, 0]
        N_total = cm[0].sum()

        # 이상 클래스 정답 수: W, C, B
        W_total = cm[1].sum()
        C_total = cm[2].sum()
        B_total = cm[3].sum()

        # 각각의 정답 → 정확한 예측만 고려
        W_w = cm[1, 1]
        C_c = cm[2, 2]
        B_b = cm[3, 3]

        SP = N_n / (N_total + 1e-6) #spec
        SE = (W_w + C_c + B_b) / (W_total + C_total + B_total + 1e-6) #sense

        AS = (SP + SE) / 2
        HS = 2 * SP * SE / (SP + SE + 1e-6)

        return cm, SE, SP, y_true_cls, y_pred_cls

    def log_multiclass_conf_matrix_wandb(cm, class_names, sens, spec, normalize, tag):
        # Normalize (비율) 옵션
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1, keepdims=True)
            fmt = '.2f'
            title = "Confusion Matrix (Normalized %)"
        else:
            fmt = 'd'
            title = "Confusion Matrix (Raw Count)"

        fig, ax = plt.subplots(figsize=(7, 6))
        sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names, ax=ax)

        ax.set_xlabel('Predicted')
        ax.set_ylabel('True')
        ax.set_title(title)

        icbhi_score = (sens + spec) / 2
        # 우하단에 성능 출력
        ax.text(
            0.99, 0.15,
            f"Sensitivity: {sens*100:.2f}%\nSpecificity: {spec*100:.2f}%\nICBHI Score: {icbhi_score*100:.2f}%",
            ha='right', va='bottom',
            transform=plt.gca().transAxes,
            fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8)
        )

        plt.tight_layout()
        # wandb.log({tag: wandb.Image(fig)})
        # plt.close(fig)
        return fig

    # 1. 4-class Confusion Matrix 평가
    class_names = ["Normal", "Crackle", "Wheeze", "Both"]
    cm_4x4, finetune_train_sens, finetune_train_spec, y_true_cls, y_pred_cls = evaluate_multiclass_confusion(all_labels, all_preds, class_names)
    finetune_icbhi_score = (finetune_train_sens + finetune_train_spec)/2

    print("4-Class Confusion Matrix:\n", cm_4x4)
    print(f"Sensitivity: {finetune_train_sens:.4f}, Specificity: {finetune_train_spec:.4f}, ICBHI Score: {finetune_icbhi_score:.4f}")


    # ===============================
    # 3. Validation
    # ===============================
    test_loss, test_labels, test_preds = validate(
        model, test_loader, criterion, device
    )

    precision = precision_score(test_labels, test_preds, average='macro')
    recall = recall_score(test_labels, test_preds, average='macro')
    f1 = f1_score(test_labels, test_preds, average='macro')

    test_cm_4x4, test_sens, test_spec, test_y_true_cls, test_y_pred_cls = evaluate_multiclass_confusion(test_labels, test_preds)
    test_icbhi_score = (test_sens+test_spec)/2

    print("[Validation] Confusion Matrix:\n", test_cm_4x4)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"[VALIDATION] Sensitivity: {test_sens:.4f}, Specificity: {test_spec:.4f}, Avg ICBHI Score: {(test_sens+test_spec)/2:.4f}")
    print("##################################################")


    # ===============================
    # 4. Confusion Matrix
    # ===============================

    # 2. Finetune Count Confusion Matrix 시각화
    fig_finetune_raw = log_multiclass_conf_matrix_wandb(cm_4x4, class_names, finetune_train_sens, finetune_train_spec, normalize=False, tag="finetune_conf_matrix_raw")
    fig_finetune_norm = log_multiclass_conf_matrix_wandb(cm_4x4, class_names, finetune_train_sens, finetune_train_spec, normalize=True, tag="finetune_conf_matrix_norm")

    # 3. Test Confusion Matrix 시각화
    fig_test_raw = log_multiclass_conf_matrix_wandb(test_cm_4x4, class_names, test_sens, test_spec, normalize=False, tag="test_conf_matrix_raw")
    fig_test_norm = log_multiclass_conf_matrix_wandb(test_cm_4x4, class_names, test_sens, test_spec, normalize=True, tag="test_conf_matrix_norm")

    # 4. log dictionary 생성
    wandb_log_dict = {
        "finetune_conf_matrix_raw": wandb.Image(fig_finetune_raw),
        "finetune_conf_matrix_norm": wandb.Image(fig_finetune_norm),
        "test_conf_matrix_raw": wandb.Image(fig_test_raw),
        "test_conf_matrix_norm": wandb.Image(fig_test_norm)
    }

    # =====================================
    # 5. Checkpoint (Every 50 epochs)
    # =====================================
    if (epoch + 1) % 50 == 0:
        ckpt_path = save_ckpt_path + f"{finetune_project_name}_{epoch:03d}.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict()
        }, ckpt_path)
        print(f"💾 Saved checkpoint to {save_ckpt_path}")

    # ===============================
    # 6. Save Best Checkpoint
    # ===============================
    if test_loss < best_loss:
        best_loss = test_loss
        best_epoch = epoch
        best_ckpt_path = save_ckpt_path + f"{finetune_project_name}_best.pth.tar"
        torch.save({
            'epoch': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'loss': best_loss
        }, best_ckpt_path)
        print(f"=> Saved best checkpoint (epoch: {epoch}, loss: {best_loss:.4f})")


        # 🔹 Confusion Matrix Logging for Best
        cm_best, sens_best, spec_best,_, _ = evaluate_multiclass_confusion(test_labels, test_preds, class_names)
        fig_best_raw = log_multiclass_conf_matrix_wandb(cm_best, class_names, sens_best, spec_best, normalize=False, tag="best_test_conf_matrix_raw")

        fig_best_norm = log_multiclass_conf_matrix_wandb(cm_best, class_names, sens_best, spec_best, normalize=True, tag="best_test_conf_matrix_norm")

        wandb_log_dict.update({
            "best_test_conf_matrix_raw": wandb.Image(fig_best_raw),
            "best_test_conf_matrix_norm": wandb.Image(fig_best_norm)
        })


    if epoch == args.ft_epochs - 1:
        # 🔸 Confusion Matrix Logging for Last Epoch
        cm_last, sens_last, spec_last, _, _  = evaluate_multiclass_confusion(test_labels, test_preds, class_names)
        fig_last_raw = log_multiclass_conf_matrix_wandb(cm_last, class_names, sens_last, spec_last, normalize=False, tag="last_test_conf_matrix_raw")

        fig_last_norm = log_multiclass_conf_matrix_wandb(cm_last, class_names, sens_last, spec_last, normalize=True, tag="last_test_conf_matrix_norm")

        wandb_log_dict.update({
            "last_test_conf_matrix_raw": wandb.Image(fig_last_raw),
            "last_test_conf_matrix_norm": wandb.Image(fig_last_norm)
        })
    # =====================================
    # 7. Logging with wandb confusion matrix
    # =====================================

    # step 1. metrics
    wandb.log({
        # Train metrics
        "Finetune/epoch": epoch,
        "Finetune/train_loss": train_loss,
        "Finetune/test_loss": test_loss,
        "Finetune/train_sens": finetune_train_sens,
        "Finetune/train_spec": finetune_train_spec,
        "Finetune/icbhi_score": finetune_icbhi_score,

        # Test metrics
        "Test/loss": test_loss,
        "Test/sensitivity": test_sens,
        "Test/specificity": test_spec,
        "Test/icbhi_score": test_icbhi_score
    })

    # step 2. Confusion matrix
    wandb.log(wandb_log_dict)

    plt.close(fig_finetune_raw)
    plt.close(fig_finetune_norm)
    plt.close(fig_test_raw)
    plt.close(fig_test_norm)
    if 'fig_best_raw' in locals(): plt.close(fig_best_raw)
    if 'fig_best_norm' in locals(): plt.close(fig_best_norm)
    if 'fig_last_raw' in locals(): plt.close(fig_last_raw)
    if 'fig_last_norm' in locals(): plt.close(fig_last_norm)

    # ===============================
    # 8. Scheduler Step
    # ===============================
    scheduler.step()

wandb.finish()