In [None]:
import matplotlib.pyplot as plt
from src.models.clip import get_model
from tsnecuda import TSNE
import numpy as np
from src.datasets.transform import load_transform
from src.datasets.utils import get_dataloader, build_iter_dataloader
import torch
from omegaconf import OmegaConf
import torch.nn.functional as F
from scipy.spatial.distance import cdist

config = OmegaConf.create(
    {
        "model": {
            "vit_base": "ViT-B-16",
            "pretrained": "/home/chuyu/vllab/clip/outputs/ViT-B-16/fgvc-aircraft/latest/checkpoint_10.pth",
        },
        "data": {
            "name": "fgvc-aircraft",
            "root": "/mnt/data/classification",
        }
    }
)

def plot_features(feat, labels=None):
    plt.figure(figsize=(8, 8))
    plt.scatter(feat[:, 0], feat[:, 1], c=labels, cmap="tab10", s=2)
    plt.title('t-SNE Visualization of Features')
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.show()

def prepare_dataloader(dataset, batch_size=32, shuffle=True, drop_last=False, mode="train", sample_num=-1):
    train_trans, eval_trans = load_transform()
    config = {
        "batch_size": batch_size,
        "shuffle": shuffle,
        "drop_last": drop_last,
        "sample_num": sample_num,
    }
    return get_dataloader(dataset, "/mnt/data/classification/", mode, train_trans if mode=="train" else eval_trans, **config)

def get_tsne_features(feats, perplexity=15, learning_rate=10):
    return TSNE(n_components=2, perplexity=perplexity, learning_rate=learning_rate).fit_transform(feats)


In [None]:
fine_model = get_model(config)

In [None]:
dataloader = prepare_dataloader("imagenet", mode="train", shuffle=False, sample_num=100_000)

In [None]:

# fine_model = load_pretrained_model()
# dataloader = prepare_dataloader("fgvc-aircraft", mode="train")

output_features = []
labels = []

dataloader.init()
with torch.no_grad():
    for x, y in dataloader:
        output_features.append(fine_model.get_features(x))
        labels.append(y)
    output_features = torch.cat(output_features, dim=0).cpu().detach().numpy()
    labels = torch.cat(labels, dim=0).flatten().cpu().detach().numpy()
print(output_features.shape)

In [None]:
dataloader.init()
for x, y in dataloader:
    print(x, y)
    break

In [None]:
mean_features = np.stack([output_features[labels==label].mean(axis=0) for label in np.unique(labels)], axis=0)

In [None]:
learnable_input = torch.load("learnable_input_with_gt_labels.pt")
learnable_input_loader = build_iter_dataloader(
    learnable_input,
    batch_size=32,
    num_workers=4,
    shuffle=False,
    drop_last=False,
    device="cuda",
)
learnable_input_loader.init()
learnable_features = []
with torch.no_grad():
    for x in learnable_input_loader:
        learnable_features.append(fine_model.get_features(x))
    learnable_features = torch.cat(learnable_features, dim=0).cpu().detach().numpy()
print(learnable_features.shape)

In [None]:
head_features = fine_model.classification_head.weight.data.cpu().detach().numpy() / 100

In [None]:
learnable_mean_features = learnable_features.mean(axis=0).reshape(1, -1)

In [None]:
cdist(output_features, head_features).min(axis=1)

In [None]:
learnable_labels = np.tile(np.arange(0, 100), (10, 1)).transpose(1, 0).reshape(-1)
all_features = np.r_[output_features, learnable_features]

tsne_features = get_tsne_features(all_features)

original_feats, learnable_feats = tsne_features[:output_features.shape[0], :], tsne_features[output_features.shape[0]:, :]
plt.figure(figsize=(8, 8))
plt.scatter(original_feats[:, 0], original_feats[:, 1], c=labels, cmap="tab10", s=2)
plt.scatter(learnable_feats[:, 0], learnable_feats[:, 1], c=learnable_labels, cmap="tab10", s=1)
plt.show()

In [None]:
learnable_input_loader.init()
with torch.no_grad():
    for x in learnable_input_loader:
        train_logits = F.softmax(fine_model(x), dim=-1)
        break

In [None]:


# random_noise = torch.randn(32, 3, 224, 224).cuda()
# with torch.no_grad():
#     logits = F.softmax(fine_model(learnable_input), dim=-1)
#     random_logits = F.softmax(fine_model(random_noise), dim=-1)

# print(logits.max(dim=-1))
# print(random_logits.max(dim=-1))
print(train_logits.max(dim=-1))

In [None]:
logits[0]