# In-Context Learning (ICL) for Comics

- Initial config: K = 10, N = 1, embedding_model = bert-base-uncased, no attributes/context

### Libraries

In [3]:
import ast
import json
import torch
import random
import numpy as np
import pandas as pd
import torch.nn.functional as F

from pathlib import Path
from tqdm.notebook import tqdm
from operator import itemgetter
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

### Tokenizer and Model (embedding and inference)

In [4]:
embedding_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
embedding_model = AutoModel.from_pretrained("google-bert/bert-base-uncased")

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

In [5]:
model_id = "unsloth/llama-3-8b-Instruct-bnb-4bit"

In [6]:
inference_tokenizer = AutoTokenizer.from_pretrained(model_id, padding='left', padding_side='left')
inference_tokenizer.pad_token = inference_tokenizer.eos_token
terminators = [
    inference_tokenizer.eos_token_id,
    inference_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [7]:
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    #cache_dir = '/home/umushtaq/scratch/am_work/in_context_learning/model_downloads',
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


### Read Data

In [12]:
DATASET_DIR = Path(Path.cwd().as_posix()) / "emotion_analysis_comics" / "incontext_learning" / "datasets"

In [13]:
DATASET_DIR

PosixPath('/Utilisateurs/umushtaq/emotion_analysis_comics/incontext_learning/datasets')

In [14]:
all_labels = ["anger", "surprise", "fear", "disgust", "sadness", "joy", "neutral"]

In [15]:
emotion_map = {
    'AN': 'anger',
    'DI': 'disgust',
    'FE': 'fear',
    'SA': 'sadness',
    'SU': 'surprise',
    'JO': 'joy'
}

In [18]:
df = pd.read_csv(DATASET_DIR / "comics_data_processed.csv")
df = df.drop(columns=[df.columns[0], df.columns[1]]).reset_index(drop=True)

In [20]:
def extract_emotions(row):

    emotion_str = row.emotion

    if emotion_str == 'Neutral':
        return ['neutral']

    emotions = emotion_str.split('-')
    tags = []

    for emotion in emotions:
        abbrev = emotion[:2]  # Get the abbreviation
        value_part = emotion[2:]  # Get the value part
        
        # Ensure that the value part is a valid integer and abbrev is in the emotion_map
        if abbrev in emotion_map and value_part.isdigit():
            value = int(value_part)
            if value > 0:
                tags.append(emotion_map[abbrev].lower())
        else:
            print(f"Warning: Skipping invalid emotion entry: '{emotion}'")
    return tags  

In [21]:
df['emotions_list'] = df.apply(lambda row: extract_emotions(row), axis=1)

In [22]:
df.shape

(5282, 12)

In [24]:
df

Unnamed: 0,file_name,page_nr,panel_nr,balloon_nr,utterance,raw_annotation,raw_emotion,raw_speaker_id,emotion,speaker_id,split,emotions_list
0,QC copy - 1500 - 04 Nightwing 19 _Nightwing 95...,1,2,1,DID YOU HAVE TO ELECTROCUTE HER SO HARD?,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:ID-1,AN0-DI0-FE3-SA0-SU5-JO0,ID-1,TRAIN,"[fear, surprise]"
1,QC copy - 1500 - 04 Nightwing 19 _Nightwing 95...,1,2,2,IT'S NOT LIKE I HAVE DIFFERENT SETTINGS.,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:ID-2,AN0-DI0-FE0-SA0-SU5-JO0,ID-2,TRAIN,[surprise]
2,QC copy - 1500 - 04 Nightwing 19 _Nightwing 95...,1,2,3,YOU'RE ELECTROCUTIONER. IT'S YOUR WHOLE THING....,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:ID-1,AN0-DI0-FE2-SA0-SU0-JO0,ID-1,TRAIN,[fear]
3,QC copy - 1500 - 04 Nightwing 19 _Nightwing 95...,1,3,1,"OH, HEY. I THINK SHE'S AWAKE.",2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN0-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:ID-2,AN0-DI0-FE0-SA0-SU4-JO0,ID-2,TRAIN,[surprise]
4,QC copy - 1500 - 04 Nightwing 19 _Nightwing 95...,1,4,1,"WELCOME BACK, MADAM MAYOR. BLOCKBUSTER IS PRET...",2024-08-27 - aselermekova20\nFeeling:AN3-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN3-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:ID-1,AN3-DI0-FE0-SA0-SU0-JO0,ID-1,TRAIN,[anger]
...,...,...,...,...,...,...,...,...,...,...,...,...
5277,QC copy - 1499 - 58 ECC Co_mics 50 _The Jurass...,20,1,1,I KNOW THE BEINGS OF THIS WORLD ARE TRYING TO ...,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:BLACKMA...,AN5-DI0-FE0-SA0-SU0-JO0,BLACKMANTASAURUS,TEST,[anger]
5278,QC copy - 1499 - 58 ECC Co_mics 50 _The Jurass...,20,1,2,… BUT I WILL CRUSH THEM IN DUE TIME!,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:BLACKMA...,AN5-DI0-FE0-SA0-SU0-JO0,BLACKMANTASAURUS,TEST,[anger]
5279,QC copy - 1499 - 58 ECC Co_mics 50 _The Jurass...,20,2,1,FOR MY FIRST TASK...,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:BLACKMA...,AN5-DI0-FE0-SA0-SU0-JO0,BLACKMANTASAURUS,TEST,[anger]
5280,QC copy - 1499 - 58 ECC Co_mics 50 _The Jurass...,20,2,2,… I MUST REMOVE THIS WORLD OF THEIR GODS!,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-08-27 - aselermekova20\nFeeling:AN5-DI0-F...,2024-09-05 - aidaraliev12345\nSpokenBy:BLACKMA...,AN5-DI0-FE0-SA0-SU0-JO5,BLACKMANTASAURUS,TEST,"[anger, joy]"


### Get embeddings

In [26]:
utterance_embed_d = {}

for utterance in tqdm(df.utterance):
    # print(utterance)
    while True:
        try:
            inputs = embedding_tokenizer(utterance, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            utterance_embed_d[utterance] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

  0%|          | 0/5282 [00:00<?, ?it/s]

In [27]:
df['utterance_embedding'] = df.utterance.apply(lambda x: utterance_embed_d[x])

In [28]:
train_df = df[df.split == "TRAIN"].reset_index(drop=True)
test_df = df[df.split == "TEST"].reset_index(drop=True)

### Get K neighbours and prepare prompt

In [32]:
def get_k_neighbours(k, utterance):

    test_utterance_embedding = test_df[test_df.utterance == utterance]["utterance_embedding"].values[0]

    utterance_embed_d = {}
    for e in train_df.iterrows():
        if e[1].utterance not in utterance_embed_d:
            utterance_embed_d[e[1].utterance] = e[1].utterance_embedding

    # train_titles = set(df[df.split == 'TRAIN'].title.unique())
    train_utterances = set(train_df.utterance)

    dist_l = []
    for t, v in utterance_embed_d.items():
        if t in train_utterances:
            # d = cos_sim(title_embed_d[title], v)
            d = F.cosine_similarity(torch.tensor(test_utterance_embedding), torch.tensor(v), dim=0)
            dist_l.append((t, d.item()))

    sorted_dist_l = sorted(dist_l, key=itemgetter(1), reverse=True)
    
    return sorted_dist_l[0: k]

In [53]:
get_k_neighbours(3, test_df.iloc[0]["utterance"])

[("WHAT'S IT GOING TO BE, SWEET?", 0.996392011642456),
 ('WHAT ARE YOU DOING UP HERE?', 0.9960034489631653),
 ('WHAT DO YOU WANT?', 0.9958155155181885)]

In [40]:
def prepare_similar_example_prompts(utterance, k=3, seed=33):
    """
    Create a part of prompt made of k examples in the train set, whose topic is most similar to a given title.
    """

    random.seed(seed)

    neighbours_l = get_k_neighbours(2*k, utterance) # Fetch the 2*k closest neighbors
    # print(neighbours_l)
    sampled_neighbours_l = random.sample(neighbours_l, k) # Only keep k of them
    # bprint(sampled_neighbours_l)

    prompt = ''
    cnt = 0
    for i, (utterance, dist) in enumerate(sampled_neighbours_l):
        prompt += f'## Example {i+1}\n'

        example_df = train_df[train_df.utterance == utterance]
        # example_df = example_df[example_df.aty != 'none'].reset_index()
        
        class_l = []
        for k in example_df.iterrows():
            
            if k[0] == 0:

                prompt += f'# Abstract:\n{example_df.iloc[0].utterance}\n\n# Arguments:\n'
                cnt = 0
                
            # prompt += f'Argument {cnt + 1}={k[1].text} - Class={k[1].aty}\n'
            prompt += f'Utterance {cnt + 1}={k[1].utterance}\n'
            class_l.append(k[1].emotions_list)
            cnt += 1
            
        prompt += '\n# Result:\n'
        prompt += '{' + ', '.join([f'"utterance_emotions": "{class_l[i]}"' for i in range(len(class_l))]) + '}'
        prompt += '\n\n'

    return prompt

In [41]:
print(prepare_similar_example_prompts(test_df.iloc[14]["utterance"], k=10))

## Example 1
Utterance 1=THAT SO? WELL, YOU CAN JUST SUCK MY--

# Result:
{"utterance_emotions": "['anger', 'sadness', 'surprise']"}

## Example 2
Utterance 2=OH, I'LL GIVE YOU A SHELL WEDGIE!

# Result:
{"utterance_emotions": "['anger']"}

## Example 3
Utterance 3=WHAT ARE YOU DOING?! LEAVE US ALONE!

# Result:
{"utterance_emotions": "['anger', 'surprise']"}

## Example 4
Utterance 4=I'LL STILL LIVE FOREVER!

# Result:
{"utterance_emotions": "['anger']"}

## Example 5
Utterance 5=OH YEAH? MOP THE FLOOR WITH THIS!

# Result:
{"utterance_emotions": "['anger', 'surprise']"}

## Example 6
Utterance 6=OH, DO SHUT UP!

# Result:
{"utterance_emotions": "['anger', 'disgust']"}

## Example 7
Utterance 7=YOU BETTER BE HUSTLING, TRICK!

# Result:
{"utterance_emotions": "['anger', 'fear', 'sadness']"}

## Example 8
Utterance 8=HEY, SHIT FOR BRAINS!

# Result:
{"utterance_emotions": "['anger', 'disgust']"}

## Example 9
Utterance 9=… @ % # @ $ YOU TOO.

# Result:
{"utterance_emotions": "['anger']"

### Prepare test set prompts

In [70]:
experiment_df = test_df

sys_msg_l = []
task_msg_l = []

for row in tqdm(test_df.iterrows(), total=len(test_df)):
    
    #row[0] is index, row[1] is the data
    sys_msg = {"role": "system", "content": "### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You are given one utterance to classify and 3 example utterances to help you. You must absolutely not generate any text or explanation other than the following JSON format: {\"utterance_emotion\": \"<predicted emotion classes for the utterance (str)>}\"\n\n" + "### Examples:\n\n" + prepare_similar_example_prompts(row[1].utterance)}
    #sys_msg = {"role":"system", "content": "### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. " + proportion_desc + " You must absolutely not generate any text or explanation other than the following JSON format {\"Argument 1\": <predicted class for Argument 1 (str)>, ..., \"Argument n\": <predicted class for Argument n (str)>}\n\n### Class definitions:" + " Claim = " + claim_fulldesc + " Premise = " + premise_fulldesc + "\n\n### Examples:\n\n" + prepare_similar_example_prompts(title_l[i], experiment_df, k=3, seed=seed)}  # Sample by similar title
    task_msg = {"role":"user", "content": f"# Utterance:\n{row[1].utterance}\n\n# Result:\n"}
    
    sys_msg_l.append(sys_msg)
    task_msg_l.append(task_msg)
    

  0%|          | 0/1776 [00:00<?, ?it/s]

In [71]:
len(sys_msg_l)

1776

In [72]:
print(sys_msg_l[0]['content'])

### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You are given one utterance to classify and 3 example utterances to help you. You must absolutely not generate any text or explanation other than the following JSON format: {"utterance_emotion": "<predicted emotion classes for the utterance (str)>}"

### Examples:

## Example 1
Utterance 1=WHAT'S THE ANGLE, DOLL?

# Result:
{"utterance_emotions": "['anger', 'fear', 'surprise']"}

## Example 2
Utterance 2=WHAT ARE YOU DOING UP HERE?

# Result:
{"utterance_emotions": "['surprise', 'joy']"}

## Example 3
Utterance 3=WHAT THE HELL HAPPENED TO YOU?

# Result:
{"utterance_emotions": "['anger', 'surprise']"}




In [73]:
print(task_msg_l[0]["content"])

# Utterance:
HOW'S IT GOING?

# Result:



In [74]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l)):
    prepared_sys_task_msg_l.append([sys_msg_l[i], task_msg_l[i]])

In [75]:
len(prepared_sys_task_msg_l)

1776

### Run Inferences

In [85]:
inputs = inference_tokenizer.apply_chat_template(
            prepared_sys_task_msg_l,
            #tools=tools,
            # pad_token = tokenizer.eos_token,
            padding=True,
            truncation=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt",
)

In [86]:
def batch_tensor(tensor, batch_size):
    return [tensor[i:i+batch_size] for i in range(0, tensor.size(0), batch_size)]

In [87]:
BATCH_SIZE = 128

In [88]:
input_ids_batches = batch_tensor(inputs['input_ids'], BATCH_SIZE)
attention_mask_batches = batch_tensor(inputs['attention_mask'], BATCH_SIZE)

In [110]:
generated_outputs = []

for i, (input_ids_batch, attention_mask_batch) in tqdm(enumerate(zip(input_ids_batches, attention_mask_batches))):
    
    print(f"Processing batch {i + 1}")
    
    inputs = {
        'input_ids': input_ids_batch.to(generation_model.device),
        'attention_mask': attention_mask_batch.to(generation_model.device)
    }

    outputs = generation_model.generate(
    **inputs,
    max_new_tokens=64,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    #outputs = [inference_tokenizer.decode(output[input_ids.shape[-1]:], skip_special_tokens=True) for output in outputs]
    generated_outputs.append(outputs)


0it [00:00, ?it/s]

Processing batch 1
Processing batch 2
Processing batch 3
Processing batch 4
Processing batch 5
Processing batch 6
Processing batch 7
Processing batch 8
Processing batch 9
Processing batch 10
Processing batch 11
Processing batch 12
Processing batch 13
Processing batch 14


In [126]:
print(len(generated_outputs))

14


In [127]:
inputs['input_ids'].shape[1]

449

In [128]:
decoded_outputs = []

for batch in generated_outputs:

    for prediction in batch:
        # print(prediction)
        #print(prediction.shape)
        decoded_outputs.append(inference_tokenizer.decode(prediction[inputs['input_ids'].shape[1]:], skip_special_tokens=True))
        #break
        # decoded_outputs.append(inference_tokenizer.decode(prediction, skip_special_tokens=True))
        #inference_tokenizer.decode(prediction[0][inputs_ids.shape[-1]:], skip_special_tokens=True)
        #inference_tokenizer.decode(prediction.shape[[-1]:], skip_special_tokens=True)

In [129]:
len(decoded_outputs)

1776

In [130]:
decoded_outputs

['{"utterance_emotions": "[\'neutral\', \'joy\']"}',
 '{"utterance_emotions": "[\'surprise\']"}',
 '{"utterance_emotions": "[\'joy\', \'neutral\']"}',
 '{"utterance_emotions": "[\'joy\', \'neutral\']"}',
 '{"utterance_emotions": "[\'neutral\', \'joy\']"}',
 '{"utterance_emotions": "[\'joy\']"}',
 '{"utterance_emotions": "[\'surprise\', \'fear\']"}',
 '{"utterance_emotions": "[\'joy\', \'neutral\']"}',
 '{"utterance_emotions": "[\'joy\']"}',
 '{"utterance_emotions": "[\'surprise\', \'joy\']"}',
 '{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'joy\',\'surprise\']"}',
 '{"utterance_emotions": "[\'sadness\']"}',
 '{"utterance_emotions": "[\'anger\',\'surprise\']"}',
 '{"utterance_emotions": "[\'anger\',\'surprise\']"}',
 '{"utterance_emotions": "[\'neutral\', \'joy\']"}',
 '{"utterance_emotions": "[\'fear\', \'joy\']"}',
 '{"utterance_emotions": "[\'anger\', \'disgust\']"}',
 '{"utterance_emotions": "[\'anger\',\'sadness\']"}',
 '{"utterance_emotions": "[\'anger\',\

In [150]:
preds = []

#for output in outputs_l:
for prediction in decoded_outputs:
    try:
        # Use json.loads to safely parse the JSON-like string
        parsed_prediction = json.loads(prediction)
        # Append the values of the parsed prediction to preds
        preds.append(parsed_prediction['utterance_emotions'])
        
    except json.JSONDecodeError as e:
        print(f"Error decoding prediction: {e}")

In [151]:
len(preds)

1776

In [152]:
preds

["['neutral', 'joy']",
 "['surprise']",
 "['joy', 'neutral']",
 "['joy', 'neutral']",
 "['neutral', 'joy']",
 "['joy']",
 "['surprise', 'fear']",
 "['joy', 'neutral']",
 "['joy']",
 "['surprise', 'joy']",
 "['neutral']",
 "['joy','surprise']",
 "['sadness']",
 "['anger','surprise']",
 "['anger','surprise']",
 "['neutral', 'joy']",
 "['fear', 'joy']",
 "['anger', 'disgust']",
 "['anger','sadness']",
 "['anger','surprise', 'disgust']",
 "['anger', 'frustration']",
 "['surprise', 'anger']",
 "['surprise', 'curiosity']",
 "['neutral']",
 "['anger']",
 "['sadness', 'neutral']",
 "['joy','surprise']",
 "['joy']",
 "['anger', 'disgust']",
 "['anger','surprise']",
 "['anger','sadness', 'neutral']",
 "['surprise','sadness']",
 "['joy']",
 "['surprise', 'fear']",
 "['anger', 'fear', 'disgust']",
 "['surprise', 'joy']",
 "['anger','surprise']",
 "['sadness', 'apology']",
 "['sadness', 'neutral']",
 "['neutral']",
 "['anger', 'disgust']",
 "['surprise', 'disgust']",
 "['surprise', 'anger']",
 "['s

In [153]:
preds = [ast.literal_eval(item) for item in preds]

In [154]:
preds

[['neutral', 'joy'],
 ['surprise'],
 ['joy', 'neutral'],
 ['joy', 'neutral'],
 ['neutral', 'joy'],
 ['joy'],
 ['surprise', 'fear'],
 ['joy', 'neutral'],
 ['joy'],
 ['surprise', 'joy'],
 ['neutral'],
 ['joy', 'surprise'],
 ['sadness'],
 ['anger', 'surprise'],
 ['anger', 'surprise'],
 ['neutral', 'joy'],
 ['fear', 'joy'],
 ['anger', 'disgust'],
 ['anger', 'sadness'],
 ['anger', 'surprise', 'disgust'],
 ['anger', 'frustration'],
 ['surprise', 'anger'],
 ['surprise', 'curiosity'],
 ['neutral'],
 ['anger'],
 ['sadness', 'neutral'],
 ['joy', 'surprise'],
 ['joy'],
 ['anger', 'disgust'],
 ['anger', 'surprise'],
 ['anger', 'sadness', 'neutral'],
 ['surprise', 'sadness'],
 ['joy'],
 ['surprise', 'fear'],
 ['anger', 'fear', 'disgust'],
 ['surprise', 'joy'],
 ['anger', 'surprise'],
 ['sadness', 'apology'],
 ['sadness', 'neutral'],
 ['neutral'],
 ['anger', 'disgust'],
 ['surprise', 'disgust'],
 ['surprise', 'anger'],
 ['surprise', 'anger'],
 ['neutral'],
 ['anger', 'sadness'],
 ['sadness', 'disgus

In [155]:
grounds = test_df.emotions_list.tolist()

In [156]:
len(grounds)

1776

In [157]:
grounds

[['surprise', 'joy'],
 ['joy'],
 ['surprise', 'joy'],
 ['joy'],
 ['joy'],
 ['joy'],
 ['surprise'],
 ['joy'],
 ['joy'],
 ['neutral'],
 ['neutral'],
 ['neutral'],
 ['neutral'],
 ['anger', 'disgust'],
 ['anger', 'disgust'],
 ['neutral'],
 ['sadness'],
 ['sadness'],
 ['anger', 'sadness'],
 ['anger', 'sadness'],
 ['anger', 'sadness'],
 ['fear', 'surprise'],
 ['surprise'],
 ['joy'],
 ['anger', 'surprise'],
 ['joy'],
 ['joy'],
 ['joy'],
 ['anger'],
 ['anger'],
 ['surprise', 'joy'],
 ['fear', 'sadness'],
 ['fear', 'sadness'],
 ['fear', 'surprise'],
 ['anger', 'disgust'],
 ['anger', 'disgust'],
 ['anger', 'disgust'],
 ['fear', 'sadness'],
 ['fear', 'sadness', 'surprise'],
 ['sadness'],
 ['sadness'],
 ['fear', 'sadness'],
 ['sadness', 'surprise'],
 ['sadness', 'surprise'],
 ['joy'],
 ['anger'],
 ['anger'],
 ['anger'],
 ['anger', 'disgust'],
 ['joy'],
 ['joy'],
 ['surprise', 'joy'],
 ['surprise', 'joy'],
 ['anger', 'surprise'],
 ['anger', 'surprise'],
 ['neutral'],
 ['joy'],
 ['joy'],
 ['neutral'

In [158]:
import numpy as np

all_labels = ["anger", "surprise", "fear", "disgust", "sadness", "joy", "neutral"]

def labels_to_binary_matrix(label_list, all_labels):
    binary_matrix = np.zeros((len(label_list), len(all_labels)))
    for i, labels in enumerate(label_list):
        for label in labels:
            if label in all_labels:
                binary_matrix[i][all_labels.index(label)] = 1
    return binary_matrix

def opposite(component_type):

    if component_type == "anger":
        return "surprise"
    elif component_type == "disgust":
        return "joy"
    elif component_type == "fear":
        return "sadness"
    elif component_type == "sadness":
        return "anger"
    elif component_type == "surprise":
        return "disgust"
    elif component_type == "joy":
        return "fear"
    elif component_type == "Neutral":
        return "sadness"
    

def harmonize_preds(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

def post_process_zs(grounds, preds):

    for i,(x,y) in enumerate(zip(grounds, preds)):
        
        if len(x) != len(y):
            
            preds[i] = harmonize_preds(x, y)

    true_matrix = labels_to_binary_matrix(grounds, all_labels)
    predicted_matrix = labels_to_binary_matrix(preds, all_labels)

    return true_matrix, predicted_matrix

In [159]:
true_matrix, predicted_matrix = post_process_zs(grounds, preds)

In [160]:
print(classification_report(true_matrix, predicted_matrix, target_names=all_labels, digits=3))

              precision    recall  f1-score   support

       anger      0.612     0.546     0.577       614
    surprise      0.543     0.716     0.618       486
        fear      0.459     0.305     0.366       407
     disgust      0.157     0.306     0.207        85
     sadness      0.519     0.308     0.387       347
         joy      0.565     0.557     0.561       429
     neutral      0.168     0.264     0.205       129

   micro avg      0.494     0.486     0.490      2497
   macro avg      0.432     0.429     0.417      2497
weighted avg      0.514     0.486     0.490      2497
 samples avg      0.476     0.469     0.472      2497



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
