In [1]:
import os
import copy
import pickle
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModel
from sentence_transformers import SentenceTransformer

In [2]:
hmdb_path = '../../dataset/hmdb51_vid/'
class_names = sorted(os.listdir(hmdb_path))
original_names = copy.deepcopy(class_names)
print(class_names)
print(original_names)

['brush_hair', 'cartwheel', 'catch', 'chew', 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword', 'dribble', 'drink', 'eat', 'fall_floor', 'fencing', 'flic_flac', 'golf', 'handstand', 'hit', 'hug', 'jump', 'kick', 'kick_ball', 'kiss', 'laugh', 'pick', 'pour', 'pullup', 'punch', 'push', 'pushup', 'ride_bike', 'ride_horse', 'run', 'shake_hands', 'shoot_ball', 'shoot_bow', 'shoot_gun', 'sit', 'situp', 'smile', 'smoke', 'somersault', 'stand', 'swing_baseball', 'sword', 'sword_exercise', 'talk', 'throw', 'turn', 'walk', 'wave']
['brush_hair', 'cartwheel', 'catch', 'chew', 'clap', 'climb', 'climb_stairs', 'dive', 'draw_sword', 'dribble', 'drink', 'eat', 'fall_floor', 'fencing', 'flic_flac', 'golf', 'handstand', 'hit', 'hug', 'jump', 'kick', 'kick_ball', 'kiss', 'laugh', 'pick', 'pour', 'pullup', 'punch', 'push', 'pushup', 'ride_bike', 'ride_horse', 'run', 'shake_hands', 'shoot_ball', 'shoot_bow', 'shoot_gun', 'sit', 'situp', 'smile', 'smoke', 'somersault', 'stand', 'swing_baseball', 'sword

In [3]:
for idx, name in enumerate(class_names):
    class_names[idx] = name.replace('_', ' ')
print(class_names)
bert_embed = dict()
sbert_embed = dict()    

['brush hair', 'cartwheel', 'catch', 'chew', 'clap', 'climb', 'climb stairs', 'dive', 'draw sword', 'dribble', 'drink', 'eat', 'fall floor', 'fencing', 'flic flac', 'golf', 'handstand', 'hit', 'hug', 'jump', 'kick', 'kick ball', 'kiss', 'laugh', 'pick', 'pour', 'pullup', 'punch', 'push', 'pushup', 'ride bike', 'ride horse', 'run', 'shake hands', 'shoot ball', 'shoot bow', 'shoot gun', 'sit', 'situp', 'smile', 'smoke', 'somersault', 'stand', 'swing baseball', 'sword', 'sword exercise', 'talk', 'throw', 'turn', 'walk', 'wave']


In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_model = AutoModel.from_pretrained("bert-base-uncased")

In [5]:
for idx, name in enumerate(class_names):
    curr_inp = tokenizer(name, return_tensors="pt")
    curr_out = bert_model(**curr_inp)
    bert_embed[original_names[idx]] = curr_out[0][0][-1].detach().numpy()
    print(name, original_names[idx])

brush hair brush_hair
cartwheel cartwheel
catch catch
chew chew
clap clap
climb climb
climb stairs climb_stairs
dive dive
draw sword draw_sword
dribble dribble
drink drink
eat eat
fall floor fall_floor
fencing fencing
flic flac flic_flac
golf golf
handstand handstand
hit hit
hug hug
jump jump
kick kick
kick ball kick_ball
kiss kiss
laugh laugh
pick pick
pour pour
pullup pullup
punch punch
push push
pushup pushup
ride bike ride_bike
ride horse ride_horse
run run
shake hands shake_hands
shoot ball shoot_ball
shoot bow shoot_bow
shoot gun shoot_gun
sit sit
situp situp
smile smile
smoke smoke
somersault somersault
stand stand
swing baseball swing_baseball
sword sword
sword exercise sword_exercise
talk talk
throw throw
turn turn
walk walk
wave wave


In [6]:
sbert_model = SentenceTransformer('distilbert-base-nli-mean-tokens')
class_embeddings_sbert = sbert_model.encode(class_names)
for idx, name in enumerate(original_names):
    sbert_embed[name] = class_embeddings_sbert[idx]

In [8]:
print(bert_embed['brush_hair'].shape)
print(sbert_embed['brush_hair'].shape)

(768,)
(768,)


In [9]:
pickle.dump(bert_embed, open('./metas/class_embed_bert.pkl', 'wb'))
pickle.dump(sbert_embed, open('./metas/class_embed_sbert.pkl', 'wb'))