In [None]:
import sys
from collections import defaultdict, Counter
import random

from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from  torch.utils.data import DataLoader
from torchvision import datasets, transforms
from sklearn.metrics import roc_curve, roc_auc_score
from sklearn.manifold import TSNE
import umap
from torchvision.models.resnet import BasicBlock, Bottleneck

import utils
import vision_transformer as vits
from vision_transformer import DINOHead, VisionTransformer

from my_utils import myResNet, ReturnEmbWrapper
if "../" not in sys.path:
    sys.path.append("../")
from dataloaders import load

from importlib import import_module, reload

In [None]:
def get_datasets(name="MNIST", normal_class=1, seed=0):
    if name == "MNIST": stats = ((0.1307,), (0.3081,))
    elif name == "CIFAR10": stats = ((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    elif name == "FashionMNIST": stats = ((0.2860,), (0.3530,))
    elif name == "SVHN": stats = ((0.4377, 0.4438, 0.4728), (0.198, 0.201, 0.197))
    val_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(*stats),
    ])
    # if name == "MNIST":
    #     test_dataset = datasets.MNIST(root="/workspace/angular_dino/datasets", download=False, train=False, transform=val_transform)
    # elif name == "CIFAR10":
    #     test_dataset = datasets.CIFAR10(root="/workspace/angular_dino/datasets/CIFAR10", download=False, train=False, transform=val_transform)
    # elif name == "FashionMNIST":
    #     test_dataset = datasets.FashionMNIST(root="/workspace/angular_dino/datasets", download=False, train=False, transform=val_transform)
    # elif name == "SVHN":
    #     test_dataset = datasets.SVHN(root="/workspace/angular_dino/datasets/SVHN", download=False, split="test", transform=val_transform)
    test_dataset = load(name, normal_class=[normal_class], unseen_anomaly=[0],return_test_subset=True, seed=seed, return_id = True)
    pu_dataset, pu_val_dataset = load(
        name=name,
        batch_size=False,
        normal_class=[normal_class],
        unseen_anomaly=[0],
        labeled_anomaly_class=False, 
        n_train = 4500,
        n_valid = 500,
        # n_test = 2000,
        n_unlabeled_normal = 4500, #n_unlabeled_normal
        n_unlabeled_anomaly = 250, #n_unlabeled_anomaly
        n_labeled_anomaly = 250, #n_labeled_anomaly
        return_extra_test_loader= False,
        return_subset=True,
        return_unl_pos_subset=False, 
        transform = val_transform,
        seed=seed, 
        return_id = False
        )
    pu_loader = DataLoader(dataset=pu_dataset, batch_size=100, shuffle=False, num_workers=8,
           drop_last=False, pin_memory=True)
    pu_val_loader = DataLoader(dataset=pu_val_dataset, batch_size=100, shuffle=False, num_workers=8,
           drop_last=False, pin_memory=True)
    return pu_dataset, pu_val_dataset, pu_loader, pu_val_loader, test_dataset

