In [None]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"

from bert_experiments import experiments
from compute_batch_bert_gradients import learn_embeddings

# Run a single gradient descent step and collect the gradients
args = experiments["one-step"].copy()
result = learn_embeddings(args)

In [None]:
import matplotlib.pyplot as plt
from bert_util import get_token_distribution
from bert_viz import plot_distribution

HORSE_DISTRIBUTION = get_token_distribution(["The animal that says neigh is a [MASK]"])
DOG_DISTRIBUTION = get_token_distribution(["The animal that says bark is a [MASK]"])


fig, axes = plt.subplots(1,2,figsize=(10, 3), constrained_layout=True, sharey=True)
plot_distribution(HORSE_DISTRIBUTION, 0, 6, axes[0])
plot_distribution(DOG_DISTRIBUTION, 0, 6, axes[1])


In [None]:
# Display sentences by gradient L2 norm
import numpy as np
from bert_viz import generate_saliency_html, viz_sentences_for_input_embed
saliencies = np.linalg.norm(result.last_gradient, axis=2, ord=2)
sentences = viz_sentences_for_input_embed(args, input_embeds=result.inputs_embeds_list[0])
generate_saliency_html(args, sentences, saliencies)

In [None]:
# Display tokens that have the highest cosine similarity to the gradient
import numpy as np
from bert_models import tokenizer, full_vocab_embedding_normalized
from bert_viz import generate_saliency_html, viz_sentences_for_input_embed
normalized_gradient = result.last_gradient / np.linalg.norm(result.last_gradient,axis=2,keepdims=True) # [B, N, D]
cosine_sims = np.einsum('bnd,vd->bnv', normalized_gradient, full_vocab_embedding_normalized) # [B, N, V]
closest_cosine_sim_vocab_token_idxs = np.argmax(cosine_sims,-1) # [B, N]
closest_cosine_sim_vocab_tokens = [ tokenizer.convert_ids_to_tokens(idxs) for idxs in closest_cosine_sim_vocab_token_idxs ]
saliencies = np.take_along_axis(cosine_sims, np.expand_dims(closest_cosine_sim_vocab_token_idxs, axis=-1), axis=-1)
generate_saliency_html(args, closest_cosine_sim_vocab_tokens, saliencies, mask_token="___")

In [None]:
# Display tokens that are pointed to by the gradient from the initial input embedding
from bert_viz import viz_sentences_for_input_embed, viz_sentences_for_input_embed
gradient_diff = result.last_gradient - result.inputs_embeds_list[0] # [B, N, D]
gradient_diff_normalized = gradient_diff / np.linalg.norm(gradient_diff, axis=2, keepdims=True) # [B, N, D]
cosine_sims = np.einsum('bnd,vd->bnv', normalized_gradient, full_vocab_embedding_normalized) # [B, N, V]
nearest_token_ids = np.argmax(cosine_sims,-1) # [B, N]
nearest_tokens = [ tokenizer.convert_ids_to_tokens(nearest_token_ids[i]) for i in range(len(nearest_token_ids)) ]
saliencies = np.take_along_axis(cosine_sims, np.expand_dims(closest_cosine_sim_vocab_token_idxs, axis=-1), axis=-1)
generate_saliency_html(args, nearest_tokens, saliencies, mask_token="___")


In [None]:
from bert_experiments import experiments
from compute_batch_bert_gradients import learn_embeddings
from bert_util import collect_probability_path


# Gradient descent dog->horse, 250 steps
args = experiments["horse"].copy()
result = learn_embeddings(args)
probs_path = collect_probability_path(args, result.inputs_embeds_list)

In [None]:
from bert_animate import animate_kl_divergences
animate_kl_divergences(args, probs_path)

In [None]:
#from bert_viz import animate_sentence_level_L2_distances
from bert_animate import animate_sentence_level_L2_distances
animate_sentence_level_L2_distances(args, result.inputs_embeds_list)


In [None]:
from bert_animate import animate_token_level_L2_distances
animate_token_level_L2_distances(args, result.inputs_embeds_list, sentence_idx = 0)

In [None]:
from bert_experiments import experiments
from compute_batch_bert_gradients import learn_embeddings
# Try 'eng-random-embedding' to help validate our "close to everywhere global minima" hypothesis
args = experiments['eng'].copy()
result = learn_embeddings(args)
probs_path = collect_probability_path(args, result.inputs_embeds_list)

In [None]:
from bert_viz import display_gradient_displacement
display_gradient_displacement(args, result.inputs_embeds_list, 0)

