# In-Context Learning (ICL) for Comics

- Initial config: K = 10, N = 1, embedding_model = bert-base-uncased, no attributes/context

### Libraries

In [1]:
import ast
import json
import torch
import random
import pickle
import numpy as np
import pandas as pd
import torch.nn.functional as F

from pathlib import Path
from tqdm.notebook import tqdm
from operator import itemgetter
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

### Tokenizer and Model (embedding and inference)

In [2]:
embedding_tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
embedding_model = AutoModel.from_pretrained("google-bert/bert-base-uncased")



In [3]:
# model_id = "unsloth/llama-3-8b-Instruct-bnb-4bit"
#model_id = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
model_id = "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit"


In [4]:
inference_tokenizer = AutoTokenizer.from_pretrained(model_id, padding='left', padding_side='left')
inference_tokenizer.pad_token = inference_tokenizer.eos_token
terminators = [
    inference_tokenizer.eos_token_id,
    inference_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [5]:
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    #cache_dir = '/home/umushtaq/scratch/am_work/in_context_learning/model_downloads',
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


model.safetensors.index.json:   0%|          | 0.00/331k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/6 [00:00<?, ?it/s]

model-00001-of-00006.safetensors:   0%|          | 0.00/7.00G [00:00<?, ?B/s]

model-00002-of-00006.safetensors:   0%|          | 0.00/6.90G [00:00<?, ?B/s]

model-00003-of-00006.safetensors:   0%|          | 0.00/6.94G [00:00<?, ?B/s]

model-00004-of-00006.safetensors:   0%|          | 0.00/6.94G [00:00<?, ?B/s]

model-00005-of-00006.safetensors:   0%|          | 0.00/6.99G [00:00<?, ?B/s]

model-00006-of-00006.safetensors:   0%|          | 0.00/4.75G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/234 [00:00<?, ?B/s]

### Read Data

In [6]:
DATASET_DIR = Path(Path.cwd().as_posix()) / "emotion_analysis_comics" / "incontext_learning" / "datasets"

In [7]:
DATASET_DIR

PosixPath('/Utilisateurs/umushtaq/emotion_analysis_comics/incontext_learning/datasets')

In [8]:
all_labels = ["anger", "surprise", "fear", "disgust", "sadness", "joy", "neutral"]

In [9]:
emotion_map = {
    'AN': 'anger',
    'DI': 'disgust',
    'FE': 'fear',
    'SA': 'sadness',
    'SU': 'surprise',
    'JO': 'joy'
}

In [10]:
df = pd.read_csv(DATASET_DIR / "comics_data_processed.csv")
df = df.drop(columns=[df.columns[0], df.columns[1]]).reset_index(drop=True)

In [11]:
def extract_emotions(row):

    emotion_str = row.emotion

    if emotion_str == 'Neutral':
        return ['neutral']

    emotions = emotion_str.split('-')
    tags = []

    for emotion in emotions:
        abbrev = emotion[:2]  # Get the abbreviation
        value_part = emotion[2:]  # Get the value part
        
        # Ensure that the value part is a valid integer and abbrev is in the emotion_map
        if abbrev in emotion_map and value_part.isdigit():
            value = int(value_part)
            if value > 0:
                tags.append(emotion_map[abbrev].lower())
        else:
            print(f"Warning: Skipping invalid emotion entry: '{emotion}'")
    return tags  

In [12]:
df['emotions_list'] = df.apply(lambda row: extract_emotions(row), axis=1)

In [13]:
df.shape

(5282, 12)

### Get embeddings

In [15]:
utterance_embed_d = {}

for utterance in tqdm(df.utterance):
    # print(utterance)
    while True:
        try:
            inputs = embedding_tokenizer(utterance, return_tensors="pt")
            output = embedding_model(**inputs)
            embedding = output[1][0].squeeze()
            utterance_embed_d[utterance] = embedding.detach().numpy()
            break
        except Exception as e:
            print(e)

  0%|          | 0/5282 [00:00<?, ?it/s]

In [16]:
df['utterance_embedding'] = df.utterance.apply(lambda x: utterance_embed_d[x])

In [14]:
train_df = df[df.split == "TRAIN"].reset_index(drop=True)
test_df = df[df.split == "TEST"].reset_index(drop=True)

### Get K neighbours and prepare prompt

In [18]:
def get_k_neighbours(k, utterance):

    test_utterance_embedding = test_df[test_df.utterance == utterance]["utterance_embedding"].values[0]

    utterance_embed_d = {}
    for e in train_df.iterrows():
        if e[1].utterance not in utterance_embed_d:
            utterance_embed_d[e[1].utterance] = e[1].utterance_embedding

    # train_titles = set(df[df.split == 'TRAIN'].title.unique())
    train_utterances = set(train_df.utterance)

    dist_l = []
    for t, v in utterance_embed_d.items():
        if t in train_utterances:
            # d = cos_sim(title_embed_d[title], v)
            d = F.cosine_similarity(torch.tensor(test_utterance_embedding), torch.tensor(v), dim=0)
            dist_l.append((t, d.item()))

    sorted_dist_l = sorted(dist_l, key=itemgetter(1), reverse=True)
    
    return sorted_dist_l[0: k]

In [19]:
get_k_neighbours(10, test_df.iloc[10]["utterance"])

[('… YES?', 0.9947154521942139),
 ('… POR FAVOR…', 0.9920827150344849),
 ('HEY!', 0.9919130802154541),
 ("NO… NO… YOU CAN'T CONTROL THEM!", 0.9913744330406189),
 ('UM… ACTUALLY…', 0.9905799627304077),
 ('… th-the…', 0.990537703037262),
 ('" DON\'T GO INTO THE LIGHT… "', 0.9902849197387695),
 ('But today…', 0.9900076985359192),
 ("IF HE'D KILLED ANDREA… I DON'T KNOW…", 0.9897928833961487),
 ('QUIET… OH NO…', 0.9897283315658569)]

In [20]:
def prepare_similar_example_prompts(utterance, k=10, seed=33):
    """
    Create a part of prompt made of k examples in the train set, whose topic is most similar to a given title.
    """

    random.seed(seed)

    neighbours_l = get_k_neighbours(2*k, utterance) # Fetch the 2*k closest neighbors
    # print(neighbours_l)
    sampled_neighbours_l = random.sample(neighbours_l, k) # Only keep k of them
    # bprint(sampled_neighbours_l)

    prompt = ''
    cnt = 0
    for i, (utterance, dist) in enumerate(sampled_neighbours_l):
        prompt += f'## Example {i+1}\n'

        example_df = train_df[train_df.utterance == utterance]
        # example_df = example_df[example_df.aty != 'none'].reset_index()
        
        class_l = []
        for k in example_df.iterrows():
            
            # if k[0] == 0:

            #     prompt += f'# Abstract:\n{example_df.iloc[0].utterance}\n\n# Arguments:\n'
            #     cnt = 0
                
            # prompt += f'Argument {cnt + 1}={k[1].text} - Class={k[1].aty}\n'
            prompt += f'Utterance {cnt + 1}={k[1].utterance}\n'
            class_l.append(k[1].emotions_list)
            cnt += 1
            
        prompt += '\n# Result:\n'
        prompt += '{' + ', '.join([f'"utterance_emotions": "{class_l[i]}"' for i in range(len(class_l))]) + '}'
        prompt += '\n\n'

    return prompt

In [21]:
print(prepare_similar_example_prompts(test_df.iloc[10]["utterance"], k=10))

## Example 1
Utterance 1=… LANDED

# Result:
{"utterance_emotions": "['surprise']"}

## Example 2
Utterance 2=… th-the…

# Result:
{"utterance_emotions": "['neutral']"}

## Example 3
Utterance 3=But today…

# Result:
{"utterance_emotions": "['sadness']"}

## Example 4
Utterance 4=IF HE'D KILLED ANDREA… I DON'T KNOW…

# Result:
{"utterance_emotions": "['fear', 'sadness', 'surprise']"}

## Example 5
Utterance 5=SHE… WE COULD STILL…

# Result:
{"utterance_emotions": "['sadness']"}

## Example 6
Utterance 6=BUT FLORIAN… I…

# Result:
{"utterance_emotions": "['fear']"}

## Example 7
Utterance 7=WHAT THE… CAN'T… MOVE MY…

# Result:
{"utterance_emotions": "['fear', 'surprise']"}

## Example 8
Utterance 8=THAT WAS…

# Result:
{"utterance_emotions": "['surprise']"}

## Example 9
Utterance 9=HEY!
Utterance 10=HEY!
Utterance 11=HEY!

# Result:
{"utterance_emotions": "['surprise']", "utterance_emotions": "['anger', 'surprise']", "utterance_emotions": "['anger']"}

## Example 10
Utterance 12=… that

### Prepare test set prompts

In [22]:
experiment_df = test_df

sys_msg_l = []
task_msg_l = []

for row in tqdm(test_df.iterrows(), total=len(test_df)):
    
    #row[0] is index, row[1] is the data
    sys_msg = {"role": "system", "content": "### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You are given one utterance to classify and 3 example utterances to help you. You must absolutely not generate any text or explanation other than the following JSON format: {\"utterance_emotion\": \"<predicted emotion classes for the utterance (str)>}\"\n\n" + "### Examples:\n\n" + prepare_similar_example_prompts(row[1].utterance)}
    #sys_msg = {"role":"system", "content": "### Task description: You are an expert biomedical assistant that takes 1) an abstract text, 2) the list of all arguments from this abstract text, and must classify all arguments into one of two classes: Claim or Premise. " + proportion_desc + " You must absolutely not generate any text or explanation other than the following JSON format {\"Argument 1\": <predicted class for Argument 1 (str)>, ..., \"Argument n\": <predicted class for Argument n (str)>}\n\n### Class definitions:" + " Claim = " + claim_fulldesc + " Premise = " + premise_fulldesc + "\n\n### Examples:\n\n" + prepare_similar_example_prompts(title_l[i], experiment_df, k=3, seed=seed)}  # Sample by similar title
    task_msg = {"role":"user", "content": f"# Utterance:\n{row[1].utterance}\n\n# Result:\n"}
    
    sys_msg_l.append(sys_msg)
    task_msg_l.append(task_msg)
    

  0%|          | 0/1776 [00:00<?, ?it/s]

In [23]:
len(sys_msg_l)

1776

In [24]:
print(sys_msg_l[10]['content'])

### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You are given one utterance to classify and 3 example utterances to help you. You must absolutely not generate any text or explanation other than the following JSON format: {"utterance_emotion": "<predicted emotion classes for the utterance (str)>}"

### Examples:

## Example 1
Utterance 1=… LANDED

# Result:
{"utterance_emotions": "['surprise']"}

## Example 2
Utterance 2=… th-the…

# Result:
{"utterance_emotions": "['neutral']"}

## Example 3
Utterance 3=But today…

# Result:
{"utterance_emotions": "['sadness']"}

## Example 4
Utterance 4=IF HE'D KILLED ANDREA… I DON'T KNOW…

# Result:
{"utterance_emotions": "['fear', 'sadness', 'surprise']"}

## Example 5
Utterance 5=SHE… WE COULD STILL…

# Result:
{"utterance_emotions": "['sadness']"}

## Example 6
Utte

In [25]:
print(task_msg_l[10]["content"])

# Utterance:
@… IN A FAR OFF KINGDOM…

# Result:



In [26]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l)):
    prepared_sys_task_msg_l.append([sys_msg_l[i], task_msg_l[i]])