def get_result(model, cp_path, dataset="MNIST", prototype_vec_weight=None, normal_class=1, seed=0, varbose=True):
    # print("loading check point")
    # load_pretrained_backbone_head(model, cp_path, checkpoint_key="Teacher")
    print("preparing dataset")
    pu_dataset, pu_val_dataset, pu_loader, pu_val_loader, test_dataset = get_datasets(dataset, normal_class, seed=seed)
    print("obtaining prototype vecs")
    pu_unlabel_vec_dict, pu_positive_vec_dict, pu_normal_prototype_dict, pu_normal_prototype_dict_id = get_prototype_vec_ids(model, pu_loader, pu_val_loader, prototype_vec_weight)
    if varbose:
        print(f"unlabel_prototype_vec_ids = {list(pu_unlabel_vec_dict.keys())[:5]}, total dot = {list(map(lambda x: round(x, 4), pu_unlabel_vec_dict.values()))[:5]}")
        print(f"positive_prototype_vec_ids = {list(pu_positive_vec_dict.keys())[:5]}, total dot = {list(map(lambda x: round(x, 4), pu_positive_vec_dict.values()))[:5]}")
        print(f"noraml_prototype_vec_ids = {list(pu_normal_prototype_dict.keys())[:5]}, total dot = {list(map(lambda x: round(x, 4), pu_normal_prototype_dict.values()))[:5]}")
        print(f"noraml_prototype_vec_ids = {pu_normal_prototype_dict_id}")

    if varbose: print("calculating score")
    y_true, y_score, y_each_class, dot_each_class, dot_pos = calc_y_score(model, test_dataset, pu_normal_prototype_dict_id, pu_unlabel_vec_dict, pu_positive_vec_dict, pu_normal_prototype_dict, prototype_vec_weight, normal_class, varbose)
    if varbose: 
        print("y_true, y_score example ")
        for n, (t, s) in enumerate(zip(y_true, y_score)):
            print(t, s)
            if n == 9: break

    fpr, tpr, thresholds = roc_curve(y_true, y_score)
    auroc = roc_auc_score(y_true, y_score)
    if varbose:
        print("showing ROC curve")
        plt.figure(figsize=(16,9))
        plt.plot(fpr, tpr, marker='o')
        plt.xlabel('FPR: False positive rate')
        plt.ylabel('TPR: True positive rate')
        plt.grid()
        plt.title(f"AUROC = {auroc:.4f}")
    print(f"AUROC: {auroc}")
    
    if varbose:
        conf_score = np.copy(y_score)
        conf_score = np.where(conf_score>0.5, 1, conf_score).astype(int)
        print("confusion_matrix, thresholds=0.5")
        print(confusion_matrix(y_true, conf_score.tolist()))
        plot_class_scores(y_each_class, normal_class)
        plot_class_prototype_dot(dot_each_class, dot_pos)
    return auroc

def get_prototype_vec(cp_path):
    prototype_vec_weight = nn.utils.weight_norm(nn.Linear(64, 100, bias=False)).cuda()
    prototype_vec_weight.weight_g.data.fill_(1)
    prototype_vec_weight.weight_g.requires_grad = False
    state_prototype_vec = torch.load(cp_path, map_location="cpu")["t_prototype_vec"]
    prototype_vec_weight.load_state_dict(state_prototype_vec)
    prototype_vec_weight.eval()
    prototype_vec_weight(torch.eye(64).cuda(non_blocking=True))
    return prototype_vec_weight

def construct_model(dataset_name="MNIST"):
    embed_dim=64
    arcface_family_conf = {"name":False}
    if "MNIST" in dataset_name: ch=1
    else: ch=3

    teacher = myResNet(BasicBlock, [2,2,2,2], normalize=True)
    # teacher = myResNet(Bottleneck, [3,4,6,3], normalize=True)
    teacher.conv1 = nn.Conv2d(ch, 64, kernel_size=7, stride=1, padding=3,bias=False)
    teacher.fc = nn.Linear(teacher.fc.in_features, out_features=embed_dim, bias=True)
    # teacher = CustomWrapper(
    #         teacher,
    #         DINOHead(embed_dim, 512, False, arcface_family_conf=arcface_family_conf,
    #             nlayers=3,
    #             hidden_dim=512,
    #             bottleneck_dim=128,
    #             ),
    #     )
    teacher = ReturnEmbWrapper(
            teacher,
            DINOHead(
                64,
                256,
                use_bn=False,
                norm_last_layer=False,
                nlayers=2, #3,
                hidden_dim=128, #512
                bottleneck_dim=64 #128,
                ),
        )
    
    return teacher


def load_pretrained_backbone_head(model, cp_path, checkpoint_key="Teacher"):
    model.cuda()
    state_dict = torch.load(cp_path, map_location="cpu")["teacher"]
    state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
    backbone_state = {k.replace("backbone.", ""):v for k,v in state_dict.items() if k.startswith("backbone")}
    head_state = {k.replace("head.", ""):v for k,v in state_dict.items() if k.startswith("head")}
    model.backbone.load_state_dict(backbone_state, strict=False)
    model.head.load_state_dict(head_state, strict=False)
    model.eval()

