In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
from torchvision import transforms
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
from copy import deepcopy

from dataset import DataManager

import resnet50 as resnet_models

import argparse

import os
import random

In [None]:
def get_args():
    parser = argparse.ArgumentParser(description="Implementation of SwAV")
        
    #########################
    #### data parameters ####
    #########################
    parser.add_argument("--data_set", type=str, default="cifar",
                        help="dataset")
    parser.add_argument("--nmb_crops", type=int, default=[2], nargs="+",
                        help="list of number of crops (example: [2, 6])")
    parser.add_argument("--size_crops", type=int, default=[32], nargs="+",
                        help="crops resolutions (example: [224, 96])")
    parser.add_argument("--min_scale_crops", type=float, default=[0.14], nargs="+",
                        help="argument in RandomResizedCrop (example: [0.14, 0.05])")
    parser.add_argument("--max_scale_crops", type=float, default=[1], nargs="+",
                        help="argument in RandomResizedCrop (example: [1., 0.14])")
    parser.add_argument("--prior", type=float, default=0.5,
                        help="positive prior")
    
    
    ##########################
    #### train parameters ####
    ##########################
    parser.add_argument("--epochs", type=int, default=100)
    parser.add_argument("--arch", type=str, default='resnet50')
    parser.add_argument("--hidden_mlp", type=int, default=2048)
    parser.add_argument("--feat_dim", type=int, default=128)
    parser.add_argument("--batch_size", type=int, default=128)
    parser.add_argument("--seed", type=int, default=42)

    args = parser.parse_args()
    return args

In [None]:
def seeds(seed=42):
    """
    Fix random seeds.
    """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

# set

In [None]:
import sys

sys.argv = ['script.py',
    '--data_set', 'living17',
    '--hidden_mlp', '2048',
    '--feat_dim', '128',
    '--arch', 'resnet50',
    '--prior', '0.5',
    '--seed', '42',
]

In [None]:
args = get_args()

In [None]:
manager = DataManager(train_dataset = 'living17',
                      test_dataset = args.data_set,
                      batch_size = 64,
                      train_prior = args.prior,
                      test_prior = 1. - args.prior)
mv_dataset, sv_dataset, train_dataset, observed_subset, test_dataset = manager.get_data()
torch.cuda.empty_cache()

In [None]:
if args.data_set in ['cifarv2', 'cifar10c', 'cinic']:
    args.mean = [0.4914, 0.4822, 0.4465]
    args.std = [0.2470, 0.2435, 0.2616]
    args.channels = 3
elif args.data_set in ['usps', 'svhn']:
    args.mean = [0.5]
    args.std = [0.5]
    args.channels = 1
elif args.data_set in ['entity13', 'living17']:
    args.mean = [0.485, 0.456, 0.406]
    args.std = [0.229, 0.224, 0.225]
    args.channels = 3
elif args.data_set == 'camelyon17':
    args.mean = None
    args.std = None
    args.channels = 3
else:
    raise TypeError("no ")

In [None]:
seeds(args.seed)

In [None]:
train_dataset.dataset.data.shape

# load CL model

In [None]:
import glob
import os
import torch

def get_models(base_path):
    # base_path 아래의 args.data_set 폴더 경로
    folder_path = os.path.join(base_path, args.data_set)
    
    # prior + seed 조건 패턴
    pattern = f"{args.data_set}_{args.prior}_{args.seed}*.pth"
    ckpt_pth = sorted(glob.glob(os.path.join(folder_path, pattern)))

    models = []
    for pth in ckpt_pth:
        state_dict = torch.load(pth, map_location="cpu")  # state_dict만 저장했으므로 바로 로드
        model = resnet_models.__dict__[args.arch](
            normalize=True,
            hidden_mlp=args.hidden_mlp,
            output_dim=args.feat_dim,
            in_channels=args.channels
        )
        model.load_state_dict(state_dict)
        model.eval()
        models.append(model)

    return ckpt_pth, models

