import gensim.downloader as gensim_api
import numpy as np
import pandas as pd
from scipy.stats import pearsonr
import huggingface_hub
from transformers import AutoTokenizer
from transformers import AutoModelForMaskedLM
import torch
from transformers import BertModel
import sklearn.cluster
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import random

In [2]:
#This function is chatgpt generated
def read_file_to_list(filename):
    with open(filename, 'r') as file:
        # Read the entire file content
        content = file.read()

    # Split the content by double newlines
    lines = content.split('\n\n')

    # Strip any extra newlines or whitespace from each line
    lines = [line.strip() for line in lines]

    return lines

# Example usage
filename = 'problem3_data.txt'
lines_list = read_file_to_list(filename)
print("Loaded Data")

In [3]:
checkpoint = 'bert-base-uncased'

model = AutoModelForMaskedLM.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [4]:
lines_list.remove("") #remove the last line, which is empty

In [5]:
# and adapting the other function
def combined_vector_for(targetword, text, bert_input, bert_output,
                        layer = -1, word_occurrence = 0):
    """
    calculate a word vector
    based on the mean of the WordPiece vectors in the given layer.
    targetword is a word appearing in text.
    bert_input is the whole dictionary returned by the tokenizer.
    bert_output is the last-layer output obtained
    on this text.
    """
    tokenized_text = [w for w, n in tokenizer.backend_tokenizer.pre_tokenizer.pre_tokenize_str(text)]

    # where in the tokenized text (split into words but not split into word pieces)
    # do we find the target word?
    target_word_indices = [i for i, x in enumerate(tokenized_text) if x == targetword]

    # sanity check
    if len(target_word_indices) < word_occurrence + 1:
        # no occurrences found, or not enough to match the required one
        return None

    # if multiple occurrences of the target word, we use this one
    usethis_targetword_index = target_word_indices[ word_occurrence]

    # use word_to_tokens to determine the word piece span of the target word
    word_start, word_end = bert_input.word_to_tokens( usethis_targetword_index )

    # extract the embeddings from the right layer, and the target vectors from that layer
    embeddings= bert_output["hidden_states"][layer]
    # print(embeddings.shape)
    target_vectors = embeddings[0, word_start:word_end, :]

    # if we have multiple word pieces for this word, average over them
    avg_target_vector = target_vectors.mean(dim = 0)

    return avg_target_vector.detach().numpy()



In [6]:
embeddings = []
# checkpoint = 'bert-base-uncased'
for line in lines_list:

    text = line

    bert_inputs = tokenizer(text, padding=True, truncation=True,  return_tensors="pt")

    bert_output = model(bert_inputs["input_ids"], output_hidden_states = True)

    charge_output = combined_vector_for("charge", text, bert_inputs, bert_output, layer = 7, word_occurrence = 0)
    embeddings.append(charge_output)

print(len(embeddings))
# print(embeddings)

In [7]:
kmeans = sklearn.cluster.KMeans(n_clusters=5, random_state=2048).fit(embeddings)
#random_state makes clustering deterministic - it's a seed value

In [73]:
cluster_list = kmeans.predict(embeddings)
print(list(cluster_list))

[1, 1, 3, 3, 1, 0, 2, 1, 3, 0, 4, 3, 0, 2, 1, 4, 1, 2, 2, 3, 2, 3, 4, 3, 3, 2, 2, 2, 3, 3, 2, 1, 4, 0, 3, 3, 4, 2, 3, 2, 0, 3, 3, 3, 3, 0, 3, 0, 3, 4, 0, 0, 3, 1, 2, 0, 4, 1, 4, 2, 4, 2, 3, 0, 2, 4, 3, 0, 4, 3, 3, 1, 0, 0, 3, 3, 3, 3, 1, 0, 4, 0, 2, 4, 2, 1, 1, 3, 4, 3, 3, 2, 0, 0, 4, 2, 2, 0, 3, 4, 0, 2, 2, 3, 0, 0, 0, 0, 4, 2, 2, 3, 0, 0, 3, 2, 2, 0, 3, 4, 0, 0, 2, 3, 3, 3, 4, 1, 4, 4, 4, 2, 3, 1, 3, 4, 1, 1, 0, 0, 3, 1, 1, 4, 2, 3, 3, 0, 4, 4, 4, 1, 4, 2, 4, 1, 0, 2, 1, 0, 2, 3, 2, 3, 2, 2, 3, 1, 0, 3, 4, 2, 3, 2, 0, 1, 0, 2, 1, 1, 0, 2, 4, 3, 3, 0, 3, 3, 0, 3, 4, 4, 0, 1, 2, 0, 0, 0, 2, 4]


In [9]:
cluster_dict = {0:[],1:[],2:[],3:[],4:[]}
for i in range(len(cluster_list)):
  cluster_dict[cluster_list[i]].append(i)

In [68]:
outs = 3
random.seed(50)

for cluster in cluster_dict.keys():
  sents = random.sample(cluster_dict[cluster],outs)
  for sent in sents:
    print(lines_list[sent])

  print()

Gov-WHITMAN : Gabe , this is a man who has been with the FBI for 20-some-odd years , who was left in  	 charge 	  of that investigation , who was then given charge of the investigation of the TWA crash subsequent to the World Trade Center bombing , who was promoted by the FBI after the World Trade Center bombing .
SCHIEFFER : That 's Senator John Warner of Virginia who was accompanying the president there . He is in  	 charge 	  of all this as the chairman of that Senate Rules Committee . RATHER : So for Vice President Gore , as he waves ;
JOHN STOSSEL : He was scared . I would be scared . Paul Pfingst is San Diego 's district attorney . He was in  	 charge 	  of the prosecution . interviewing I could see that I in that situation , if I had a gun , I might say I 'm going to go protect my family .

