In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from os.path import join
from tqdm import tqdm
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from fast_pytorch_kmeans import KMeans
from transformers import AutoTokenizer, AutoModel

import sys
sys.path.append("../")
from utils.viz import bokeh_2d_scatter, bokeh_2d_scatter_new

In [None]:
DATA_DIR = "/Users/piyush/datasets/NTU/"
labels_file = join(DATA_DIR, "annotations/action-clf/class_labels.txt")

In [None]:
def get_class_labels(fpath):
    with open(fpath, "rb") as f:
        lines = f.read()
        lines = lines.decode("utf-8")
        lines = lines.split("\n")
        
        class_label_dict = dict()
        for line in lines:
            if len(line):
                class_id, class_phrase, _ = line.split(".")
                class_label_dict[class_id] = class_phrase.strip()

    return class_label_dict

In [None]:
class_label_dict = get_class_labels(labels_file)

In [None]:
layers = [-4, -3, -2, -1]

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
model = AutoModel.from_pretrained("bert-base-cased", output_hidden_states=True)

In [None]:
def get_phrase_embedding(phrase, layers=[-4, -3, -2, -1], agg_method="mean"):

    encoded = tokenizer.encode_plus(phrase, return_tensors="pt")

    with torch.no_grad():
        output = model(**encoded)

    # Get all hidden states
    states = output.hidden_states

    # Stack and sum all requested layers
    output = torch.stack([states[i] for i in layers]).sum(0).squeeze()

    phrase_word_embeddings = output[1:-1]
    phrase_embedding = getattr(torch, agg_method)(phrase_word_embeddings, dim=0)
    
    return phrase_embedding

In [None]:
phrase_embedding = get_phrase_embedding('shoot at other person with a gun')

In [None]:
phrase_embedding.shape

In [None]:
embeddings = dict()
embeddings_tensor = []
class_labels = []

for k, v in tqdm(class_label_dict.items()):
    phrase_embedding = get_phrase_embedding(v)
    embeddings[k] = phrase_embedding
    embeddings_tensor.append(phrase_embedding.unsqueeze(0))
    class_labels.append(k)

embeddings_tensor = torch.cat(embeddings_tensor, dim=0)

In [None]:
embeddings_tensor.shape

In [None]:
kmeans = KMeans(n_clusters=12, mode='euclidean', verbose=1)

labels = kmeans.fit_predict(embeddings_tensor)

In [None]:
labels.shape

In [None]:
labels

In [None]:
(U, S, V) = torch.pca_lowrank(embeddings_tensor)

K = 2
Z = torch.matmul(embeddings_tensor, V[:, :K])

In [None]:
Z.shape

In [None]:
df = pd.DataFrame(None)

df["x"] = Z[:, 0].numpy()
df["y"] = Z[:, 1].numpy()
df["cluster_label"] = labels.numpy()
df["class_id"] = np.array(class_labels)
df["class_desc"] = df["class_id"].apply(lambda k: class_label_dict[k])

In [None]:
bokeh_2d_scatter_new(
    df=df, x="x", y="y", hue="cluster_label", label="class_desc", use_nb=True,
    title="BERT-based embeddings for NTU-120 action classes."
)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8, 8))
ax.grid()


unique_labels = torch.unique(labels)
for l in unique_labels:
    Z_label = Z[labels == l]
    ax.scatter(Z_label[:, 0], Z_label[:, 1], label=np.array(class_labels)[labels == l])

plt.legend()
plt.show()