In [None]:
ckpt_pth, models = get_models('./model_log_l4_con')
len(models)

In [None]:
ckpt_pth

# Pretrain

In [None]:
import copy
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.metrics import f1_score, roc_auc_score


# ------------------------------------------------------------
# 메트릭 계산
# ------------------------------------------------------------
@torch.no_grad()
def compute_metrics(loader, cl_model, classifier, device):
    classifier.eval()
    cl_model.eval()

    all_preds, all_targets, all_probs = [], [], []
    for images, labels, _ in loader:
        images = images.to(device, non_blocking=True)
        targets = labels[0].to(device, non_blocking=True)

        feats = cl_model.forward_backbone(images)
        logits = classifier(feats).squeeze(-1)
        probs = torch.sigmoid(logits)

        preds = (logits > 0).long()

        all_preds.append(preds.cpu())
        all_targets.append(targets.cpu())
        all_probs.append(probs.cpu())

    if len(all_targets) == 0:
        return 0.0, 0.0, 0.0, 0, 0.0

    all_preds = torch.cat(all_preds).numpy()
    all_targets = torch.cat(all_targets).numpy()
    all_probs = torch.cat(all_probs).numpy()

    acc = float((all_preds == all_targets).mean())
    macro_f1 = f1_score(all_targets, all_preds, average='macro', zero_division=0)
    pos_count = int(all_preds.sum())
    try:
        auc = roc_auc_score(all_targets, all_probs)
    except ValueError:
        auc = 0.0
    return acc, macro_f1, pos_count, auc

# ------------------------------------------------------------
# 버퍼 없는 ModelEMA (파라미터만 EMA)
# ------------------------------------------------------------
class ModelEMA(nn.Module):
    """
    - ema_model을 deepcopy로 보관하고 학습은 하지 않는다.
    - 매 step 후 update(base_model)로 ema_model 파라미터만 지수평균 갱신한다.
    """
    def __init__(self, model: nn.Module, decay: float = 0.999):
        super().__init__()
        self.decay = float(decay)
        self.ema_model = deepcopy(model)
        for p in self.ema_model.parameters():
            p.requires_grad = False
        self.ema_model.eval()

    @torch.no_grad()
    def update(self, base_model: nn.Module):
        for p_src, p_ema in zip(base_model.parameters(), self.ema_model.parameters()):
            if not p_src.requires_grad:
                continue
            p_ema.data.mul_(self.decay).add_(p_src.data, alpha=1.0 - self.decay)

    def forward(self, x):
        return self.ema_model(x)

    def to(self, *args, **kwargs):
        self.ema_model.to(*args, **kwargs)
        return self

    def eval(self):
        self.ema_model.eval()
        return super().eval()

