# Import libraries

In [None]:
import os, yaml

from datetime import datetime
from easydict import EasyDict

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets import CIFAR10
import torchvision.models as models

from pytorch_lightning import seed_everything, Trainer

from model.byol import ModelBase
from model.litmodel import LitModelLinear
from utils.setup_utils import get_device

# Configs

In [None]:
with open(f'configs/linear_config.yaml') as file:
    config = yaml.load(file, Loader=yaml.FullLoader)
    args = EasyDict(config)
    
args.current_time = datetime.now().strftime('%Y%m%d')

### Set Device ###
if torch.cuda.is_available():
    os.environ['CUDA_VISIBLE_DEVICES'] = args.GPU_NUM
    
args['device'] = get_device(args.GPU_NUM)
cudnn.benchmark = True
cudnn.fastest = True
cudnn.deterministic = True

args.lr = float(args.lr)
args.weight_decay = float(args.weight_decay)

### Set SEED ###
seed_everything(args.SEED)

# Load data

In [None]:
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]) if args.cifar \
        else transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

test_transform = transforms.Compose([
        transforms.ToTensor(),
        normalize
    ])

In [None]:
test_data = CIFAR10(root=args.DATA_PATH, train=False, transform=test_transform, download=True)
test_dataloader = DataLoader(test_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=6)

args.num_classes = len(test_data.classes)

# Load pre-trained model

In [None]:
if args.cifar:
    model = models.__dict__[args.arch]
    model = ModelBase(model)
else:
    model = models.__dict__[args.arch]()
model.fc = nn.Linear(model.fc.in_features, args.num_classes, bias=True)

# load pre-trained model
checkpoint = torch.load(args.MODEL_PATH, map_location=f'cpu')
state_dict = checkpoint['state_dict']

for k in list(state_dict.keys()):
    if k.startswith('model.'):
        state_dict[k[len('model.'):]] = state_dict[k]
    del state_dict[k]

model.load_state_dict(state_dict, strict=False)

In [None]:
model = LitModelLinear(model, args)

# Evaluation

In [None]:
trainer = Trainer(gpus=[int(args.GPU_NUM)])

In [None]:
trainer.test(model, dataloaders=test_dataloader)

# Feature visualization

In [None]:
import numpy as np
from sklearn.manifold import TSNE
from umap import UMAP
import matplotlib.pyplot as plt
import matplotlib.cm as cm
%matplotlib inline

In [None]:
memory_data = CIFAR10(root=args.DATA_PATH, train=True, transform=test_transform, download=True)
memory_dataloader = DataLoader(memory_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.num_workers)

In [None]:
model.model.fc = nn.Identity()
model = model.model
model.cuda(int(args.GPU_NUM))

In [None]:
feature_bank = []

with torch.no_grad():
    for data, target in memory_dataloader:
        feature = model(data.cuda(int(args.GPU_NUM), non_blocking=True))
        feature = F.normalize(feature, dim=1)
        feature_bank.append(feature)
    feature_bank = torch.cat(feature_bank, dim=0).t().contiguous()
    feature_labels = torch.tensor(memory_dataloader.dataset.targets, device=feature_bank.device)

In [None]:
feature_bank = feature_bank.detach().cpu().numpy().T
feature_labels = feature_labels.detach().cpu().numpy()

### t-SNE

In [None]:
X_embedded_tsne = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(feature_bank)

In [None]:
colors = cm.rainbow(np.linspace(0, 1, args.num_classes))

plt.figure(figsize=(10, 10))
for idx, color in zip(range(args.num_classes), colors):
    indices = np.where(feature_labels == idx)
    plt.scatter(X_embedded_tsne[indices, 0], X_embedded_tsne[indices, 1], color=color, alpha=0.1, label=f'{idx}')
plt.legend()
plt.show()

### UMAP

In [None]:
X_embedded_umap = UMAP(n_components=2).fit_transform(feature_bank)

In [None]:
colors = cm.rainbow(np.linspace(0, 1, args.num_classes))

plt.figure(figsize=(10, 10))
for idx, color in zip(range(args.num_classes), colors):
    indices = np.where(feature_labels == idx)
    plt.scatter(X_embedded_umap[indices, 0], X_embedded_umap[indices, 1], color=color, alpha=0.1, label=f'{idx}')
plt.legend()
plt.show()

### Singular value plot

In [None]:
from scipy.linalg import svd

In [None]:
C = np.cov(feature_bank.T)
_, s, _ = svd(C)
s.shape

In [None]:
plt.figure(figsize=(10, 5))

plt.subplot(121)
plt.plot(s, label='origin scale')
plt.legend()

plt.subplot(122)
plt.plot(np.log(s), label='log scale')
plt.legend()

plt.show()