def plot_embedding(val_loader, model, prototype_vec_weight):
    model.eval()
    emb_list=defaultdict(list)

    for images, ids in val_loader:
        with torch.no_grad():
            output = model.backbone(images.cuda(non_blocking=True))
            for n, out in enumerate(output):
                emb_list[ids[n].item()].append(out.cpu())
            
    emb_list = dict(sorted(emb_list.items(), key=lambda x: x[0]))
    emb_numpy = torch.vstack(sum(list(emb_list.values()), [])).numpy()
    prototype_vec_numpy = prototype_vec_weight.weight.data.cpu().detach().numpy()[:]

    emb_prototype_numpy = np.vstack([emb_numpy, prototype_vec_numpy])
    # tsne = TSNE(n_components=2, random_state = 0, perplexity = 30, n_iter = 1000)
    _umap = umap.UMAP(n_neighbors=15, n_components=2, metric="cosine", min_dist=0.7, spread=0.7)
    # emb_tsne = tsne.fit_transform(emb_prototype_numpy)
    emb_tsne = _umap.fit_transform(emb_prototype_numpy)
    return emb_tsne, emb_list

def select_unique_random_indices(arr, num=1):
    # ユニークな要素を取得
    unique_elements = set(arr)
    unique_indices = defaultdict(list)
    for id, cls in enumerate(arr):
        unique_indices[cls].append(id)
    selected_indices = {k:random.sample(list(v), min(num, len(v))) for k,v in unique_indices.items()}

    # # ユニークな要素のインデックスを保持する辞書
    # unique_indices = defaultdict(list)
    # for index, value in enumerate(arr):
    #     unique_indices[value].append(index)
    
    # # print(unique_indices)
    # unique_indices = dict(sorted(unique_indices.items(), key=lambda x:x[0]))
    # # print(unique_indices)
    # # ユニークな要素からランダムにインデックスを選択
    # # selected_indices = [unique_indices[element] for element in random.sample(list(unique_elements), len(unique_elements))]
    # if num==1:
    #     selected_indices = {k:random.choice(list(v)) for k,v in unique_indices.items()}
    # else:
    #     selected_indices = {k:random.sample(list(v), num) for k,v in unique_indices.items()}
    return selected_indices

def plot_class_scores(y_each_class, normal_class=1):
    plt.figure(figsize=(12,9))
    for n, (k, v) in enumerate(y_each_class.items()):
        plt.subplot(2,5, n+1)
        plt.bar(np.arange(len(v))[:], v[:], label=f"class id{k}\nmean:{np.mean(v):.4f}")
        plt.ylim(-0.1,0.5)
        plt.legend(handlelength=0, frameon=False)
    plt.suptitle("normal prototype vec cos similarity")
    plt.tight_layout()

    plt.figure(figsize=(12,9))
    anomaly_scores = []
    for k,v in y_each_class.items():
        if k!=normal_class:
            anomaly_scores.extend(v)
        else:
            normal_scores = v

    max_ = max(max(anomaly_scores), max(normal_scores))
    min_ = min(min(anomaly_scores), min(normal_scores))
    anomaly_scores = (np.array(anomaly_scores)-min_) / (max_-min_)
    normal_scores = (np.array(normal_scores)-min_) / (max_-min_)

    # bn = np.linspace(-1, 1, 10)
    bn = np.array(range(0,11))/10
    # bn = np.linspace(min_, max_, 10)
    plt.hist(anomaly_scores, label="normlity score from anomarly samples", bins=bn, color="red", ec="black", alpha=0.7)
    plt.hist(normal_scores, label="normlity score from normal samples", bins=bn, color="blue", ec="black", alpha=0.7)
    plt.legend()
    plt.title("normality score hist ano vs norm")
    print(f"anomaly scores: {[round(i, 4)for i in anomaly_scores]}")
    print(f"normal scores: {[round(i, 4)for i in normal_scores]}")