# ------------------------------------------------------------
# 메인 학습 루프
# ------------------------------------------------------------
def pre_main(
    cl_model: nn.Module,
    train_loader,
    test_loader,
    *,
    epochs: int = 100,
    lr: float = 1e-3,
    neg_weight: float = 0.1,
    use_ema: bool = True,
    ema_decay: float = 0.999,
    ema_start_epoch: int = 10,
    eval_with_ema_only: bool = False  # True면 출력은 EMA 기준으로만
):
    """
    - 학습: base classifier로만 업데이트
    - step 후: ema_model.update(base)로 EMA 파라미터 갱신
    - 평가: ema_start_epoch 이후엔 ema_model로 평가(옵션에 따라 base도 함께 출력)
    """
    global args
    device = 'cuda'

    # 1) backbone 고정
    cl_model.to(device).eval()
    for p in cl_model.parameters():
        p.requires_grad = False

    # 2) classifier (base)
    feature_dim = 16 if getattr(args, "arch", "") == "lenet" else 2048
    classifier = nn.Linear(feature_dim, 1).to(device)

    # 3) 손실 / 옵티마이저
    pn_loss = nn.BCEWithLogitsLoss(reduction='none').to(device)
    optimizer = optim.Adam(classifier.parameters(), lr=lr)

    # 4) EMA 모델(teacher)
    ema_model = ModelEMA(classifier, decay=ema_decay).to(device) if use_ema else None

    # -------------------------
    # 학습 루프
    # -------------------------
    for epoch in range(1, epochs + 1):
        classifier.train()
        cl_model.eval()

        total_loss = 0.0
        num_batches = 0

        for images, labels, _ in train_loader:
            images = images.to(device, non_blocking=True)
            y_pu = labels[1].to(device, non_blocking=True)   # {+1, -1}
            targets = (y_pu == 1).float()                   # {1, 0}

            with torch.no_grad():
                feats = cl_model.forward_backbone(images)

            logits = classifier(feats).squeeze(-1)
            per_sample_loss = pn_loss(logits, targets)

            # negative 샘플 가중치
            weights = torch.ones_like(per_sample_loss)
            weights = torch.where(targets == 0,
                                  torch.full_like(per_sample_loss, float(neg_weight)),
                                  weights)
            loss = (weights * per_sample_loss).mean()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # EMA 파라미터 갱신
            if use_ema and epoch >= ema_start_epoch:
                ema_model.update(classifier)

            total_loss += loss.item()
            num_batches += 1

        avg_loss = total_loss / max(1, num_batches)

        # -------------------------
        # 평가 & 로그 출력
        # -------------------------
        tag = "EMA-ON" if (use_ema and epoch >= ema_start_epoch) else "EMA-OFF"
        
        if use_ema and epoch >= ema_start_epoch:
            # EMA 메트릭
            tr_acc_e, tr_mf1_e, tr_pos_e, tr_auc_e = compute_metrics(
                train_loader, cl_model, ema_model, device
            )
            te_acc_e, te_mf1_e, te_pos_e, te_auc_e = compute_metrics(
                test_loader, cl_model, ema_model, device
            )
        
        # BASE 메트릭(항상 계산)
        tr_acc_b, tr_mf1_b, tr_pos_b, tr_auc_b = compute_metrics(
            train_loader, cl_model, classifier, device
        )
        te_acc_b, te_mf1_b, te_pos_b, te_auc_b = compute_metrics(
            test_loader, cl_model, classifier, device
        )

        print('='*70, flush=True)
        if use_ema and epoch >= ema_start_epoch and not eval_with_ema_only:
            # ----- BASE + EMA 둘 다 출력 -----
            print(
                f"[Epoch {epoch:3d}/{epochs}] [{tag}]\n"
                f"Loss: {avg_loss:.4f}\n"
                f"(BASE) Train Acc: {tr_acc_b:.4f}  Macro-F1: {tr_mf1_b:.4f}  AUC: {tr_auc_b:.4f}  Pos#: {tr_pos_b} |\n"
                f"(BASE) Test  Acc: {te_acc_b:.4f}  Macro-F1: {te_mf1_b:.4f}  AUC: {te_auc_b:.4f}  Pos#: {te_pos_b} |\n"
                f"(EMA)  Train Acc: {tr_acc_e:.4f}  Macro-F1: {tr_mf1_e:.4f}  AUC: {tr_auc_e:.4f}  Pos#: {tr_pos_e} |\n"
                f"(EMA)  Test  Acc: {te_acc_e:.4f}  Macro-F1: {te_mf1_e:.4f}  AUC: {te_auc_e:.4f}  Pos#: {te_pos_e} |\n",
                flush=True
            )
        elif use_ema and epoch >= ema_start_epoch and eval_with_ema_only:
            # ----- EMA만 출력 -----
            print(
                f"[Epoch {epoch:3d}/{epochs}] [{tag}]\n"
                f"Loss: {avg_loss:.4f}\n"
                f"(EMA) Train Acc: {tr_acc_e:.4f}  Macro-F1: {tr_mf1_e:.4f}  AUC: {tr_auc_e:.4f}  Pos#: {tr_pos_e} |\n"
                f"(EMA) Test  Acc: {te_acc_e:.4f}  Macro-F1: {te_mf1_e:.4f}  AUC: {te_auc_e:.4f}  Pos#: {te_pos_e} |\n",
                flush=True
            )
        else:
            # ----- BASE만 출력 -----
            print(
                f"[Epoch {epoch:3d}/{epochs}] [{tag}]\n"
                f"Loss: {avg_loss:.4f}\n"
                f"(BASE) Train Acc: {tr_acc_b:.4f}  Macro-F1: {tr_mf1_b:.4f}  AUC: {tr_auc_b:.4f}  Pos#: {tr_pos_b} |\n"
                f"(BASE) Test  Acc: {te_acc_b:.4f}  Macro-F1: {te_mf1_b:.4f}  AUC: {te_auc_b:.4f}  Pos#: {te_pos_b} |\n",
                flush=True
            )



    return classifier, ema_model

