# TAG CORPUS CONSOLIDATION

## Imports and function definition

In [1]:
import nltk
from nltk.stem import WordNetLemmatizer
from nltk.corpus import stopwords
from nltk.corpus import brown
from nltk.data import find
nltk.download('wordnet')
nltk.download('stopwords')
nltk.download('punkt')

import os, sys, time
import numpy as np
import re
import string
import random
import itertools
from itertools import chain
import torch
from transformers import BertTokenizer, BertModel, BertForMaskedLM
import logging
import collections
from collections import defaultdict
from __future__ import print_function
from __future__ import division
from scipy import stats, optimize
from utils import util, vocabulary
from google.cloud import bigquery
import pandas as pd
from gensim import models


util.require_package("tqdm")  # for nice progress bars
from tqdm import tqdm as ProgressBar

# # Bokeh for plotting.
util.require_package("bokeh")
import bokeh.plotting as bp
from bokeh.models import HoverTool
bp.output_notebook()


[nltk_data] Downloading package wordnet to
[nltk_data]     /home/ejhaselden/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/ejhaselden/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/ejhaselden/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [2]:
import w2v_model
get_stats = {}

## PART 1: SIMPLE EMBEDDING

to do:   
use loss and test function to fine-tune embeddings 
add preprocessing notebook or keep it separate?

We define functions to retrieve tag data, preprocess the results, find similar embeddings, and introduce a functional suggestion test (inspired by masked-language-model training)

In [3]:
nb_start = time.time()

lemma_list = []
word_list = []

lemmatizer = WordNetLemmatizer()
w_tokenizer = nltk.tokenize.WhitespaceTokenizer()
stop_words = set(stopwords.words('english'))


def word_preprocessing(word):
    lower = word.lower()
    punct_replacer = str.maketrans(string.punctuation, ' '*len(string.punctuation))
    rem_punct = lower.translate(punct_replacer)
    lemma = [lemmatizer.lemmatize(w) for w in nltk.word_tokenize(rem_punct)]
    rem_stop = [w for w in lemma if not w in stop_words]
    rem_digits = [re.sub('\d', '<dig>', i) for i in rem_stop]
    lemma_list.append(rem_digits)
    word_list.append(word)

In [4]:
def get_top_matches(model, test_list):
    """
    Given an embedding model and list of tags, gets most similar results based on Word2Vec embeddings 
    (model constructed in w2v_model.py).
    Runs on one row (asset) at a time.
    """
#     ref_list = []
    matches = {}
#     not_found = 0
    for lstring in test_list:
        tagset = []
        try:
            match = model.wv.most_similar(lstring)
    #         print(ls, ' : ', match)
#             ref_list.append(lstring)
            for tag in range(len(match)):
                tagset.append(match[tag][0])

#             ref_list.append(tagset)
            matches[lstring] = tagset
        except KeyError:
    #         print(ls, ' : ','NOT_FOUND')
            pass
#     print("Not found", not_found)
    return matches

# search_list = ['blue']
# get_top_matches(model, search_list)

In [5]:
def valid_prediction(test_dictionary):
    """
    Selects one random key from dictionary and determines if any values for that key
    match any other keys in the dictionary (in other words, whether the model's
    suggestion for a given tag matches any existing tags for the same asset).
    """
    
    rand = random.randint(0, len(test_dictionary) - 1)
    keylist = list(test_dictionary.keys())
    key = keylist[rand]
    suggestions = test_dictionary[key]
#     print(key, sugestions)
    matches = 0
    for suggestion in suggestions:
        for key in keylist:
            if suggestion == key:
                print("MATCH!", suggestion,  key, len(keylist))
                return("MATCH")
    return("NO MATCH")
    

We get embedding model and compute loss (for use in hyperparameter tuning)

In [6]:
#get embedding model and compute loss (for use in hyperparameter tuning)
for i in range(1,2):
    print('epoch:', i)
    start = time.time()
    epochs = i
    vec_size = 10
    window = 5
    test_df = w2v_model.retrieve_expanded_query()
#     vec_model = w2v_model.retrieve_model_no_id(epochs, vec_size, window)
        
    vec_model = models.Word2Vec(test_df.values.tolist(), vector_size=vec_size, window=window, min_count=1, workers=4, compute_loss = True, epochs = epochs)
#     end = time.time()
    loss = vec_model.get_latest_training_loss()
    # perplexity = 2**loss
    print('loss:', loss)

epoch: 1
loss: 4390248.0


We compile the set of tags for each asset. For each of those tags, we then get  a list of the most similar tags based on the W2V model.

In [7]:
##to do: Make this a function with a parameter for each type of model

# get lemmatized tag df with 1 row per asset and each tag in a separate column, covert to list of lists
test_df = w2v_model.retrieve_expanded_query()
# test_df = df_for_model
test_vals = test_df.values[0:1000]
# test_vals = test_df[0:1000]