def plot_class_prototype_dot(dot_each_class, dot_pos):
    plt.figure(figsize=(12,9))
    for n, (k, v) in enumerate(dot_each_class.items()):
        max_id = np.argsort(v)[:5]
        plt.subplot(3,5, n+1)
        plt.bar(np.arange(len(v))[:], v[:], label=f"class id{k}\nmax_ids:{max_id}")
        plt.ylim(-0.1,0.5)
        plt.legend(handlelength=0, frameon=False)
    plt.subplot(3,5, n+2)
    max_id = np.argsort(dot_pos)[:5]
    plt.bar(np.arange(len(dot_pos))[:], dot_pos[:], label=f"positive\nmax_ids:{max_id}")
    plt.ylim(-0.1,0.5)
    plt.legend(handlelength=0, frameon=False)

def calc_y_score(model, val_dataset, pu_normal_prototype_dict_id, pu_unlabel_vec_dict, pu_positive_vec_dict, pu_normal_prototype_dict, prototype_vec_weight=None, normal_class=1, varbose=True):
    # normal_prototype_vec = head_state["last_layer.weight_v"][pu_normal_prototype_dict_id,:]
    if not prototype_vec_weight:
        normal_prototype_vec = model.head.last_layer.weight_v.detach().cpu().clone()
        normal_prototype_vec /= torch.norm(normal_prototype_vec, p=2)
    else:
        normal_prototype_vec = prototype_vec_weight.weight.data.detach().cpu().clone()

    # if not prototype_vec_weight:
    #     normal_prototype_vec = model.head.last_layer.weight_v.detach().cpu().clone()[pu_normal_prototype_dict_id,:]
    #     normal_prototype_vec /= torch.norm(normal_prototype_vec, p=2)
    # else:
    #     normal_prototype_vec = prototype_vec_weight.weight.data.detach().cpu().clone()[pu_normal_prototype_dict_id,:]
    #     # normal_prototype_vec = prototype_vec_weight.weight.data.detach().cpu().clone()
    # print(f"代表ベクトルのshape: {normal_prototype_vec.shape}, norm: {normal_prototype_vec.norm(dim=1, p=2)}")

    ################
    top5_normal_prototype_ids = list(pu_normal_prototype_dict.keys())[:10]
    # top5_normal_prototype_ids = [2,93,57,51,6]
    # top5_normal_prototype_ids = list(pu_unlabel_vec_dict.keys())[:5]
    top5_positive_prototype_ids = list(pu_positive_vec_dict.keys())[:10]
    # top5_normal_prototype_ids = [i for i in top5_normal_prototype_ids if i not in top5_positive_prototype_ids]

    # top5_normal_prototype_vecs= head_state["last_layer.weight_v"][top5_normal_prototype_ids,:]
    if not prototype_vec_weight:
        top5_normal_prototype_vecs= normal_prototype_vec[top5_normal_prototype_ids,:]
    else:
        top5_normal_prototype_vecs = normal_prototype_vec[top5_normal_prototype_ids,:]
        top5_positive_prototype_vecs = normal_prototype_vec[top5_positive_prototype_ids,:]
    if varbose:
        print(f"top5正常代表ベクトルのindex: {top5_normal_prototype_ids}")
        print(f"top5異常代表ベクトルのindex: {top5_positive_prototype_ids}")
    # print(f"top5正常代表ベクトルのshape: {top5_normal_prototype_vecs.shape}, norm: {top5_normal_prototype_vecs.norm(dim=1, p=2)}")
    top5_normal_prototype_wights = list(pu_normal_prototype_dict.values())[:10]
    top5_positive_prototype_wights = list(pu_positive_vec_dict.values())[:10]
    if varbose:
        print(f"top5_normal_prototype_weights: {top5_normal_prototype_wights}")
    top5_normal_prototype_wights = nn.functional.softmax(torch.tensor(top5_normal_prototype_wights)/0.5, dim=0, dtype=torch.float32)
    top5_positive_prototype_wights = nn.functional.softmax(torch.tensor(top5_positive_prototype_wights)/0.5, dim=0, dtype=torch.float32)
    if varbose:
        print(f"top5_normal_prototype_softmax_weights: {top5_normal_prototype_wights}")
        print(f"top5_positive_prototype_softmax_weights: {top5_positive_prototype_wights}")
    ################

    random.seed(42)
    #正常データのindex 
    # try: indices = val_dataset.targets
    # except: indices = val_dataset.labels
    try: targets = val_dataset.dataset.targets
    except: targets = val_dataset.dataset.labels
    if not torch.is_tensor(targets): targets = torch.tensor(targets)
    indices = targets[val_dataset.indices]
    normal_indices = (indices==normal_class).nonzero().flatten()
    anomaly_indices = (indices!=normal_class).nonzero().flatten()
    print(len(normal_indices), len(anomaly_indices))
    
    # if not torch.is_tensor(indices):
    #     indices = torch.tensor(indices)
    # normal_indices = (indices==1).nonzero().flatten()
    # normal_indices = random.sample(normal_indices.tolist(), 1000)
    # #異常データのindex
    # anomaly_indices = (indices!=1).nonzero().flatten()
    # anomaly_indices = random.sample(anomaly_indices.tolist(), 1000)

    #結果の配列
    y_true = []
    y_score = []
    y_each_class = defaultdict(list)
    dot_each_class = defaultdict(lambda: np.zeros(normal_prototype_vec.shape[0]))
    class_sample_num = {k:0 for k in range(10)}
    dot_pos = np.zeros(normal_prototype_vec.shape[0])

    for sample, id in val_dataset:
        # if i in normal_indices:
        #     y_true.append(0)
        # elif i in anomaly_indices:
        #     y_true.append(1)
        # else: continue
        # sample, id = val_dataset.__getitem__(i)
        # print(id)
        # if id==0: 
        #     print(id)
        #     continue
        if id==normal_class: y_true.append(0)
        elif id!=normal_class: y_true.append(1)
        if varbose:
            print(id)
            print(Counter(y_true))
        sample = sample.unsqueeze(0)
        if not prototype_vec_weight:
            with torch.no_grad():
                head_emb = model(sample.cuda(non_blocking=True), return_emb=True)[0].cpu()
                head_emb = nn.functional.normalize(head_emb.unsqueeze(0), p=2)[0]
                score = (normal_prototype_vec @ head_emb).item()
                y_score.append(score)
                y_each_class[id].append(score)
                
        else:
            with torch.no_grad():
                backbone_emb = model.backbone(sample.cuda(non_blocking=True))[0].cpu()
                # backbone_emb = nn.functional.normalize(backbone_emb.unsqueeze(0), p=2)[0]
                # y_score.append(((normal_prototype_vec @ backbone_emb).item()))

                # score = (top5_normal_prototype_vecs @ backbone_emb).mean().item() #meanではなくmaxも試す
                # score = (top5_normal_prototype_vecs @ backbone_emb).max().item() #meanではなくmaxも試す
                # score = ((top5_normal_prototype_vecs @ backbone_emb) @ top5_normal_prototype_wights).item()
                # score = ((top5_normal_prototype_vecs @ backbone_emb) * top5_normal_prototype_wights).max().item()
                score = (top5_normal_prototype_vecs @ backbone_emb).mean().item() - \
                    (top5_positive_prototype_vecs @ backbone_emb).mean().item()
                # score = ((top5_normal_prototype_vecs @ backbone_emb) @ top5_normal_prototype_wights).item() - \
                #     ((top5_positive_prototype_vecs @ backbone_emb) @ top5_positive_prototype_wights).item()*0.5
                # score = ((top5_normal_prototype_vecs @ backbone_emb) - (top5_positive_prototype_vecs @ backbone_emb)*0.5).max().item()
                # score = (((top5_normal_prototype_vecs @ backbone_emb) * top5_normal_prototype_wights) - \
                #     ((top5_positive_prototype_vecs @ backbone_emb) * top5_positive_prototype_wights)*0.5).max().item()
                y_score.append(score)
                if varbose:
                    print((1-score)/2)
                # print((2-score)/4)
                # y_score.append( ((top5_normal_prototype_vecs @ backbone_emb) @ top5_normal_prototype_wights).item())

                # backbone_emb = model.backbone(sample.cuda(non_blocking=True))[0]
                # sf_out = nn.functional.softmax(prototype_vec_weight(backbone_emb)).cpu()
                # y_score.append(sf_out[pu_normal_prototype_dict_id])
                y_each_class[id].append(score)
                dot_each_class[id] += (normal_prototype_vec @ backbone_emb).numpy()
                class_sample_num[id] += 1
                if id != normal_class:
                    dot_pos += (normal_prototype_vec @ backbone_emb).numpy()


        ##################
        # y_score.append((top5_normal_prototype_vecs @ head_emb).mean().item())
        # y_score.append( ((top5_normal_prototype_vecs @ head_emb) @ top5_normal_prototype_wights).item())
        ##################
    # print(normal_prototype_vec.shape, backbone_emb.shape, normal_prototype_vec.norm(dim=1), backbone_emb.norm(), (normal_prototype_vec @ backbone_emb).numpy())
    y_score = np.array(y_score)
    y_score = (y_score - y_score.min()) / (y_score.max() - y_score.min())
    y_score = 1 - y_score
    # y_score = (1-y_score)/2
    # y_score = (2-y_score)/4
    y_each_class = dict(sorted(y_each_class.items(), key=lambda x: x[0]))
    dot_each_class = dict(sorted(dot_each_class.items(), key=lambda x: x[0]))
    dot_each_class = {k:v/class_sample_num[k] for k,v in dot_each_class.items()}
    if varbose:
        plt.figure(figsize=(16,9))
        plt.bar(list(range(len(y_score))), y_score)
        # plt.bar(list(range(len(y_score))), y_true)
        print(Counter(y_true))
    return y_true, y_score, y_each_class, dot_each_class, dot_pos/1000

