In [None]:
import os, sys
from libs import *

from data import ImageDataset
from models.models import PretextsCA
from sklearn.manifold import TSNE
from matplotlib import pyplot as plt

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
model = torch.load(
    "../ckps/HAM/PretextsCA/best.ptl", 
    map_location = "cuda", 
)

In [None]:
train_dataset = ImageDataset(
    data_dir = "../datasets/HAM/train/", 
    augment = True, 
)

In [None]:
features, attn_features = [], []
labels = []
for i in tqdm.tqdm_notebook(range(len(train_dataset))):
    image, label = train_dataset[i]

    feature, attn_feature = model(image.cuda().unsqueeze(0))[0]
    features.append(feature.squeeze(0).detach().cpu().numpy()), attn_features.append(attn_feature.squeeze(0).detach().cpu().numpy())
    labels.append(label)

features, attn_features = np.array(features), np.array(attn_features)
labels = np.array(labels)

In [None]:
embedder = TSNE(
    n_components = 2, n_iter = 1000, 
    random_state = 23, 
)
embedded_features, embedded_attn_features = embedder.fit_transform(features), embedder.fit_transform(attn_features)

In [None]:
plt.scatter(
    embedded_attn_features[:, 0], embedded_attn_features[:, 1], 
    c = labels, 
    s = 1, 
)

plt.xlim([-110, 110])
plt.ylim([-110, 110])
plt.show()

In [None]:
if not os.path.exists("../ckps/HAM/PretextsCA/attn_features"):
    os.makedirs("../ckps/HAM/PretextsCA/attn_features")
for c in range(7):
    attn_features_c = attn_features[np.where(labels == c)[0]]
    np.save("../ckps/HAM/PretextsCA/attn_features/attn_features_{}.npy".format(c), attn_features)