In [None]:
from datasets import load_dataset
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from transformers import DistilBertForSequenceClassification, AutoTokenizer, pipeline
import matplotlib.pyplot as plt
import seaborn as sns

from source.emotion import all_emotions
from source.data_utils import load_friends_dataset, load_tennis_dataset
from source.classification_utils import MultiLabelTextClassification, analyze_result

In [None]:
model_path = "./results/models/best"
model_describ = "distilbert-base-cased"

In [None]:
# load best model
model2 = DistilBertForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_describ)

In [None]:
labels = all_emotions
id2label = {i:label for i,label in enumerate(labels)}

In [None]:
pipeline_config = {
    "return_all_scores":True,
    "device":0    
}
inference_pipeline = MultiLabelTextClassification(model=model2, tokenizer=tokenizer, **pipeline_config)

In [None]:
zero_shot_classifier = "typeform/distilbert-base-uncased-mnli"
zero_shot_pipeline = pipeline("zero-shot-classification", device=0, model=zero_shot_classifier, tokenizer = zero_shot_classifier)

# Read Dataset

In [None]:
dataset = load_friends_dataset("data/friends-final-raw.txt")
#dataset = load_tennis_dataset()

In [None]:
person_counts = dataset["person"].value_counts()
main_characters = person_counts[person_counts > 1000].reset_index()["index"]
print(main_characters)

dataset = dataset[dataset["person"].isin(main_characters)]

In [None]:
predicted_emotions = {}
for p in main_characters:
    predicted_emotions[p] = list()

for row in tqdm(dataset.itertuples(), total=len(dataset)):
    #prediction = zero_shot_pipeline(row.line, labels, multi_label=True)
    #if len(row.line) == 1:
    #    prediction = [prediction]
    #prediction = [[{'label' : label, 'score': value} for label, value in zip(sentence['labels'], sentence['scores'])] for sentence in prediction]
    #result = analyze_result(prediction, .8)
    prediction = inference_pipeline(row.line)
    result = analyze_result(prediction, .2)
    
    
    result = [(pred['label'], pred['score']) for pred in result[0]] 
    predicted_emotions[row.person].append(result)

In [None]:
import pickle

In [None]:
with open('data/friends-classification.pickle', 'wb') as f:
    pickle.dump(predicted_emotions, f)

In [None]:
predicted_emotions = pickle.load(open('data/tennis-classification.pickle', 'rb'))

In [None]:
total_emotions_per_person = {}
for p in main_characters:
    total_emotions_per_person[p] = {}
    for l in labels:
        total_emotions_per_person[p][l] = 0

for person, sentences in predicted_emotions.items():    
    for s in sentences:
        for e in s:
            total_emotions_per_person[person][e[0]] += (e[1]/ len(sentences))

# Plotting

In [None]:
for person, emotions in total_emotions_per_person.items():
    plt.title(person)
    plt.ylim((0,0.1))
    plt.xticks(rotation='vertical')
    plt.bar(range(0, len(labels)), emotions.values(), tick_label=labels)
    plt.show()

In [None]:
df = pd.DataFrame(columns=["person", "emotion", "value"])

for person, emotions in total_emotions_per_person.items():
    for emotion, value in emotions.items():
        df.loc[len(df)] = [person,emotion,value]
        
#df = df[df["emotion"] != "neutral"]
sns.set(rc={'figure.figsize':(30,8)})
sns.set_theme(style="whitegrid")
chart = sns.barplot(x="emotion", y="value", hue="person", data=df)
chart.set_xticklabels(chart.get_xticklabels(), rotation=90)
chart.legend(bbox_to_anchor=(1.01, 1), loc=2, borderaxespad=0.)