In [None]:
cl_model = deepcopy(models[-1])

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=512,
    drop_last=True,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=512
)

In [None]:
pre_clf, pre_ema = pre_main(cl_model, train_loader, test_loader)

# 오분류

In [None]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.manifold import TSNE
from PIL import Image
import numpy as np
import torch

def to_numpy(x):
    if hasattr(x, "detach"):
        x = x.detach().cpu().numpy()
    elif torch.is_tensor(x):
        x = x.cpu().numpy()
    return x

def extract_embeddings(model, loader):
    embeddings, labels_pn, labels_pu = [], [], []

    model.eval()
    model = model.to('cuda')

    with torch.no_grad():
        for x, y, _ in loader:
            y_pn, y_pu = y
            x = x.to('cuda')
            feat = model.forward_backbone(x)
            embeddings.append(feat.cpu())
            labels_pn.append(y_pn)
            labels_pu.append(y_pu)

    embeddings = torch.cat(embeddings)
    labels_pn = torch.cat(labels_pn)
    labels_pu = torch.cat(labels_pu)

    return embeddings, labels_pn, labels_pu

def _save_under_mb(fig, base_path, dpi=300, max_mb=4):
    """fig를 base_path.(png/jpg/webp) 중 하나로 저장하되, 최종 파일이 max_mb 이하여야 함."""
    target_bytes = max_mb * 1024 * 1024

    # 1) 우선 PNG로 저장
    png_path = base_path + ".png"
    fig.savefig(png_path, dpi=dpi, bbox_inches='tight', facecolor='white')
    if os.path.getsize(png_path) <= target_bytes:
        return png_path

    # 2) PNG 팔레트 양자화(256색)로 용량 감소
    img = Image.open(png_path).convert("RGB")
    quant = img.quantize(colors=256, method=Image.MEDIANCUT)
    quant.save(png_path, optimize=True)
    if os.path.getsize(png_path) <= target_bytes:
        return png_path

    # 3) JPEG로 품질 낮추며 시도
    for q in [95, 90, 85, 80, 75, 70, 65, 60]:
        jpg_path = base_path + ".jpg"
        img.save(jpg_path, format="JPEG", quality=q, optimize=True, progressive=True)
        if os.path.getsize(jpg_path) <= target_bytes:
            os.remove(png_path)
            return jpg_path

    # 4) WebP로도 시도
    for q in [90, 80, 70, 60]:
        webp_path = base_path + ".webp"
        img.save(webp_path, format="WEBP", quality=q, method=6)
        if os.path.getsize(webp_path) <= target_bytes:
            os.remove(png_path)
            return webp_path

    # 5) 마지막 수단: 다운스케일 + JPEG
    w, h = img.size
    scale = 0.9
    while scale > 0.4:
        w2, h2 = int(w * scale), int(h * scale)
        down = img.resize((w2, h2), Image.LANCZOS)
        down.save(jpg_path, format="JPEG", quality=85, optimize=True, progressive=True)
        if os.path.getsize(jpg_path) <= target_bytes:
            os.remove(png_path)
            return jpg_path
        scale -= 0.1

    # 그래도 안되면 가장 작은 파일을 반환(실무상 여기까지 오기 힘듦)
    sizes = [(p, os.path.getsize(p)) for p in [png_path, jpg_path] if os.path.exists(p)]
    return min(sizes, key=lambda x: x[1])[0] if sizes else png_path

