# Load models

In [10]:
import os
import sys
from tqdm import tqdm
from time import time
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from sklearn import datasets
from sklearn.manifold import TSNE

cwd = os.getcwd()
module_path = "/".join(cwd.split('/')[0:-1])
if module_path not in sys.path:
    sys.path.append(module_path)

# Network architectures
from net.resnet import resnet50

from data_utils.ood_detection import cifar10
import metrics.uncertainty_confidence as uncertainty_confidence
from utils.gmm_utils import get_embeddings, gmm_fit, gmm_evaluate

In [15]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [16]:
# Load ResNet-18 model
model = resnet50(spectral_normalization=True,
                             mod = True,
                             mnist = False).to(device)
model.load_state_dict(torch.load("../saved_models/runs2/2024_03_12_14_45_25/resnet50_sn_3.0_mod_seed_1_best.model"))
activation = {}

def get_activation(name):
    def hook(model, input, output):
        activation[name] = input[0]
    return hook

    model.fc.register_forward_hook(get_activation('embedding'))

<All keys matched successfully>

In [19]:
num_classes = 10
batch_size = 128

# MNIST data loader
train_loader,val_loader = cifar10.get_train_valid_loader(64,True,1,0.1,root="../data/")
X = []
y = []
for images,labels in train_loader:
    images = images.to(device)
    embeddings = model(images)
    X.append(embeddings.cpu().detach().numpy())
    y.append(labels.detach().numpy())


In [None]:
tsne = TSNE(n_components=3, init='pca', random_state=0)
X_tsne = tsne.fit_transform(X)
def plot_embedding_3d(X, y,title=None):
    #坐标缩放到[0,1]区间
    x_min, x_max = np.min(X,axis=0), np.max(X,axis=0)
    X = (X - x_min) / (x_max - x_min)
    #降维后的坐标为（X[i, 0], X[i, 1],X[i,2]），在该位置画出对应的digits
    fig = plt.figure()
    ax = fig.add_subplot(1, 1, 1, projection='3d')
    for i in range(X.shape[0]):
        ax.text(X[i, 0], X[i, 1], X[i,2],str(y[i]),
                 color=plt.cm.Set1(y[i] / 10.),
                 fontdict={'weight': 'bold', 'size': 9})
    if title is not None:
        plt.title(title)

plot_embedding_3d(X_tsne,"t-SNE 3D " )

In [None]:
def plot_embedding_2d(data, y, title):
    x_min, x_max = np.min(data, 0), np.max(data, 0)
    data = (data - x_min) / (x_max - x_min)

    fig = plt.figure()
    ax = plt.subplot(111)
    for i in range(data.shape[0]):
        plt.text(data[i, 0], data[i, 1], str(y[i]),
                 color=plt.cm.Set1(y[i] / 10.),
                 fontdict={'weight': 'bold', 'size': 9})
    plt.xticks([])
    plt.yticks([])
    plt.title(title)
    return fig

plot_embedding_2d(X_tsne,y,"t-SNE 2D")