def get_prototype_vec_ids(model, pu_loader, pu_valloader, prototype_vec_weight=None):
    # import pdb; pdb.set_trace()

    num_unlabel = 0
    num_positive = 0
    for _,i in pu_loader:
        i = i.tolist()
        c = Counter(i)
        num_unlabel+=c[0]
        num_positive+=c[1]
    for _,i in pu_valloader:
        i = i.tolist()
        c = Counter(i)
        num_unlabel+=c[0]
        num_positive+=c[1]
    print("num unl pos",num_unlabel, num_positive)

    num_unlabel = 0
    num_positive = 0
    pu_num = 0
    val_num=0
    for _,i in pu_loader:
        for k in i:
            if k==0: num_unlabel+=1
            elif k==1: num_positive+=1
            else: raise NotImplementedError()
            pu_num+=1
    for _,i in pu_valloader:
        for k in i:
            if k==0: num_unlabel+=1
            elif k==1: num_positive+=1
            else: raise NotImplementedError()
            val_num+=1
    print("num unl pos",num_unlabel, num_positive, pu_num, val_num)

    pu_unlabel_vec_dict = defaultdict(float)
    pu_positive_vec_dict = defaultdict(float)

    #####
    for images, unl_pos in pu_loader:
        images = images.cuda(non_blocking=True)
        if not prototype_vec_weight:
            with torch.no_grad():
                output = model(images)[0][0]
        else:
            with torch.no_grad():
                output = prototype_vec_weight(model.backbone(images))
        output =  output.cpu().detach()
        sorted_out, sorted_arg = torch.sort(output, dim=1, descending=True)
        sorted_out, sorted_arg = sorted_out.numpy()[:, :5], sorted_arg.numpy()[:, :5]
        for n, (out, arg) in enumerate(zip(sorted_out, sorted_arg)):
            if unl_pos[n] == 0:
                for o, a in zip(out, arg):
                    pu_unlabel_vec_dict[a] += o/num_unlabel
            elif unl_pos[n] == 1:
                for o, a in zip(out, arg):
                    pu_positive_vec_dict[a] += o/num_positive
            else: raise NotImplementedError
    
    for images, unl_pos in pu_valloader:
        images = images.cuda(non_blocking=True)
        if not prototype_vec_weight:
            with torch.no_grad():
                output = model(images)[0][0]
        else:
            with torch.no_grad():
                output = prototype_vec_weight(model.backbone(images))
        output =  output.cpu().detach()
        sorted_out, sorted_arg = torch.sort(output, dim=1, descending=True)
        sorted_out, sorted_arg = sorted_out.numpy()[:, :5], sorted_arg.numpy()[:, :5]
        for n, (out, arg) in enumerate(zip(sorted_out, sorted_arg)):
            if unl_pos[n] == 0:
                for o, a in zip(out, arg):
                    pu_unlabel_vec_dict[a] += o/num_unlabel
            elif unl_pos[n] == 1:
                for o, a in zip(out, arg):
                    pu_positive_vec_dict[a] += o/num_positive
            else: raise NotImplementedError

    pu_unlabel_vec_dict = dict(sorted(pu_unlabel_vec_dict.items(), key=lambda x: x[1], reverse=True))
    pu_positive_vec_dict = dict(sorted(pu_positive_vec_dict.items(), key=lambda x: x[1], reverse=True))
    pu_normal_prototype_dict = pu_unlabel_vec_dict.copy()

    for k in pu_positive_vec_dict:
        if k in pu_normal_prototype_dict.keys():
            pu_normal_prototype_dict[k] -= pu_positive_vec_dict[k]

    pu_normal_prototype_dict = dict(sorted(pu_normal_prototype_dict.items(), key=lambda x: x[1], reverse=True))
    pu_normal_prototype_dict_id = list(pu_normal_prototype_dict.keys())[0]
    return  pu_unlabel_vec_dict, pu_positive_vec_dict, pu_normal_prototype_dict, pu_normal_prototype_dict_id