You know , the -- the wrap , of course , against Phil Gramm is that he 's basically been hiding his very talented , articulate , smart wife every time he 's run for election because Americans

In [11]:
twodim = PCA().fit_transform(embeddings)[:,:2]

# set up the canvas
fig, ax =  plt.subplots()

# add a scatter plot of the two-D embeddings
scatter = ax.scatter(twodim[:,0], twodim[:,1], edgecolors='k', c=cluster_list)

# and show the canvas
plt.show()

In [29]:
replacements = []
replacements_words = []
# checkpoint = 'bert-base-uncased'
for line in lines_list:

    text = line.replace("charge", "[MASK]")

    bert_inputs = tokenizer(text, padding=True, truncation=True,  return_tensors="pt")

    # running the embedding model on the input sentence
    token_logits = model(**bert_inputs).logits
    
    # Find the location of <mask>
    mask_token_index = torch.where(bert_inputs["input_ids"] == tokenizer.mask_token_id)[1]
    mask_token_logits = token_logits[0, mask_token_index, :]

    top_token = torch.topk(mask_token_logits, 20, dim=1).indices[0].tolist()

    
    replacements_words.append(tokenizer.decode(top_token).split(' '))
    replacements.append(mask_token_logits)

In [33]:
# replacements_words

In [51]:
idx = 0
idx_to_word = {}
word_to_idx = {}

for word_list in replacements_words:
    for word in word_list:
        if not word in idx_to_word.values():
            idx_to_word[idx] = word
            word_to_idx[word] = idx
            idx += 1

In [53]:
# idx_to_word
# word_to_idx

In [76]:
replacements_vectors = []
for word_list in replacements_words:
    vector = np.zeros((len(idx_to_word)))
    for word in word_list:
        idx = word_to_idx[word]
        vector[idx] = 1
    replacements_vectors.append(vector)

In [77]:
len(embeddings)

200

In [78]:
len(replacements_vectors)

200

In [79]:
kmeans_replacements = sklearn.cluster.KMeans(n_clusters=5, random_state=2048).fit(replacements_vectors)




In [80]:
cluster_list_replacements = kmeans_replacements.predict(replacements_vectors)
print(list(cluster_list_replacements))

[4, 4, 4, 2, 4, 1, 4, 4, 4, 1, 4, 4, 0, 1, 4, 3, 4, 4, 1, 2, 4, 4, 4, 4, 4, 0, 4, 4, 2, 4, 1, 4, 2, 0, 2, 4, 4, 4, 4, 0, 0, 4, 4, 2, 2, 1, 2, 1, 4, 3, 1, 1, 4, 3, 1, 0, 4, 4, 4, 1, 3, 1, 4, 1, 0, 3, 4, 1, 3, 4, 4, 4, 1, 0, 4, 2, 2, 2, 4, 1, 3, 0, 1, 4, 4, 4, 4, 2, 2, 2, 2, 4, 0, 1, 3, 4, 0, 1, 2, 3, 0, 4, 1, 4, 1, 0, 1, 1, 4, 0, 4, 2, 1, 4, 2, 4, 4, 0, 2, 2, 1, 1, 0, 2, 4, 2, 3, 4, 3, 3, 3, 1, 4, 4, 2, 4, 4, 4, 1, 0, 2, 4, 4, 3, 4, 4, 4, 1, 2, 3, 2, 4, 4, 0, 4, 4, 1, 1, 4, 0, 0, 4, 1, 4, 4, 4, 4, 4, 0, 4, 3, 4, 2, 4, 1, 4, 0, 1, 3, 4, 1, 4, 4, 2, 4, 4, 2, 2, 4, 2, 3, 3, 0, 4, 1, 1, 1, 1, 1, 3]


In [81]:
cluster_dict_replacements = {0:[],1:[],2:[],3:[],4:[]}
for i in range(len(cluster_list_replacements)):
  cluster_dict_replacements[cluster_list_replacements[i]].append(i)

In [82]:
outs = 3
random.seed(50)

for cluster in cluster_dict_replacements.keys():
  sents = random.sample(cluster_dict_replacements[cluster],outs)
  for sent in sents:
    print(lines_list[sent])

  print()

Right now , he says people just do n't know what to do . Mr. BYRD : I mean , I do n't know . If I was in  	 charge 	  of the whole thing , I guess I could n't look around here and say , ' You guys do n't go to sea today , do n't catch any more pollack .
Now you 're talking about how choreographed this event was . It certainly came across that way . Who choreographs two super power leaders ? Who 's in  	 charge 	  of saying if you get this kind of question this is the line we have to take ? MALVEAUX : Well , sure .
GIBSON : In the BACK OF THE BOOK segment tonight , a Utah state investigation found that half the cops in the town of Hildale , Utah are polygamists , and some are even married to underage girls . Polygamy , having more than one spouse , is of course illegal . So why are the men in  	 charge 	  of enforcing the law being allowed to break it ? Joining us now from Salt Lake City is Mark Shurtleff , the attorney general of the State of Utah .

He is sponsoring anti-terrorism leg