In [27]:
len(prepared_sys_task_msg_l)

1776

In [14]:
with open(DATASET_DIR / 'tmp_prepared_sys_task_msg_l.pkl', 'wb') as f:
    
    pickle.dump(prepared_sys_task_msg_l, f)

### Run Inferences

In [15]:
# Load the pickled list
with open(DATASET_DIR / 'tmp_prepared_sys_task_msg_l.pkl', 'rb') as f:
    
    prepared_sys_task_msg_l = pickle.load(f)


In [16]:
inputs = inference_tokenizer.apply_chat_template(
            prepared_sys_task_msg_l,
            #tools=tools,
            # pad_token = tokenizer.eos_token,
            padding=True,
            truncation=True,
            add_generation_prompt=True,
            return_dict=True,
            return_tensors="pt",
)

In [17]:
inputs

{'input_ids': tensor([[128009, 128009, 128009,  ...,  78191, 128007,    271],
        [128009, 128009, 128009,  ...,  78191, 128007,    271],
        [128009, 128009, 128009,  ...,  78191, 128007,    271],
        ...,
        [128009, 128009, 128009,  ...,  78191, 128007,    271],
        [128009, 128009, 128009,  ...,  78191, 128007,    271],
        [128009, 128009, 128009,  ...,  78191, 128007,    271]]), 'attention_mask': tensor([[0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        ...,
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1],
        [0, 0, 0,  ..., 1, 1, 1]])}