In [None]:
task_name = "test/fashionmnist/trials/tea-stu-arcface-m0.5-step50/trial1/fmnist_class1_resnet18_ptnum100_ep200_bs128_lambda1.0-warmup100_tea-stu-arcface-m0.5-step50-num10_useweight_smallhead_emb64_lr0.0005-weighted-sampler_test"

cp_path = "/workspace/angular_dino/dino/outputs/"+ task_name +"/checkpoints/checkpoint.pth"
dataset_name = "FashionMNIST" #"CIFAR10" #SVHN #FashionMNIST #MNIST

teacher = construct_model(dataset_name)
load_pretrained_backbone_head(teacher, cp_path, checkpoint_key="Teacher")
prototype_vec_weight = get_prototype_vec(cp_path)
auroc = get_result(teacher, cp_path, dataset_name, prototype_vec_weight, normal_class=1, seed=0, varbose=True)

# score = 0.0
# seeds = list(range(0,31)) #list(range(31,61)) #
# score_dict={}
# for s in seeds:
#     scores = []
#     for i in range(1,10):
#         task_name= f"test/mnist/trials/tea-stu-arcface-m0.5-step50/trial5/mnist_class{i}_resnet18_ptnum100_ep200_bs128_lambda1.0-warmup100_tea-stu-arcface-m0.5-step50-num10_useweight_smallhead_emb64_lr0.0005-weighted-sampler_test"
#         cp_path = "/workspace/angular_dino/dino/outputs/"+ task_name +"/checkpoints/checkpoint.pth"

#         teacher = construct_model(dataset_name)
#         load_pretrained_backbone_head(teacher, cp_path, checkpoint_key="Teacher")
#         prototype_vec_weight = get_prototype_vec(cp_path)
#         auroc = get_result(teacher, cp_path, dataset_name, prototype_vec_weight, normal_class=i, seed=s, varbose=False)

#         scores.append(auroc)
#     new_score = np.mean(scores)
#     print(f"seed:{s} auroc:{new_score:.4f}")
#     score_dict[str(s)] = new_score
#     if new_score > score:
#         score=new_score
#     if score > 0.9966: break