# 4.2 - visualization

In this notebook I decided to explore [visualization tool](https://github.com/jessevig/bertviz) which shows attention in NLP models.

I will use trained model to show cross attention in the model as it translates from toxic to non-toxic "language".


In [18]:
# Necessary inputs
import warnings
from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM

warnings.filterwarnings("ignore")


def translate(model, inference_request, tokenizer):
    """
    translate is helper function which allows easier inference of model

    Args:
        model (transformers.modeling_utils.PreTrainedModel): model to use for inference
        inference_request (str): input string to transform
        tokenizer (transformers.tokenization_utils.PreTrainedTokenizer): tokenizer to use for inference
    """
    prefix = "paraphrase toxic sentences:"
    input_ids = tokenizer(prefix + inference_request, return_tensors="pt").input_ids
    outputs = model.generate(input_ids=input_ids)
    return tokenizer.decode(outputs[0], skip_special_tokens=True, temperature=0)


# load model
model = AutoModelForSeq2SeqLM.from_pretrained("./../models/best")

# get tokenizer from model
tokenizer = AutoTokenizer.from_pretrained("./../models/best")

In [19]:
question = "This is my fucking house, okay?"
answer = translate(model, question, tokenizer)

answer

'this is my house, okay?'

In [20]:
encoder_input_ids = tokenizer(question, return_tensors="pt", add_special_tokens=True).input_ids
with tokenizer.as_target_tokenizer():
    decoder_input_ids = tokenizer(answer, return_tensors="pt", add_special_tokens=True).input_ids

model = AutoModel.from_pretrained("./../models/best", output_attentions=True)
outputs = model(input_ids=encoder_input_ids, decoder_input_ids=decoder_input_ids)

encoder_text = tokenizer.convert_ids_to_tokens(encoder_input_ids[0])
decoder_text = tokenizer.convert_ids_to_tokens(decoder_input_ids[0])

In [21]:
from bertviz import model_view
model_view(
    encoder_attention=outputs.encoder_attentions,
    decoder_attention=outputs.decoder_attentions,
    cross_attention=outputs.cross_attentions,
    encoder_tokens= encoder_text,
    decoder_tokens = decoder_text
)

<IPython.core.display.Javascript object>