In [None]:
import sys
sys.path.append('..')

from .tsne import visualize_tsne, visualize_tsne_raw
import torch
import numpy as np

from torchvision.datasets import CIFAR10
import torchvision.transforms as T
from easydict import EasyDict

In [None]:
def load_embedding(path, label=False):
    data = torch.load(path)
    emb = data['embedding']
    if label:
        return emb, data['labels']
    return emb  

In [None]:
def get_embedding(src_embeddings):
    src_embeddings = torch.stack(src_embeddings).to('cpu')
    img_emb_src = src_embeddings.mean(dim=0, keepdim=True)
    img_emb_src /= img_emb_src.norm(dim=-1, keepdim=True)
    img_emb_src = img_emb_src.repeat(1,1).type(torch.float32)  # (B,1024)
    print(img_emb_src.shape)
    return img_emb_src.cpu()

In [None]:
def norm_embeddings(src_embeddings):
    src_embeddings = src_embeddings.clone()
    for i in range(len(src_embeddings)):
        src_embeddings[i] /= src_embeddings[i].norm(dim=-1, keepdim=True)
    return torch.stack(src_embeddings)

In [None]:
def normalize(embeddings): #list of list of embds
    results = []
    for lst in embeddings:
        res = []
        for emb in lst:
            emb = emb.clone()
            emb /= emb.norm(dim=-1, keepdim=True)
            res.append(emb)
        results.append(res)
    return results

In [None]:
device = torch.device('cuda:2')
CORRUPTIONS = ["gaussian_noise", "shot_noise", "impulse_noise", "defocus_blur", "glass_blur", "motion_blur", "zoom_blur", "snow", "frost", "fog", "brightness", "contrast", "elastic_transform", "pixelate", "jpeg_compression"]

In [None]:
opts = EasyDict({
    'model': 'wideresnet28',
    'pretrained': 'data/cifar10/Standard.pt',
    'datasets': {
        'CIFAR10': {
            'path': 'datasets'
        },
        'CIFAR10C': {
            'path': 'corruptions'
        }
    }
})

transform = T.Compose([
    T.ToTensor()
])

dataset = CIFAR10(opts.datasets['CIFAR10'].path, transform=transform)

In [None]:
src_embeddings = load_embedding('../data/embeddings/embedding_cifar10_wideresnet28_nonorm.pth')
idxs = np.random.choice(len(src_embeddings), size=10000, replace=False)
print(f'{len(src_embeddings)=}, {src_embeddings[0].shape=}')

features = np.stack(src_embeddings)
labels = np.array([t[1] for t in dataset]) #y_test.numpy()
print(features.shape, labels.shape, np.unique(labels))

features = features[idxs].squeeze()
labels = labels[idxs].squeeze()
print(features.shape, labels.shape, np.unique(labels))

In [None]:
corruption = "fog"
severity = 5

In [None]:
target_embeddings, target_labels = load_embedding(f'./data/embeddings/corruptions/embedding_cifar10c_wideresnet28_{corruption}{severity}.pth', label=True)
print(f'{len(target_embeddings)=}, {target_embeddings[0].shape=}')

ftt = np.stack(target_embeddings)
lbt = np.array(target_labels)
print(ftt.shape, lbt.shape, np.unique(lbt))

idxt = np.random.choice(len(target_embeddings), size=5000, replace=False)
ftt = ftt[idxt]
lbt = lbt[idxt]

In [None]:
visualize_tsne_raw(features = np.concatenate((features, ftt)), 
                labels = np.concatenate((labels, lbt+10)),
                label_names = [f'class{i}' for i in range(10)] + [f'class{i}-C' for i in range(10)],
                figsize=(15,10), dimension=2, perplexity=30)