In [18]:
def batch_tensor(tensor, batch_size):
    return [tensor[i:i+batch_size] for i in range(0, tensor.size(0), batch_size)]

In [22]:
BATCH_SIZE = 32

In [23]:
input_ids_batches = batch_tensor(inputs['input_ids'], BATCH_SIZE)
attention_mask_batches = batch_tensor(inputs['attention_mask'], BATCH_SIZE)

In [24]:
CUDA_LAUNCH_BLOCKING=0

In [25]:
generated_outputs = []

for i, (input_ids_batch, attention_mask_batch) in tqdm(enumerate(zip(input_ids_batches, attention_mask_batches)), total=len(input_ids_batches)):
    
    print(f"Processing batch {i + 1}")
    
    inputs = {
        'input_ids': input_ids_batch.to(generation_model.device),
        'attention_mask': attention_mask_batch.to(generation_model.device)
    }

    outputs = generation_model.generate(
    **inputs,
    max_new_tokens=64,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    #outputs = [inference_tokenizer.decode(output[input_ids.shape[-1]:], skip_special_tokens=True) for output in outputs]
    generated_outputs.append(outputs)
    # del inputs


  0%|          | 0/56 [00:00<?, ?it/s]

Processing batch 1
Processing batch 2
Processing batch 3
Processing batch 4
Processing batch 5
Processing batch 6
Processing batch 7
Processing batch 8
Processing batch 9
Processing batch 10
Processing batch 11
Processing batch 12
Processing batch 13
Processing batch 14
Processing batch 15
Processing batch 16
Processing batch 17
Processing batch 18
Processing batch 19
Processing batch 20
Processing batch 21
Processing batch 22
Processing batch 23
Processing batch 24
Processing batch 25
Processing batch 26
Processing batch 27
Processing batch 28
Processing batch 29
Processing batch 30
Processing batch 31
Processing batch 32
Processing batch 33
Processing batch 34
Processing batch 35
Processing batch 36
Processing batch 37
Processing batch 38
Processing batch 39
Processing batch 40
Processing batch 41
Processing batch 42
Processing batch 43
Processing batch 44
Processing batch 45
Processing batch 46
Processing batch 47
Processing batch 48
Processing batch 49
Processing batch 50
Processin

In [26]:
print(len(generated_outputs))

56


In [27]:
inputs['input_ids'].shape[1]

963

In [28]:
decoded_outputs = []

for batch in generated_outputs:

    for prediction in batch:
        # print(prediction)
        #print(prediction.shape)
        decoded_outputs.append(inference_tokenizer.decode(prediction[inputs['input_ids'].shape[1]:], skip_special_tokens=True))
        #break
        # decoded_outputs.append(inference_tokenizer.decode(prediction, skip_special_tokens=True))
        #inference_tokenizer.decode(prediction[0][inputs_ids.shape[-1]:], skip_special_tokens=True)
        #inference_tokenizer.decode(prediction.shape[[-1]:], skip_special_tokens=True)

In [29]:
len(decoded_outputs)

1776

In [30]:
decoded_outputs

['{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'joy\', \'neutral\']"}',
 '{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'joy\']"}',
 '{"utterance_emotions": "[\'fear\', \'concern\']"}',
 '{"utterance_emotions": "[\'joy\']"}',
 '{"utterance_emotions": "[\'joy\']"}',
 '{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'neutral\']"}',
 '{"utterance_emotions": "[\'joy\']"}',
 '{"utterance_emotions": "[\'anger\']"}',
 '{"utterance_emotions": "[\'anger\',\'surprise\']"}',
 '{"utterance_emotions": "[\'anger\', \'disgust\']"}',
 '{"utterance_emotions": "[\'joy\',\'surprise\']"}',
 '{"utterance_emotions": "[\'joy\',\'surprise\']"}',
 '{"utterance_emotions": "[\'anger\',\'sadness\']"}',
 '{"utterance_emotions": "[\'anger\']"}',
 '{"utterance_emotions": "[\'anger\', \'disgust\']"}',
 '{"utterance_emotions": "[\'anger\']"}',
 '{"utterance_emot

In [31]:
preds = []

#for output in outputs_l:
for i, prediction in enumerate(decoded_outputs):
    try:
        # Use json.loads to safely parse the JSON-like string
        parsed_prediction = json.loads(prediction)
        # Append the values of the parsed prediction to preds
        preds.append(parsed_prediction['utterance_emotions'])
        
    except json.JSONDecodeError as e:
        print(f"Error decoding prediction: {i}")

In [32]:
len(preds)

1776

In [33]:
preds

["['neutral']",
 "['joy', 'neutral']",
 "['neutral']",
 "['neutral']",
 "['neutral']",
 "['joy']",
 "['fear', 'concern']",
 "['joy']",
 "['joy']",
 "['neutral']",
 "['neutral']",
 "['joy']",
 "['anger']",
 "['anger','surprise']",
 "['anger', 'disgust']",
 "['joy','surprise']",
 "['joy','surprise']",
 "['anger','sadness']",
 "['anger']",
 "['anger', 'disgust']",
 "['anger']",
 "['anger']",
 "['surprise']",
 "['surprise']",
 "['surprise']",
 "['sadness', 'neutral']",
 "['joy','surprise']",
 "['joy']",
 "['anger', 'disgust']",
 "['anger']",
 "['anger', 'joy']",
 "['surprise']",
 "['joy']",
 "['surprise']",
 "['anger', 'disgust']",
 "['anger']",
 "['anger', 'joy']",
 "['sadness']",
 "['fear','sadness']",
 "['joy']",
 "['anger', 'disgust']",
 "['anger','surprise']",
 "['anger', 'frustration']",
 "['anger', 'frustration']",
 "['joy']",
 "['anger']",
 "['anger']",
 "['anger']",
 "['anger', 'disgust']",
 "['joy']",
 "['neutral']",
 "['anger', 'fear']",
 "['joy']",
 "['anger','surprise']",
 "['

In [34]:
preds = [ast.literal_eval(item) for item in preds]

In [35]:
preds

[['neutral'],
 ['joy', 'neutral'],
 ['neutral'],
 ['neutral'],
 ['neutral'],
 ['joy'],
 ['fear', 'concern'],
 ['joy'],
 ['joy'],
 ['neutral'],
 ['neutral'],
 ['joy'],
 ['anger'],
 ['anger', 'surprise'],
 ['anger', 'disgust'],
 ['joy', 'surprise'],
 ['joy', 'surprise'],
 ['anger', 'sadness'],
 ['anger'],
 ['anger', 'disgust'],
 ['anger'],
 ['anger'],
 ['surprise'],
 ['surprise'],
 ['surprise'],
 ['sadness', 'neutral'],
 ['joy', 'surprise'],
 ['joy'],
 ['anger', 'disgust'],
 ['anger'],
 ['anger', 'joy'],
 ['surprise'],
 ['joy'],
 ['surprise'],
 ['anger', 'disgust'],
 ['anger'],
 ['anger', 'joy'],
 ['sadness'],
 ['fear', 'sadness'],
 ['joy'],
 ['anger', 'disgust'],
 ['anger', 'surprise'],
 ['anger', 'frustration'],
 ['anger', 'frustration'],
 ['joy'],
 ['anger'],
 ['anger'],
 ['anger'],
 ['anger', 'disgust'],
 ['joy'],
 ['neutral'],
 ['anger', 'fear'],
 ['joy'],
 ['anger', 'surprise'],
 ['anger', 'surprise'],
 ['anger', 'surprise'],
 ['joy'],
 ['joy'],
 ['disgust'],
 ['fear', 'surprise'],

In [36]:
grounds = test_df.emotions_list.tolist()

In [37]:
len(grounds)

1776

In [38]:
grounds

[['surprise', 'joy'],
 ['joy'],
 ['surprise', 'joy'],
 ['joy'],
 ['joy'],
 ['joy'],
 ['surprise'],
 ['joy'],
 ['joy'],
 ['neutral'],
 ['neutral'],
 ['neutral'],
 ['neutral'],
 ['anger', 'disgust'],
 ['anger', 'disgust'],
 ['neutral'],
 ['sadness'],
 ['sadness'],
 ['anger', 'sadness'],
 ['anger', 'sadness'],
 ['anger', 'sadness'],
 ['fear', 'surprise'],
 ['surprise'],
 ['joy'],
 ['anger', 'surprise'],
 ['joy'],
 ['joy'],
 ['joy'],
 ['anger'],
 ['anger'],
 ['surprise', 'joy'],
 ['fear', 'sadness'],
 ['fear', 'sadness'],
 ['fear', 'surprise'],
 ['anger', 'disgust'],
 ['anger', 'disgust'],
 ['anger', 'disgust'],
 ['fear', 'sadness'],
 ['fear', 'sadness', 'surprise'],
 ['sadness'],
 ['sadness'],
 ['fear', 'sadness'],
 ['sadness', 'surprise'],
 ['sadness', 'surprise'],
 ['joy'],
 ['anger'],
 ['anger'],
 ['anger'],
 ['anger', 'disgust'],
 ['joy'],
 ['joy'],
 ['surprise', 'joy'],
 ['surprise', 'joy'],
 ['anger', 'surprise'],
 ['anger', 'surprise'],
 ['neutral'],
 ['joy'],
 ['joy'],
 ['neutral'

In [39]:
len(grounds)

1776

In [40]:
all_labels = ["anger", "surprise", "fear", "disgust", "sadness", "joy", "neutral"]

def labels_to_binary_matrix(label_list, all_labels):
    binary_matrix = np.zeros((len(label_list), len(all_labels)))
    for i, labels in enumerate(label_list):
        for label in labels:
            if label in all_labels:
                binary_matrix[i][all_labels.index(label)] = 1
    return binary_matrix

def opposite(component_type):

    if component_type == "anger":
        return "surprise"
    elif component_type == "disgust":
        return "joy"
    elif component_type == "fear":
        return "sadness"
    elif component_type == "sadness":
        return "anger"
    elif component_type == "surprise":
        return "disgust"
    elif component_type == "joy":
        return "fear"
    elif component_type == "Neutral":
        return "sadness"
    

def harmonize_preds(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

def post_process_zs(grounds, preds):

    for i,(x,y) in enumerate(zip(grounds, preds)):
        
        if len(x) != len(y):
            
            preds[i] = harmonize_preds(x, y)

    true_matrix = labels_to_binary_matrix(grounds, all_labels)
    predicted_matrix = labels_to_binary_matrix(preds, all_labels)

    return true_matrix, predicted_matrix

In [41]:
true_matrix, predicted_matrix = post_process_zs(grounds, preds)

In [42]:
print(classification_report(true_matrix, predicted_matrix, target_names=all_labels, digits=3))

              precision    recall  f1-score   support

       anger      0.590     0.616     0.602       614
    surprise      0.671     0.541     0.599       486
        fear      0.436     0.410     0.423       407
     disgust      0.147     0.424     0.218        85
     sadness      0.433     0.291     0.348       347
         joy      0.580     0.522     0.550       429
     neutral      0.195     0.264     0.224       129

   micro avg      0.490     0.482     0.486      2497
   macro avg      0.436     0.438     0.424      2497
weighted avg      0.522     0.482     0.496      2497
 samples avg      0.482     0.476     0.478      2497



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