#use top_matches method to create a dictionary of related tags suggested by embedding model
asset_dicts = []
start = time.time()
for i in range(len(test_vals)):
#     print("remaining:", len(test_vals) - i)
    test_list = test_vals[i][test_vals[i] != None]
#     rate.append(test_list)
# #     print(test_list)
    top_matches = get_top_matches(vec_model, test_list)
    asset_dicts.append(top_matches)
# #     str(test_list)
end = time.time()
print("elapsed:", end - start)
# asset_dicts

elapsed: 8.012233257293701


Now we have a list of dictionaries where each key is a tag for that asset and each set of values is a list of potential suggestions based on the W2V embeddings.


In [8]:
# asset_dicts[0].keys()

In [9]:
# asset_dicts[0].values()

### Functional Suggestion Test    
We test the effectiveness of this suggestion set by selecting a random tag from each asset and seeing if it matches any other tag assigned to that asset. In other words, if one key matches one of another key's values.    
    
to do: (Consider averaging this over a few iterations)

In [10]:
#use valid_prediction method to determine useful suggestions
asset_results = []
for i in asset_dicts:
#     print(valid_prediction(i), i.keys())
    asset_results.append(valid_prediction(i))
fst_rate = asset_results.count("MATCH")/len(asset_results)
print("rate of valid suggestions:", fst_rate)
get_stats['w2v'] = fst_rate

MATCH! forehead forehead 25
MATCH! led display led display 15
MATCH! rugby rugby 28
MATCH! public event public event 16
MATCH! head head 19
MATCH! award ceremony award ceremony 19
MATCH! light light 20
MATCH! music artist music artist 17
MATCH! accordion accordion 24
MATCH! art culture entertainment art culture entertainment 8
MATCH!   4
MATCH! musician musician 7
MATCH! beard beard 7
MATCH! music artist music artist 20
MATCH! performance art performance art 21
MATCH!   20
MATCH! ball ball 26
MATCH! celebrity celebrity 9
MATCH! bicycle frame bicycle frame 30
MATCH! neck neck 15
MATCH! mountain classic mountain classic 5
MATCH! art culture entertainment art culture entertainment 4
MATCH! idiophone idiophone 39
MATCH! art culture entertainment art culture entertainment 18
MATCH! video still video still 19
MATCH! topix topix 8
MATCH! award ceremony award ceremony 17
MATCH! art culture entertainment art culture entertainment 5
MATCH! shoulder shoulder 17
MATCH! <dig><dig><dig><dig><dig><di

MATCH! music music 10
MATCH! suit suit 10
MATCH! <dig><dig><dig><dig> <dig><dig> <dig><dig> isa <dig><dig><dig><dig> <dig><dig><dig><dig> <dig><dig> <dig><dig> isa <dig><dig><dig><dig> 21
MATCH! art culture entertainment art culture entertainment 11
MATCH! art culture entertainment art culture entertainment 3
MATCH! entertainment entertainment 26
MATCH! automotive wheel system automotive wheel system 24
MATCH! sport uniform sport uniform 24
MATCH! music music 13
MATCH! <dig><dig><dig><dig>s <dig><dig><dig><dig>s 17
MATCH! sportswear sportswear 30
MATCH! sport uniform sport uniform 26
MATCH!   8
MATCH! award ceremony award ceremony 10
MATCH! audio equipment audio equipment 18
MATCH!   19
MATCH! plate plate 15
MATCH! light light 11
MATCH! art culture entertainment art culture entertainment 13
MATCH! public event public event 17
MATCH! violet violet 9
MATCH! art culture entertainment art culture entertainment 24
MATCH! art culture entertainment art culture entertainment 5
MATCH! <dig><dig

MATCH! bfselects ftp bfselects ftp 20
MATCH! spring summer <dig><dig><dig><dig> spring summer <dig><dig><dig><dig> 10
MATCH! rock concert rock concert 18
MATCH! trouser trouser 18
MATCH! event event 15
MATCH! celebrity celebrity 23
MATCH! music music 19
MATCH! stationary stationary 23
MATCH! color image color image 7
MATCH! <dig><dig><dig><dig><dig><dig><dig><dig><dig> <dig><dig><dig><dig><dig><dig><dig><dig><dig> 16
MATCH! wheel wheel 26
MATCH! nbcu photo bank nbcu photo bank 21
MATCH! fillmore plaza fillmore plaza 16
MATCH! entertainment entertainment 14
MATCH! ozteam ozteam 6
MATCH! rugby union rugby union 25
MATCH! <dig><dig><dig><dig><dig><dig><dig><dig><dig> <dig><dig><dig><dig><dig><dig><dig><dig><dig> 10
MATCH! shoulder shoulder 12
MATCH!   22
MATCH! art culture entertainment art culture entertainment 17
MATCH! dress dress 16
MATCH! music artist music artist 18
MATCH! sunglass sunglass 21
MATCH! art culture entertainment art culture entertainment 10
MATCH! celebrity celebrity 1

