In [None]:
# comment out to run on drive + colab 

# ! pip install transformers_interpret

In [None]:
# from google.colab import drive
# 
# drive.mount('/content/drive/')

In [None]:
import pandas as pd
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers_interpret import MultiLabelClassificationExplainer
import torch
from transformers import pipeline
from datasets import load_from_disk

import warnings
warnings.filterwarnings('ignore')

In [None]:
import gc

gc.collect()

torch.cuda.empty_cache()

In [None]:
# comment out the filepath to run on drive + colab 
filepath = "../models/scenario_level/anomic_wernicke_control"
# filepath = "/content/drive/Shareddrives/AphasiaProject/models/scenario_level/anomic_wernicke_control"

model = AutoModelForSequenceClassification.from_pretrained(filepath+"/model")
tokenizer = AutoTokenizer.from_pretrained(filepath+"/tokenizer", padding=True, 
                                          truncation=True, return_tensors="pt", 
                                          add_special_tokens=True, max_length=512)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == torch.device("cuda"):
    model.to(device)
pipe = pipeline("text-classification", model=filepath+"/model", tokenizer=filepath+"/tokenizer", device=device, truncation=True, padding=True)
cls_explainer = MultiLabelClassificationExplainer(model, tokenizer)

In [None]:
model.config

In [None]:
model.config.id2label

In [None]:
def interpret(data):
    for text, label in zip(list(data["new_preprocessed_text"]), list(data["label"])):
        if device == torch.device("cuda"):
            inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True, max_length=512).input_ids.to(device) 
            # tranformers interpret gives error for input of 512 tokens
            if len(inputs[0]) != 512:
                with torch.no_grad():
                    logits = model(inputs).logits
                    
                predicted_class_id = logits.argmax().item()
                print("True: ", label, "Pred: ", model.config.id2label[predicted_class_id])
                print("Text: ", text)
                cls_explainer(text)
                cls_explainer.visualize()
                print("----------------------------------------------------------------------")
        else:
            inputs = tokenizer(text, padding=True, truncation=True, return_tensors="pt", add_special_tokens=True, max_length=512) 
            # tranformers interpret gives error for input of 512 tokens
            if len(inputs[0]) != 512:    
                with torch.no_grad():
                    logits = model(**inputs).logits
                    
                predicted_class_id = logits.argmax().item()
                print("True: ", label, "Pred: ", model.config.id2label[predicted_class_id])
                print("Text: ", text)
                cls_explainer(text)
                cls_explainer.visualize()
                print("----------------------------------------------------------------------")

In [None]:
scenarios = ["Speech", "Important_Event", "Cinderella", "Stroke", "Cat"]
# make sure not to interpret on trained scenarios 
texts = load_from_disk(filepath + "/dataset")["test"].remove_columns(["Unnamed: 0", "input_ids", "attention_mask"]).to_pandas()
texts["label"] = [model.config.id2label[x] for x in texts["label"]]
n = 3

# Speech scenario interpretation
Three examples for conduction and three for control (# of examples defined by n)

In [None]:
data = pd.DataFrame(columns=['label','new_preprocessed_text','scenario'])

for x in model.config.label2id.keys():
    data = pd.concat([data, texts[(texts["label"] == x) & (texts["scenario"] == scenarios[0])].sample(n=n)])

In [None]:
interpret(data)

# Important event scenario interpretation
Three examples for conduction and three for control (# of examples defined by n)

In [None]:
data = pd.DataFrame(columns=['label','new_preprocessed_text','scenario'])

for x in model.config.label2id.keys():
    data = pd.concat([data, texts[(texts["label"] == x) & (texts["scenario"] == scenarios[1])].sample(n=n)])

In [None]:
interpret(data)

# Cinderella scenario interpretation
Three examples for conduction and three for control (# of examples defined by n)

In [None]:
data = pd.DataFrame(columns=['label','new_preprocessed_text','scenario'])

for x in model.config.label2id.keys():
    data = pd.concat([data, texts[(texts["label"] == x) & (texts["scenario"] == scenarios[2])].sample(n=n)])

In [None]:
interpret(data)

# Stroke scenario interpretation
Three examples for conduction and three for control (# of examples defined by n)

In [None]:
data = pd.DataFrame(columns=['label','new_preprocessed_text','scenario'])

for x in model.config.label2id.keys():
    df = texts[(texts["label"] == x) & (texts["scenario"] == scenarios[3])]
    # no controls in stroke scenario
    if len(df) >= n:
        data = pd.concat([data, texts[(texts["label"] == x) & (texts["scenario"] == scenarios[3])].sample(n=n)])

In [None]:
interpret(data)

# Cat scenario interpretation
Three examples for conduction and three for control (# of examples defined by n)

In [None]:
data = pd.DataFrame(columns=['label','new_preprocessed_text','scenario'])

for x in model.config.label2id.keys():
    data = pd.concat([data, texts[(texts["label"] == x) & (texts["scenario"] == scenarios[4])].sample(n=n)])

In [None]:
interpret(data)