In [None]:
# Display tokens that are pointed to by the gradient from the initial input embedding
import importlib, bert_viz as viz
importlib.reload(viz)
#from bert_viz import viz_sentences_for_input_embed, viz_sentences_for_input_embed
gradient_diff = result.last_gradient - result.inputs_embeds_list[0] # [B, N, D]
gradient_diff_normalized = gradient_diff / np.linalg.norm(gradient_diff, axis=2, keepdims=True) # [B, N, D]
cosine_sims = np.einsum('bnd,vd->bnv', normalized_gradient, full_vocab_embedding_normalized) # [B, N, V]
nearest_token_ids = np.argmax(cosine_sims,-1) # [B, N]
nearest_tokens = [ tokenizer.convert_ids_to_tokens(nearest_token_ids[i]) for i in range(len(nearest_token_ids)) ]
saliencies = np.take_along_axis(cosine_sims, np.expand_dims(closest_cosine_sim_vocab_token_idxs, axis=-1), axis=-1)
print(nearest_tokens)
viz.generate_saliency_html(args, nearest_tokens, saliencies, mask_token="____")


In [None]:
# Render BERT attention viz
from bert_viz import viz_bert
viz_bert(args, result.inputs_embeds_list)

In [None]:
SELECTED_IDX = 0
from compute_batch_bert_gradients import tokenizer

def clean_token(token):
    return token.replace("Ġ", " ")


special_tokens = set(['[CLS]', '[SEP]', '[PAD]'])
def clean_sentence(sentence):
    string = ""
    for token in sentence:
        if token in special_tokens:
            continue
        if token == '[MASK]':
            token = ' ___'
        token = token.replace("Ġ", " ")
        string += token
    return string

start_inputs_embeds = input_embeds_list[0]
start_sentences = [ clean_sentence(sentence) for sentence in invert_embeddings(start_inputs_embeds)[1] ]
final_inputs_embeds = input_embeds_list[-1]
end_sentences = [ clean_sentence(sentence) for sentence in invert_embeddings(final_inputs_embeds)[1] ]


target_probs = args.target_probabilities # [N,V]
top_target_prob_idxs = np.argsort(-target_probs, axis = -1)[:,:10]
top_target_tokens = [ clean_token(token) for token in  tokenizer.convert_ids_to_tokens(top_target_prob_idxs[SELECTED_IDX,:]) ]
top_target_probs = np.take_along_axis(target_probs, top_target_prob_idxs, axis = -1)

np_mask_positions = np.asarray(args.mask_positions)


masked_token_indexed_prob_paths = np.take_along_axis(probs_path, np.expand_dims(top_target_prob_idxs, axis = 0), axis = 2)
selected_masked_token_indexed_prob_paths = masked_token_indexed_prob_paths[:,SELECTED_IDX,:]
selected_top_target_probs = top_target_probs[SELECTED_IDX,:]
selected_source_sentence = start_sentences[SELECTED_IDX]
selected_target_sentence = "The animal that says neigh is is a [ horse]"#start_sentences[args.sentence_permutation[SELECTED_IDX]]


selected_sentence_top_token_path = np.argsort(-probs_path[:,SELECTED_IDX,:], axis=1)[:,0]
selected_sentence_top_token_path = [ clean_token(token) for token in tokenizer.convert_ids_to_tokens(selected_sentence_top_token_path)]


sentences_path = [ clean_sentence(invert_embeddings(input_embeds_list[i][[SELECTED_IDX],...])[1][0]) for i in range(len(input_embeds_list)) ]

selected_top_target_token = clean_token(tokenizer.convert_ids_to_tokens(top_target_prob_idxs[SELECTED_IDX,:1])[0])


import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from IPython.display import HTML

fig, ax = plt.subplots(figsize=(8, 4))

def init():
    ax.clear()
    return []

def animate(i):
    ax.clear()
    probs = selected_masked_token_indexed_prob_paths[i]
    positions = range(len(probs))
    ax.bar(positions, selected_top_target_probs, color='gray', alpha = 0.25)
    bars = ax.bar(positions, probs, color='skyblue')
    ax.set_ylim(0, selected_masked_token_indexed_prob_paths.max() * 1.1)
    ax.set_xlabel('Tokens')
    ax.set_ylabel('Probability')
    # Token Probabilities at Step {i} \n
    ax.set_title(f'"{sentences_path[i].replace("___", f"[{selected_sentence_top_token_path[i]}]")}"')
    # \nTarget: "{selected_target_sentence.replace("___", f"[{selected_top_target_token}]")}"')
    ax.set_xticks(range(len(top_target_tokens)))
    ax.set_xticklabels(labels=top_target_tokens)
    return bars

ani = FuncAnimation(fig, animate, frames=selected_masked_token_indexed_prob_paths.shape[0], 
                   init_func=init, blit=True, interval=50)
plt.close()  # Prevents duplicate display in Jupyter
HTML(ani.to_jshtml())