In [11]:
get_stats['w2v']

0.743

## PART 2: LANGUAGE MODEL    
     
code credit: https://github.com/datasci-w266/2021-summer-main/tree/master/materials/simple_lm   

We use a simple trigram model to see if that offers increased suggestion quality, on the assumption that tags will frequently in close context with simialr tags (ie, attached to the same asset).

In reality, we found a much lower rate of useful suggestions (0.45) as compared to the simple W2V embedding model (0.75). We attempted to improve our trigram model by alphabetizing each tag list prior to training, in the hope that this would further emphasize relationships between related tags. This approach yielded an even lower valid suggestion rate (0.05). {why?}

In [12]:
get_query = w2v_model.lm_retrieve_query()
wordlist = get_query['cn'].tolist()

In [13]:
def normalize_counter(c):
    """Given a dictionary of <item, counts>, return <item, fraction>."""
    total = sum(c.values())
    return {w:float(c[w])/total for w in c}

class SimpleTrigramLM(object):
    def __init__(self, words):
        """Build our simple trigram model."""
        # Raw trigram counts over the corpus. 
        # c(w | w_1 w_2) = self.counts[(w_2,w_1)][w]
        self.counts = defaultdict(lambda: defaultdict(lambda: 0.0))
    
        # Iterate through the word stream once.
        w_1, w_2 = None, None
        for word in words:
            if w_1 is not None and w_2 is not None:
                # Increment trigram count.
                self.counts[(w_2,w_1)][word] += 1
            # Shift context along the stream of words.
            w_2 = w_1
            w_1 = word
            
        # Normalize so that for each context we have a valid probability
        # distribution (i.e. adds up to 1.0) of possible next tokens.
        self.probas = defaultdict(lambda: defaultdict(lambda: 0.0))
        for context, ctr in self.counts.items():
            self.probas[context] = normalize_counter(ctr)
            
    def next_word_proba(self, word, seq):
        """Compute p(word | seq)"""
        context = tuple(seq[-2:])  # last two words
        return self.probas[context].get(word, 0.0)
    
    def predict_next(self, seq):
        """Sample a word from the conditional distribution."""
        context = tuple(seq[-2:])  # last two words
        pc = self.probas[context]  # conditional distribution
        words, probs = zip(*pc.items())  # convert to list
        return np.random.choice(words, p=probs)
    
    def score_seq(self, seq, verbose=False):
        """Compute log probability (base 2) of the given sequence."""
        score = 0.0
        count = 0
        # Start at third word, since we need a full context.
        for i in range(2, len(seq)):
            if (seq[i] == "<s>" or seq[i] == "</s>"):
                continue  # Don't count special tokens in score.
            s = np.log2(self.next_word_proba(seq[i], seq[i-2:i]))
            score += s
            count += 1
            # DEBUG
            if verbose:
                print("log P({:s} | {:s}) = {.03f}".format(seq[i], " ".join(seq[i-2:i]), s))
        return score, count

In [14]:
import re
# Word processing functions
def canonicalize_digits(word):
    if any([c.isalpha() for c in word]): return word
    word = re.sub("\d", "DG", word)
    if word.startswith("DG"):
        word = word.replace(",", "") # remove thousands separator
    return word

def canonicalize_word(word, wordset=None, digits=True):
    word = word.lower()
    if digits:
        if (wordset != None) and (word in wordset): return word
        word = canonicalize_digits(word) # try to canonicalize numbers
    if (wordset == None) or (word in wordset):
        return word
    else:
        return constants.UNK_TOKEN

def canonicalize_words(words, **kw):
    return [canonicalize_word(word, **kw) for word in words]

In [15]:
get_query = w2v_model.lm_retrieve_query()
get_query

Unnamed: 0,asset_id,cn
0,assetnum45248253,"film industry, arts culture and entertainment,..."
1,assetnum49642909,"775637829, Basketball, Sport, Basketball Moves..."
2,assetnum46930785,"775635494, Audio Equipment, Spokesperson, Publ..."
3,assetnum31904237,"775578330, 10.21.2020 CMT Awards Fan Viewing P..."
4,assetnum50434474,"775641976, Green, Eyelash, Hair, Font, Dress, ..."
...,...,...
779397,assetnum40965793,BFfulltakes_FTP
779398,assetnum36318105,BFfulltakes_FTP
779399,assetnum33632646,BFfulltakes_FTP
779400,assetnum36318332,BFfulltakes_FTP


In [16]:
#alphabetized version
import pandas as pd
alpha_get_query = get_query[0:1000]
v = np.sort(alpha_get_query.cn.str.split(',', expand=True).fillna(''), axis=1)
df = pd.DataFrame(v).agg(','.join, 1).str.strip(',').str.lstrip()

# wordlist = df.tolist()


In [17]:
split=0.8
sentences = np.array(list(wordlist), dtype=object)
fmt = (len(sentences), sum(map(len, sentences)))
print("Loaded {:,} sentences ({:g} tokens)".format(*fmt))


