In [None]:
import os
import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig, AutoModelForMaskedLM

model_name="roberta-large"
output_path=model_name.split("-")[0]+"_"+model_name.split("-")[1]+"_"+"repo"
model_path=os.path.join("/opt/omniai/work/instance1/jupyter/", "email-complaints/fine-tune-LM",output_path)

config=AutoConfig.from_pretrained(model_path)
tokenizer=AutoTokenizer.from_pretrained(model_path)
model=AutoModelForMaskedLM.from_pretrained(model_path)

print()
print(f"The maximal # input tokens : {tokenizer.model_max_length:,}")
print(f"Vocabulary size : {tokenizer.vocab_size:,}")
print(f"The # of parameters to be updated : {sum([p.nelement() for p in model.parameters() if p.requires_grad==True]):,}")
print()

In [None]:
from IPython.display import HTML
import re
import textwrap

text="""
I certainly understand your frustration here. We are working expeditiously to settle this dispute 
within the confines of the bank’s policies and procedures related to fraudulent claims filed. Let me coordinate with the 
internal team here to see which of these times work best. This is unacceptable. We were assured that this matter would be resolved when we met on August 4 as were within the 72 hour window of 
these transactions. I would like to meet with you and your superiors who can resolve this for us and not have my team running around in 
circles. I am available on Wednesday, August 17 from 4-5pm and on Thursday, August 18 from 9-10am. Please let me know which of these times 
is convenient for you and your manager to meet with us and I shall set up a Zoom meeting. My work hours may not be yours. 
Please do not feel obligated to respond outside of your normal work hour
"""
wrapper = textwrap.TextWrapper(width=150)
display(HTML(wrapper.fill(text)))

In [None]:
keyword=['frustration','expeditiously','confines', 'unacceptable', 'withdrawals', 'dispute', 'resolved','resolve', 'fraudulent','settle','running','obligated']
for v in keyword:
    text=text.replace(v,'[MASK]')
    
wrapper = textwrap.TextWrapper(width=150)
keyword='[MASK]'
color = "green"
style="font-weight:bold;"

highlighted_text = re.sub(r'\[MASK\]', f"<span style='color:{color};{style}'>{keyword}</span>", text)
display(HTML(wrapper.fill(highlighted_text)))

In [None]:
text="""
I certainly understand your frustration here. We are working expeditiously to settle this dispute 
within the confines of the bank’s policies and procedures related to fraudulent claims filed. Let me coordinate with the 
internal team here to see which of these times work best. This is unacceptable. We were assured that this matter would be resolved when we met on August 4 as were within the 72 hour window of 
these transactions. I would like to meet with you and your superiors who can resolve this for us and not have my team running around in 
circles. I am available on Wednesday, August 17 from 4-5pm and on Thursday, August 18 from 9-10am. Please let me know which of these times 
is convenient for you and your manager to meet with us and I shall set up a Zoom meeting. My work hours may not be yours. 
Please do not feel obligated to respond outside of your normal work hour
"""
keyword=['frustration','expeditiously','confines', 'unacceptable', 'withdrawals', 'dispute', 'resolved','resolve', 'fraudulent','settle','running','obligated']
for v in keyword:
    text=text.replace(v,tokenizer.mask_token)
print(text)

In [None]:
encoded_text=tokenizer.encode(text,add_special_tokens=False)
mask_positions=[i for i,x in enumerate(encoded_text) if x==tokenizer.mask_token_id]

In [None]:
%pdb
input_ids=encoded_text.copy()
for mask_position in mask_positions:
    input_ids[mask_position]=tokenizer.mask_token_id
    # tokens=tokenizer.convert_ids_to_tokens(input_ids)
    # tokens_str=" ".join(tokens)
    # input_ids=torch.tensor([input_ids])
    with torch.no_grad():
        outputs=model(torch.tensor([input_ids]))
        predictions=outputs[0][0][mask_position].topk(5).indices.tolist()
        pred=torch.tensor([np.random.choice(predictions)])
        # print(mask_position,input_ids[mask_position], pred)
        input_ids[mask_position]=pred
    
    
    # predicted_tokens=tokenizer.decode(torch.tensor([np.random.choice(predictions)]))
    # tokens_str=tokens_str.replace(tokenizer.mask_token,predicted_token)
    
print(tokenizer.decode(torch.tensor([input_ids]).squeeze()))