def visualize_tsne_train_test(model, loader_train, loader_test,
                              save_dir='./tsne', save=False,
                              dpi=300, max_mb=4):
    emb_tr, pn_tr, pu_tr = extract_embeddings(model, loader_train)
    emb_te, pn_te, pu_te = extract_embeddings(model, loader_test)

    all_emb = torch.cat([emb_tr, emb_te])
    emb_2d = TSNE(n_components=2, random_state=42).fit_transform(to_numpy(all_emb))
    n_tr = emb_tr.shape[0]
    emb_tr_2d = emb_2d[:n_tr]
    emb_te_2d = emb_2d[n_tr:]

    fig, axes = plt.subplots(3, 3, figsize=(18, 18))

    # color palette for pn
    unique_labels = np.unique(to_numpy(torch.cat([pn_tr, pn_te])))
    palette_pn = dict(zip(unique_labels, sns.color_palette("hsv", len(unique_labels))))

    # color palette for pu (U=-1, P=1)
    palette_pu = {
        -1: 'tab:orange',   # U
        1: 'tab:blue'       # labeled P
    }

    # 전체 x, y 범위 계산 + padding
    all_x = emb_2d[:, 0]
    all_y = emb_2d[:, 1]
    x_min, x_max = np.min(all_x), np.max(all_x)
    y_min, y_max = np.min(all_y), np.max(all_y)

    pad_x = (x_max - x_min) * 0.05
    pad_y = (y_max - y_min) * 0.05

    x_min -= pad_x
    x_max += pad_x
    y_min -= pad_y
    y_max += pad_y

    def plot(ax, emb, labels, title, palette):
        labels_np = to_numpy(labels)
        sns.scatterplot(
            x=emb[:, 0], y=emb[:, 1], hue=labels_np,
            palette=palette, alpha=0.6, s=20,
            legend='full', ax=ax
        )
        ax.set_title(title)
        ax.set_xlabel("Dim 1")
        ax.set_ylabel("Dim 2")
        ax.set_xlim(x_min, x_max)  # 동일한 x축 범위
        ax.set_ylim(y_min, y_max)  # 동일한 y축 범위
        ax.legend(title="Class", bbox_to_anchor=(1.05, 1), loc='upper left')

    # Train PN
    plot(axes[0, 0], emb_tr_2d[pn_tr == 1], pn_tr[pn_tr == 1], "Train (P)", palette_pn)
    plot(axes[0, 1], emb_tr_2d[pn_tr == 0], pn_tr[pn_tr == 0], "Train (N)", palette_pn)
    plot(axes[0, 2], emb_tr_2d, pn_tr, "Train (PN)", palette_pn)

    # Test PN
    plot(axes[1, 0], emb_te_2d[pn_te == 1], pn_te[pn_te == 1], "Test (P)", palette_pn)
    plot(axes[1, 1], emb_te_2d[pn_te == 0], pn_te[pn_te == 0], "Test (N)", palette_pn)
    plot(axes[1, 2], emb_te_2d, pn_te, "Test (PN)", palette_pn)

    # Train PU
    idx_sorted = torch.arange(len(pu_tr))
    idx_sorted = idx_sorted.sort(descending=True).values  # 큰 인덱스부터
    plot(axes[2, 0], emb_tr_2d[pu_tr == 1], pu_tr[pu_tr == 1], "Train (P)", palette_pu)
    plot(axes[2, 1], emb_tr_2d[pu_tr == -1], pu_tr[pu_tr == -1], "Train (U)", palette_pu)
    plot(axes[2, 2], emb_tr_2d[idx_sorted], pu_tr[idx_sorted], "Train (PU)", palette_pu)

    plt.tight_layout()
    
    # === 저장 ===
    if save == True:
        if save_dir is not None:
            os.makedirs(save_dir, exist_ok=True)
            prior_str = f"{float(args.prior):g}" if hasattr(args, "prior") else "prior"
            base = f"{args.data_set}_{prior_str}_{args.seed}" if hasattr(args, "data_set") else "tsne"
            base_path = os.path.join(save_dir, base)
    
            saved = _save_under_mb(fig, base_path, dpi=dpi, max_mb=max_mb)

    plt.show()