rng = np.random.RandomState()
rng.shuffle(sentences)  # in-place
split_idx = int(split * len(sentences))
train_sents = sentences[:split_idx]
test_sents = sentences[split_idx:]

for l in range(len(train_sents)):
    train_sents[l] = train_sents[l].split(", ")
for l in range(len(test_sents)):
    test_sents[l] = test_sents[l].split(", ")
# train_sents = train_sents.split(",")
# test_sents = test_sents.split(",")

fmt = (len(train_sents), sum(map(len, train_sents)))
print("Training set: {:,} sentences ({:,} tokens)".format(*fmt))
fmt = (len(test_sents), sum(map(len, test_sents)))
print("Test set: {:,} sentences ({:,} tokens)".format(*fmt))


Loaded 779,402 sentences (1.43624e+08 tokens)
Training set: 623,521 sentences (9,223,017 tokens)
Test set: 155,881 sentences (2,313,182 tokens)


In [18]:
# Word processing functions
def canonicalize_digits(word):
    if any([c.isalpha() for c in word]): return word
    word = re.sub("\d", "DG", word)
    if word.startswith("DG"):
        word = word.replace(",", "") # remove thousands separator
    return word

def canonicalize_word(word, wordset=None, digits=True):
    word = word.lower()
    if digits:
        if (wordset != None) and (word in wordset): return word
        word = canonicalize_digits(word) # try to canonicalize numbers
    if (wordset == None) or (word in wordset):
        return word
    else:
        return constants.UNK_TOKEN

def canonicalize_words(words, **kw):
    return [canonicalize_word(word, **kw) for word in words]

In [19]:
vocab = vocabulary.Vocabulary(canonicalize_word(w) for w in ProgressBar(util.flatten(train_sents)))
print("Train set vocabulary: %d words" % vocab.size)

100%|██████████| 9223017/9223017 [00:28<00:00, 323456.39it/s]

Train set vocabulary: 39535 words





In [20]:
def sents_to_tokens(sents):
    """Returns an flattened list of the words in the sentences, with padding for a trigram model."""
    padded_sentences = (["<s>", "<s>"] + s + ["</s>"] for s in sents)
    # This will canonicalize words, and replace anything not in vocab with <unk>
    return np.array([util.canonicalize_word(w, wordset=vocab.wordset) 
                     for w in ProgressBar(util.flatten(padded_sentences))], dtype=object)

train_tokens = sents_to_tokens(train_sents)
test_tokens = sents_to_tokens(test_sents)

t0 = time.time()
print("Building trigram LM...",)
lm = SimpleTrigramLM(train_tokens)
print("done in %.02f s" % (time.time() - t0))

100%|██████████| 11093580/11093580 [00:10<00:00, 1051887.43it/s]
100%|██████████| 2780825/2780825 [00:02<00:00, 1039938.93it/s]


Building trigram LM...
done in 23.26 s


In [21]:
train_tokens[0:50]

array(['<s>', '<s>', 'DGDGDGDGDGDGDGDGDG', 'talent show', 'song', 'event',
       'music', 'stage', 'concert', 'fashion', 'singing', 'musician',
       'performance', 'music artist', 'public event', 'entertainment',
       'bffulltakes_ftp', 'performance art', 'performing arts',
       'DGDGDGDGDGDGDGDGDG', 'jeans', 'yellow', 'purple', '</s>', '<s>',
       '<s>', 'singing', 'musician', 'music artist', 'entertainment',
       'DGDGDGDGDGDGDGDGDG', 'pop music', 'song', 'singer', 'stage',
       'purple', 'blue', 'event', 'music', 'concert', 'performance',
       'public event', 'bffulltakes_ftp', 'performing arts', '</s>',
       '<s>', '<s>', 'DGDGDGDGDGDGDGDGDG', 'community', 'bffulltakes_ftp'],
      dtype=object)

### Generating Sample Predictions    
When we task our model with generating predictions, we do see some relevance in the results. We quantify this later using our Functional Suggestion Test.

In [22]:
def lm_predictions(l_model, tag, max_length):
        seq = ["<s>", tag]
        for i in range(max_length):
            try:
                seq.append(l_model.predict_next(seq))
            except ValueError:
                seq.append('nodata_nodata')
                
        ## dedupe list of suggested tags
        seq = set(seq)
        seq = list(seq)
        seq = [i for i in seq if i not in ['<s>','</s>', tag, 'nodata_nodata']]
        
        ## n prevents an infinite loop in the next section
        n=0
        
        ## take length of deduped list and use it to return 15 suggestions
        while len(seq) < max_length+2 and n < 50:
            try:
                seq.append(l_model.predict_next(seq))
            except ValueError:
                seq.append('nodata_nodata')
            seq = set(seq)
            seq = list(seq)
            seq = [i for i in seq if i not in ['<s>','</s>', tag, 'nodata_nodata']]
            n+=1
        seq = seq[2:]
        return seq

