In [None]:
import fasttext.util
import fasttext
import spacy
import pickle
from transformers import AutoTokenizer, AutoModel
import torch
# load fasttext-> takes some while to load 7GB model
nlp = spacy.load('en_core_web_lg')
fasttext.util.download_model('en', if_exists='ignore')  # English
ft = fasttext.load_model('cc.en.300.bin')

#load any hugging face model 
model_name='distilbert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
config = model.config

In [None]:
def mapLabels(str_labels, id_length, dict_name):
    # get embeddings for labels
    dictionary = {}
    for idx, label in enumerate(str_labels):
        embedding = ft.get_word_vector(label)
        dictionary[idx] = embedding
    with open('./' + dict_name + '_embeddings.pickle', 'wb') as handle:
        pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
def mapTransformerLabels(str_labels, id_length, dict_name):
    # get embeddings for labels
    dictionary = {}
    for idx, label in enumerate(str_labels):
        input_ids = torch.tensor(tokenizer.encode(label)).unsqueeze(0) # Batch size 1
        embedding = model(input_ids)[0][0][0].detach().numpy() 
        dictionary[idx] = embedding
    with open('./transformers/' + dict_name + '_embeddings.pickle', 'wb') as handle:
        pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
def mapCocoLabels(str_labels, id_length, dict_name, embedding_type):
    if embedding_type == "transformers":
    # get embeddings for labels
        dictionary = {}
        for idx1, cl in enumerate(str_labels):
            tmp = dict()
            for idx2, label in enumerate(cl): 
                input_ids = torch.tensor(tokenizer.encode(label)).unsqueeze(0) # Batch size 1
                embedding = model(input_ids)[0][0][0].detach().numpy()
                tmp[idx2] = embedding
            dictionary[idx1] = tmp
        with open('../transformers/' + dict_name + '_embeddings.pickle', 'wb') as handle:
            pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)
    else:
        dictionary = {}
        for idx1, cl in enumerate(str_labels):
            tmp = dict()
            for idx2, label in enumerate(cl): 
                embedding = ft.get_word_vector(label)
                tmp[idx2] = embedding
            dictionary[idx1] = tmp
        with open('./' + dict_name + '_embeddings.pickle', 'wb') as handle:
            pickle.dump(dictionary, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
import numpy as np
labels_mnist = ['one', 'two', 'three', 'four', 'five', 'six', 'seven', 'eight', 'nine', 'zero']
labels_cifar = ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
labels_fmnist = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag' , 'ankle boot']

str_labels = labels_cifar
id_length = len(str_labels)
mapTransformerLabels(str_labels, id_length, "cifar10")

In [None]:
labels_coco = [
    ['A red, white, and blue plane is in the sky.', 'A yellow propellor airplane is on a grassy runway.'], 
    ['Two men in orange vests are next to a black car.', 'An old car stand next to a tree.'],
    ['Black and yellow bird with colorful beak sitting on a branch.', 'A brown bird perched on top of a metal fence.'],
    ['A close shot of a cat laying on purple sheets. ', 'A black and white cat laying next to a remote control.'],
    ['A deer is crossing a street', 'Between trees, there is a deer standing'],
    ['A small white dog stands on a wooden bench.', 'A dog laying on a red couch.'],
    ['A frog is sitting on a leaf', 'In a sea there is a frog swimming'],    
    ['There is only one horse in the grassy field.', 'a white horse that is standing next to a fence'],
    ['A speed boat is docked underneath a dark, shadowy bridge.', 'A man standing on top of an orange boat on a river.'],
    ['This black and white photo shows a motorcycle', 'A green truck is driving on a street']
]

str_labels = labels_coco
id_length = len(str_labels)
mapCocoLabels(str_labels, id_length, "cifar10_extended", "transformers")