In [None]:
tr_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=512,
    drop_last=False,
)

In [None]:
visualize_tsne_train_test(cl_model, tr_loader, test_loader)

In [None]:
import matplotlib.colors as mcolors

def visualize_tsne_train_test2(model, classifier, loader_train, loader_test,
                              save_dir='./tsne_wpn', save=False, ema=False,
                              dpi=300, max_mb=4):
    device = 'cuda'
    model, classifier = model.to(device), classifier.to(device)
    model.eval(), classifier.eval()

    # -------------------------------
    # 임베딩 추출 + classifier 예측
    # -------------------------------
    def extract_with_pred(loader):
        feats, pn_labels, pu_labels, preds = [], [], [], []
        with torch.no_grad():
            for x, y, _ in loader:
                y_pn, y_pu = y
                x = x.to(device)

                feat = model.forward_backbone(x)
                logit = classifier(feat).squeeze(-1)
                prob = torch.sigmoid(logit)
                pred = (prob >= 0.5).long().cpu()

                feats.append(feat.cpu())
                pn_labels.append(y_pn)
                pu_labels.append(y_pu)
                preds.append(pred)

        return (torch.cat(feats),
                torch.cat(pn_labels),
                torch.cat(pu_labels),
                torch.cat(preds))

    emb_tr, pn_tr, pu_tr, pred_tr = extract_with_pred(loader_train)
    emb_te, pn_te, pu_te, pred_te = extract_with_pred(loader_test)

    all_emb = torch.cat([emb_tr, emb_te])
    emb_2d = TSNE(n_components=2, random_state=42).fit_transform(to_numpy(all_emb))
    n_tr = emb_tr.shape[0]
    emb_tr_2d = emb_2d[:n_tr]
    emb_te_2d = emb_2d[n_tr:]

    fig, axes = plt.subplots(3, 3, figsize=(18, 18))

    # PN color palette
    unique_labels = np.unique(to_numpy(torch.cat([pn_tr, pn_te])))
    palette_list = sns.color_palette("dark", len(unique_labels))
    palette_pn = dict(zip(unique_labels, reversed(palette_list)))

    # PU color palette
    palette_pu = {-1: 'tab:orange', 1: palette_pn[1]}

    # 범위 고정
    all_x, all_y = emb_2d[:, 0], emb_2d[:, 1]
    x_min, x_max = np.min(all_x), np.max(all_x)
    y_min, y_max = np.min(all_y), np.max(all_y)
    pad_x, pad_y = (x_max-x_min)*0.05, (y_max-y_min)*0.05
    x_min, x_max = x_min-pad_x, x_max+pad_x
    y_min, y_max = y_min-pad_y, y_max+pad_y


    def lighten_color(color, factor=0.5):
        """
        color를 더 밝게 만드는 함수.
        factor=0.0 → 원래색, factor=1.0 → 흰색
        """
        base = np.array(mcolors.to_rgb(color))
        white = np.array([1, 1, 1])
        return tuple(base + (white - base) * factor)
        
    def plot(ax, emb, labels, preds, title, palette, gt_mode=True):
        labels_np = to_numpy(labels)
    
        if preds is not None and gt_mode:
            preds_np = to_numpy(preds)
            correct_mask = labels_np == preds_np
        else:
            correct_mask = np.ones_like(labels_np, dtype=bool)
    
        # 색상 통일 (RGB 튜플)
        colors = np.array([mcolors.to_rgb(palette[l]) for l in labels_np])
    
        # 맞은 샘플
        ax.scatter(
            emb[correct_mask, 0], emb[correct_mask, 1],
            c=colors[correct_mask], alpha=0.6, s=20, marker="o", label=None
        )
    
        # 틀린 샘플
        light_colors = [lighten_color(c, factor=0.6) for c in colors[~correct_mask]]
        ax.scatter(
            emb[~correct_mask, 0], emb[~correct_mask, 1],
            c=light_colors, alpha=0.6, s=20, marker="o",
            label=None
        )
    
        ax.set_facecolor("gray")
        ax.set_title(title)
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_min, y_max)
        ax.set_xlabel("Dim 1")
        ax.set_ylabel("Dim 2")
    
        # 범례
        handles = []
        for l in np.unique(labels_np):
            handles.append(
                plt.Line2D([], [], marker="o", color=palette[l],
                           linestyle="", label=str(l))
            )
        ax.legend(handles=handles, title="Class", bbox_to_anchor=(1.05, 1), loc='upper left')


    # Train PN
    plot(axes[0, 0], emb_tr_2d[pn_tr == 1], pn_tr[pn_tr == 1], pred_tr[pn_tr == 1], "Train (P)", palette_pn)
    plot(axes[0, 1], emb_tr_2d[pn_tr == 0], pn_tr[pn_tr == 0], pred_tr[pn_tr == 0], "Train (N)", palette_pn)
    plot(axes[0, 2], emb_tr_2d, pn_tr, pred_tr, "Train (PN)", palette_pn)

    # Test PN
    plot(axes[1, 0], emb_te_2d[pn_te == 1], pn_te[pn_te == 1], pred_te[pn_te == 1], "Test (P)", palette_pn)
    plot(axes[1, 1], emb_te_2d[pn_te == 0], pn_te[pn_te == 0], pred_te[pn_te == 0], "Test (N)", palette_pn)
    plot(axes[1, 2], emb_te_2d, pn_te, pred_te, "Test (PN)", palette_pn)

    # Train PU
    # PU 전체를 인덱스 큰 순으로 정렬
    idx_sorted = torch.arange(len(pu_tr))
    idx_sorted = idx_sorted.sort(descending=True).values  # 큰 인덱스부터
    plot(axes[2, 0], emb_tr_2d[pu_tr == 1], pu_tr[pu_tr == 1], None, "Train (P)", palette_pu, gt_mode=False)
    plot(axes[2, 1], emb_tr_2d[pu_tr == -1], pu_tr[pu_tr == -1], None, "Train (U)", palette_pu, gt_mode=False)
    plot(axes[2, 2], emb_tr_2d[idx_sorted], pu_tr[idx_sorted], None, "Train (PU)", palette_pu, gt_mode=False)

    plt.tight_layout()

    if save:
        os.makedirs(save_dir, exist_ok=True)
        prior_str = f"{float(args.prior):g}" if hasattr(args, "prior") else "prior"
        if not ema:
            base = f"{args.data_set}_{prior_str}_{args.seed}" if hasattr(args, "data_set") else "tsne"
        else:
            base = f"{args.data_set}_{prior_str}_EMA_{args.seed}" if hasattr(args, "data_set") else "tsne"
        base_path = os.path.join(save_dir, base)
        _save_under_mb(fig, base_path, dpi=dpi, max_mb=max_mb)

    plt.show()

In [None]:
visualize_tsne_train_test2(cl_model, pre_clf, tr_loader, test_loader, save=True)

In [None]:
visualize_tsne_train_test2(cl_model, pre_ema, tr_loader, test_loader, save=True, ema=True)

# run