lm_predictions(lm, 'blue', 15)

['red',
 'fashion',
 'electric blue',
 'fedora',
 'bffulltakes_ftp',
 'visual effect lighting',
 'cap',
 'font',
 'footwear',
 'arts culture and entertainment',
 'fashion design',
 'artist',
 'awards ceremony',
 'chair',
 'idiophone']

### Scoring  
We check the perplexity and then employ the same functional suggestion test that we used for the W2V embeddings.    
We see that the rate of valid predictions for the LM is actually much lower than that of the simple embeddings.

In [23]:
log_p_data, num_real_tokens = lm.score_seq(train_tokens)
print("Train perplexity: {:.02f}".format(2**(-1*log_p_data/num_real_tokens)))

Train perplexity: 21.24


In [24]:
lengths = []
for i in asset_dicts:
    lengths.append(len(i.values()))
np.average(lengths)

14.98

In [25]:
def lm_get_top_matches(l_model, test_list):
    """
    Given a language model and list of tags, gets most similar results.
    Runs on one row (asset) at a time.
    """
#     ref_list = []
    matches = {}
#     not_found = 0
    for lstring in test_list:
        tagset = []
        try:
            match = lm_predictions(lm, lstring, 15)
            for tag in range(len(match)):
#             for tag in range(2):
                tagset.append(match[tag])

# #             ref_list.append(tagset)
            matches[lstring] = tagset
        except KeyError:
    #         print(ls, ' : ','NOT_FOUND')
            pass
#     print("Not found", not_found)
    return matches

In [26]:
##to do: Make this a function with a parameter for each type of model

# get lemmatized tag df with 1 row per asset and each tag in a separate column, covert to list of lists
test_df = w2v_model.retrieve_expanded_query()
lm_test_vals = test_df.values[0:100]
  
#use top_matches method to create a dictionary of related tags suggested by embedding model
lm_asset_dicts = []
start = time.time()
for i in range(len(lm_test_vals)):
#     print("remaining:", len(test_vals) - i)
    lm_test_list = lm_test_vals[i][lm_test_vals[i] != None]
#     rate.append(test_list)
# #     print(test_list)
    lm_top_matches = lm_get_top_matches(lm, lm_test_list)
    lm_asset_dicts.append(lm_top_matches)
# #     str(test_list)
end = time.time()
print("elapsed:", end - start)
# lm_asset_dicts

elapsed: 35.375869274139404


In [27]:
#use valid_prediction method to determine useful suggestions
lm_asset_results = []
for i in lm_asset_dicts:
#     print(valid_prediction(i), i.keys())
    lm_asset_results.append(valid_prediction(i))
    lm_fst_rate = lm_asset_results.count("MATCH")/len(lm_asset_results)
print("Rate of valid suggestions:", lm_fst_rate)
get_stats['lm'] = lm_fst_rate
# asset_results.count("MATCH")

MATCH! font font 25
MATCH! lil corky lil corky 2
MATCH! nbcu photo bank nbcu photo bank 19
MATCH! music venue music venue 20
MATCH! entertainment entertainment 24
MATCH! leisure leisure 20
MATCH! building building 11
MATCH! wheel wheel 30
MATCH! smile smile 18
MATCH! shoulder shoulder 17
MATCH! fun fun 24
MATCH! eyelash eyelash 21
MATCH! sleeve sleeve 12
MATCH! punta cana punta cana 7
MATCH! concert concert 18
MATCH! artist artist 25
MATCH! performance performance 23
MATCH! room room 13
MATCH! street fashion street fashion 24
MATCH! flash photography flash photography 22
MATCH! music music 38
MATCH! performance performance 11
MATCH! player player 25
MATCH! sleeve sleeve 17
MATCH! fashion fashion 20
MATCH! hairstyle hairstyle 9
MATCH! fashion fashion 33
MATCH! music music 16
MATCH! purple purple 24
MATCH! drinkware drinkware 16
MATCH! font font 14
MATCH! interior design interior design 12
MATCH! smile smile 17
MATCH! property property 7
MATCH! audio equipment audio equipment 10
Rate of 

## PART 3. BERT ATTEMPT  
Given that our corpus is full of unusual terms and that our "sentences" are order-agnostic, BERT's pre-trained bi-directional nature makes it a counterintuitive choice. We propose a novel application, however, in which BERT is fine-tuned on our tag corpus. As in our LM test, we order tags in our corpus alphabetically to impose a sense of word order significance.
        
BERT plan:    
1. create version of reconstructed_assets with rare tags removed, each set of tags alphabetized   
2. fine-tune BERT on that <<---- this is where I'm stuck :( 
3. give BERT an unedited tag list for a given asset, with rare tags changed to [MASK]



In [28]:
#code credit: https://gist.github.com/yuchenlin/a2f42d3c4378ed7b83de65c7a2222eb2
# !pip install torchvision 


Vanilla BERT 

In [29]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
def predict_masked_sent(text, top_k=5):
    # Tokenize input
    text = "[CLS] %s [SEP]"%text
    tokenized_text = tokenizer.tokenize(text)
    masked_index = tokenized_text.index("[MASK]")
    indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
    tokens_tensor = torch.tensor([indexed_tokens])
    # tokens_tensor = tokens_tensor.to('cuda')    # if you have gpu

    # Predict all tokens
    with torch.no_grad():
        outputs = model(tokens_tensor)
        predictions = outputs[0]

    probs = torch.nn.functional.softmax(predictions[0, masked_index], dim=-1)
    top_k_weights, top_k_indices = torch.topk(probs, top_k, sorted=True)
    preds = []
    
    for i, pred_idx in enumerate(top_k_indices):
        predicted_token = tokenizer.convert_ids_to_tokens([pred_idx])[0]
        token_weight = top_k_weights[i]
        preds.append(predicted_token)
#         print("[MASK]: '%s'"%predicted_token, " | weights:", float(token_weight))

    return preds
        

predict_masked_sent("'white', '[MASK]'")    
# predict_masked_sent("'Lighting', 'Arts Culture and Entertainment', [MASK], 'Water', 'Chandelier', 'Ceiling Fixture', 'Ceiling'", 5)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


['black', 'white', 'red', '.', 'blue']

Vanilla BERT given a test tag sequence

In [30]:
w2v_model.is_not_rare('bffulltakes ftp')

False

In [31]:
# start = time.time()
bert_df = w2v_model.retrieve_expanded_query()
bert_lite = bert_df[0:100]
bert_values = bert_lite.values


In [32]:

# def rand_mask(mask_string):
#     mask_string = mask_string[mask_string != None]
#     r = len(mask_string)
#     rand = random.randint(0,len(mask_string) - 1)
#     mask_string[rand] = '[MASK]'
#     return mask_string


# for i in range(10):
#     print(rand_mask(bert_values[i]))
        

In [33]:
# predict_masked_sent(['lil corky', '[MASK]'], 5)

In [34]:
# tdf.applymap(w2v_model.mask_rare)

In [35]:
def bert_get_top_matches(test_list):
    """
    Given a bert model and list of tags, gets most similar results.
    Runs on one row (asset) at a time.
    """
#     ref_list = []
    matches = {}
    test_list = test_list[test_list != None]
#     not_found = 0
    for itag in test_list:
        lstring = [itag, ', [MASK]']
#         print(itag, lstring)
        tagset = []
        try:
            match = predict_masked_sent(lstring, 15)
#             print(match)
            for otag in range(len(match)):
# #             for tag in range(2):
                tagset.append(match[otag])

# #             ref_list.append(tagset)
            matches[itag] = tagset
        except KeyError:
#     #         print(ls, ' : ','NOT_FOUND')
            pass
# #     print("Not found", not_found)
    return matches
bert_get_top_matches(bert_values[4])

{'<dig><dig><dig><dig><dig><dig><dig><dig><dig>': ["'",
  'and',
  ',',
  '...',
  'or',
  '&',
  '`',
  'etc',
  '.',
  '*',
  '/',
  '-',
  'x',
  ';',
  '?'],
 '<dig><dig> <dig><dig> <dig><dig><dig><dig> <dig><dig><dig><dig><dig><dig><dig><dig><dig> starz power book<dig> premiere <dig><dig><dig> image': ["'",
  'and',
  ',',
  '...',
  '&',
  'or',
  'etc',
  '*',
  '.',
  '/',
  'x',
  '`',
  '+',
  '-',
  's'],
 'circa grand opening emiller <dig><dig><dig><dig><dig><dig><dig><dig><dig>': ["'",
  'and',
  ',',
  '...',
  '.',
  'etc',
  '&',
  '`',
  '*',
  'or',
  '/',
  '-',
  'x',
  ';',
  '?'],
 'film industry': ["'",
  '...',
  '.',
  'film',
  ',',
  'cinema',
  'etc',
  'and',
  'industry',
  'magazine',
  'movie',
  'culture',
  'video',
  'movies',
  'society'],
 'electronic signage': ["'",
  'etc',
  'and',
  '...',
  ',',
  '.',
  '&',
  '*',
  '-',
  'or',
  '/',
  'text',
  '"',
  '|',
  '?'],
 'electronic device': ["'",
  '...',
  ',',
  'and',
  'etc',
  '.',
  'devi

In [36]:
###TEST TEST TEST ___________________________________


# def lm_predictions(l_model, tag, max_length):
#         seq = ["<s>", tag]
#         for i in range(max_length):
#             try:
#                 seq.append(l_model.predict_next(seq))
#             except ValueError:
#                 seq.append('nodata_nodata')
                
#         ## dedupe list of suggested tags
#         seq = set(seq)
#         seq = list(seq)
#         seq = [i for i in seq if i not in ['<s>','</s>', tag, 'nodata_nodata']]
        
#         ## n prevents an infinite loop in the next section
#         n=0
        
#         ## take length of deduped list and use it to return 15 suggestions
#         while len(seq) < max_length+2 and n < 50:
#             try:
#                 seq.append(l_model.predict_next(seq))
#             except ValueError:
#                 seq.append('nodata_nodata')
#             seq = set(seq)
#             seq = list(seq)
#             seq = [i for i in seq if i not in ['<s>','</s>', tag, 'nodata_nodata']]
#             n+=1
#         seq = seq[2:]
#         return seq

# lm_predictions(lm, 'blue', 15)

In [37]:
#use top_matches method to create a dictionary of related tags suggested by embedding model
bert_asset_dicts = []
start = time.time()
for i in range(100):
#     print("remaining:", len(test_vals) - i)
    bert_test_list = bert_values[i]
#     rate.append(test_list)
# #     print(test_list)
    bert_top_matches = bert_get_top_matches(bert_test_list)
    bert_asset_dicts.append(bert_top_matches)
# #     str(test_list)
end = time.time()
print("elapsed:", end - start)

elapsed: 247.45930361747742


In [38]:
# bert_asset_dicts

In [39]:
#use valid_prediction method to determine useful suggestions
bert_asset_results = []
for i in bert_asset_dicts:
#     print(valid_prediction(i), i.keys())
    bert_asset_results.append(valid_prediction(i))
    bert_fst_rate = bert_asset_results.count("MATCH")/len(bert_asset_results)
print("Rate of valid suggestions:", bert_fst_rate)
get_stats['vanilla_bert'] = bert_fst_rate
# asset_results.count("MATCH")

MATCH! font font 25
MATCH! person person 12
MATCH! shirt shirt 28
MATCH! dress dress 19
MATCH! artist artist 19
MATCH! music music 20
MATCH! music music 17
MATCH! musician musician 7
MATCH! style style 19
MATCH! music music 20
MATCH! fashion fashion 21
MATCH! sport sport 26
MATCH! world world 9
MATCH! vehicle vehicle 30
MATCH! corner corner 13
MATCH! tree tree 25
MATCH! happy happy 15
MATCH! studio studio 5
MATCH! artist artist 39
MATCH! product product 18
MATCH! cap cap 19
MATCH! music music 18
MATCH! temple temple 17
MATCH! speech speech 9
MATCH! transport transport 5
MATCH! basketball basketball 17
MATCH! performance performance 28
MATCH! street street 7
MATCH! music music 23
MATCH! chair chair 10
MATCH! runway runway 8
MATCH! blue blue 25
MATCH! hand hand 24
MATCH! thigh thigh 22
MATCH! artist artist 38
MATCH! orange orange 13
MATCH! stage stage 11
MATCH! finger finger 17
MATCH! fashion fashion 20
MATCH! event event 9
MATCH! hair hair 9
MATCH! player player 21
MATCH! building build

## PART 4. SUGGESTION FUNCTION IMPLEMENTATION   
We loop through the list of assets and attempt to offer alternative tags for any rare tags that we enconter.

Using the simple W2V embeddings for now

In [40]:
to_check = w2v_model.lm_retrieve_expanded_query()
# to_check

In [41]:
#given a list of tags, identify those that are rare and provide suggested replacements
def get_candidates(tag_list):
    start = time.time()
    lemma_dict = w2v_model.lemma_map()
    lemma_common_tags = []
    lemma_rare_tags = []
    candidates = []
#     cand_list = []
    for tag in tag_list:
        try:
            if w2v_model.is_not_rare(lemma_dict[tag]):
                lemma_common_tags.append(lemma_dict[tag])
                continue
            else:
                lemma_rare_tags.append(lemma_dict[tag])
        except KeyError:
            if w2v_model.is_not_rare(tag):
                continue
        try:    
            if tag is not None:
                candidate = get_top_matches(vec_model, ['', lemma_dict[tag]])
    #             print('tag: ', tag, '\nsuggestions:', candidates.values(),'\n')
#                 cand_list.append(candidate)
                candidates.append(candidate.values())
        except KeyError:
            pass
    flat = list(chain(*candidates))
    flatter = list(chain(*flat)) 
#     print('lemma common tags: ', lemma_common_tags, '\nlemma rare tags: ', lemma_rare_tags)      
    return flatter

def suggest_better_tags(list_of_tags):
    tag_candidates = get_candidates(list_of_tags)
    tag_candidates = [w2v_model.delete_rare(tc) for tc in tag_candidates]
    tag_candidates = [tag for tag in tag_candidates if tag != ""]
    
    return tag_candidates
# end = time.time()
# print(end - start)
# list_of_tags = to_check.values[2]
# list_of_tags = ["Saw"]
# zzz = suggest_better_tags(list_of_tags, 1)
# zzz
# # print(lemma_dict(list_of_tags), zzz)
# # yyy = get_candidates(list_of_tags)
# # yyy
# list_of_tags

We gather suggestions for every tag and concatenate those into a
single suggestion list.    
Then we discard uncommon (rare) suggestions and identify suggestions that were already tags for the given asset (duplicates).     
The final list of suggested tags then consists only of common tags that are not already applied to the asset.   
We also track the duplicates so they can be used to validate the usability of our suggestions.

In [42]:
'''
For each asset (row), gather suggestions for every tag and concatenate those into a
single suggestion list. Then discard uncommon (rare) suggestions and identify
suggestions that were already tags for the given asset (duplicates). The final list of suggested tags
is then only common tags that are not already applied to the asset.

This function also tracks the duplicates so they can be used to validate the usability of our suggestions.
'''

def get_real_suggestions(existing_tags):
    #GIVE THIS FUNCTION AT LEAST TWO LISTS OF TAGS OR A LIST WRAPPED IN AN EMPTY LIST
    start = time.time()
    count = len(existing_tags)
    
    all_live_suggestions = []
    all_new_suggestions = []
    all_dupe_suggestions = []
    all_weighted = []

    lemma_dict = w2v_model.lemma_map()
    for i in range(count):
        live_tags = existing_tags[i]
#         print("tags in \n", live_tags, '\n')
        live_suggestions = suggest_better_tags(live_tags)
        all_live_suggestions.append(live_suggestions)
    #     print("all suggestions \n", live_suggestions, '\n')
        new_suggestions = []
        dupe_suggestions = []
        for sug in live_suggestions:
            for tag in live_tags:
                if sug == lemma_dict.get(tag):
                    if sug not in dupe_suggestions:
                        dupe_suggestions.append(sug)
                    continue              
            if sug not in dupe_suggestions:
                new_suggestions.append(sug)
        counts = collections.Counter(new_suggestions)
        weighted = counts.most_common()
        all_weighted.append(weighted)
        
        all_dupe_suggestions.append(dupe_suggestions)
        all_new_suggestions.append(new_suggestions)
#         counts = collections.Counter(all_new_suggestions)
#         weigthed = counts.most_common()
    return [all_live_suggestions, all_dupe_suggestions, all_new_suggestions, all_weighted]
test_check = to_check.values[0:100]

start = time.time()
real_suggestions = get_real_suggestions(test_check)
end = time.time()
print(end - start)

123.96073484420776


Our function returns a batch of new legitimate suggested tags for each asset (ideally), sorted by frequency.  

In [43]:
for i in range(len(real_suggestions[0])):
    print('\n\n Asset', i, '\n\nORIGINAL TAG LIST: \n', test_check[i], '\nSUGGESTIONS: \n', real_suggestions[0][i], '\nDUPLICATES: \n', real_suggestions[1][i], '\nLEGITIMATE SUGGESTIONS:  \n', real_suggestions[2][i],  '\nBY WEIGHT:  \n', real_suggestions[3][i])



 Asset 0 

ORIGINAL TAG LIST: 
 ['film industry' 'arts culture and entertainment' 'Choreography' 'Dance'
 'celebrities' '775621438' 'awards ceremony' 'BFselects_FTP'
 'Entertainment' 'Fashion Design' 'Performing Arts' None None None None
 None None None None None None None None None None None None None None
 None None None None None None None None None None None None None None
 None None None None None None None None] 
SUGGESTIONS: 
 ['fashion', 'music', 'event', 'entertainment', 'film industry', 'fashion', 'film industry', 'music', 'fashion', 'music', 'event', 'entertainment', 'film industry', 'pink', 'fashion', 'music', 'event', 'entertainment', 'film industry', 'pink', 'fashion', 'music', 'event', 'entertainment', 'film industry', 'fashion', 'music', 'event', 'entertainment', 'purple', 'fashion', 'music', 'event', 'entertainment', 'film industry', 'film industry', 'fashion', 'red', 'blue', 'fashion', 'music', 'event', 'entertainment', 'film industry', 'music', 'event', 'fashion', 

To do: scoring/stats, etc   
We see at least one valid, common suggestion (defined as a tag that was actually attached by a user in our initial data set) for approx 80% of assets. (based on 1000 assets)

In [44]:
dupes = real_suggestions[1]
# def condition(x): return len(x) == 0
output = [idx for idx, element in enumerate(dupes) if len(element) > 0]
len(output)/len(real_suggestions[1])


0.83

Our function suggests an average of approx 9 common tags per asset.

In [45]:
new_unique = real_suggestions[3]
new_unique_lengths = []
for nu in new_unique:
    l = len(nu)
    new_unique_lengths.append(l)
sum(new_unique_lengths)/len(new_unique_lengths)


11.24

In [46]:
#random list of common tags

rand_tags = w2v_model.retrieve_rare(3)
rand_tags

['Drink', '775662392', 'Style']

In [47]:
get_stats

{'w2v': 0.743, 'lm': 0.35, 'vanilla_bert': 0.56}

In [48]:
nb_end = time.time()
print(nb_end - nb_start)

570.9700653553009
