# Zero-Shot Comics Classification with LLaMA.

### Libraries

In [1]:
import ast
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F


from tqdm.notebook import tqdm
from operator import itemgetter
from sklearn.metrics import classification_report
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM

### Model and Tokenizer

In [2]:
model_id = "unsloth/llama-3-8b-Instruct-bnb-4bit"

In [3]:
inference_tokenizer = AutoTokenizer.from_pretrained(model_id, padding='left', padding_side='left')
inference_tokenizer.pad_token = inference_tokenizer.eos_token
terminators = [
    inference_tokenizer.eos_token_id,
    inference_tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

In [4]:
generation_model = AutoModelForCausalLM.from_pretrained(
    model_id,
    #cache_dir = '/home/umushtaq/scratch/am_work/in_context_learning/model_downloads',
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


### Dataset

In [5]:
all_labels = ["anger", "surprise", "fear", "disgust", "sadness", "joy", "neutral"]

In [6]:
emotion_map = {
    'AN': 'anger',
    'DI': 'disgust',
    'FE': 'fear',
    'SA': 'sadness',
    'SU': 'surprise',
    'JO': 'joy'
}

In [7]:
df = pd.read_csv("/Utilisateurs/umushtaq/emotion_analysis_comics/zeroshot/datasets/comics_data_processed.csv")

In [8]:
def extract_emotions(row):

    emotion_str = row.emotion

    if emotion_str == 'Neutral':
        return ['neutral']

    emotions = emotion_str.split('-')
    tags = []

    for emotion in emotions:
        abbrev = emotion[:2]  # Get the abbreviation
        value_part = emotion[2:]  # Get the value part
        
        # Ensure that the value part is a valid integer and abbrev is in the emotion_map
        if abbrev in emotion_map and value_part.isdigit():
            value = int(value_part)
            if value > 0:
                tags.append(emotion_map[abbrev].lower())
        else:
            print(f"Warning: Skipping invalid emotion entry: '{emotion}'")
    return tags  

In [9]:
df['emotions_list'] = df.apply(lambda row: extract_emotions(row), axis=1)

In [10]:
df = df.drop(columns=[df.columns[0], df.columns[1]]).reset_index(drop=True)

In [11]:
# df = df[:100]

In [12]:
df.shape

(5282, 12)

### Create Messages and Prompts 

In [13]:
sys_msg_l = []
task_msg_l = []

In [14]:
for row in df.iterrows():

    sys_msg = {"role":"system", "content": "### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You must absolutely not generate any text or explanation other than the following JSON format {\"utterance_emotion\": <predicted emotion classes for the utterance (str)>}\n\n"}
    task_msg = {"role":"user", "content": f"# Utterance:\n{row[1].utterance}\n\n# Result:\n"}

    sys_msg_l.append(sys_msg)
    task_msg_l.append(task_msg)

In [15]:
len(sys_msg_l), len(task_msg_l)


(5282, 5282)

In [16]:
print(sys_msg_l[0]['content'])

### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You must absolutely not generate any text or explanation other than the following JSON format {"utterance_emotion": <predicted emotion classes for the utterance (str)>}




In [17]:
print(task_msg_l[0]['content'])

# Utterance:
DID YOU HAVE TO ELECTROCUTE HER SO HARD?

# Result:



In [18]:
prepared_sys_task_msg_l = []

for i in range(len(sys_msg_l)):
    prepared_sys_task_msg_l.append([sys_msg_l[i], task_msg_l[i]])

In [19]:
prepared_sys_task_msg_l[0]

[{'role': 'system',
  'content': '### Task description: You are an expert sentiment analysis assistant that takes an utterance from a comic book and must classify the utterance into appropriate emotion class(s): anger, surprise, fear, disgust, sadness, joy, neutral. You must absolutely not generate any text or explanation other than the following JSON format {"utterance_emotion": <predicted emotion classes for the utterance (str)>}\n\n'},
 {'role': 'user',
  'content': '# Utterance:\nDID YOU HAVE TO ELECTROCUTE HER SO HARD?\n\n# Result:\n'}]

In [20]:
outputs_l = []

for i in tqdm(range(len(prepared_sys_task_msg_l))):

    messages = prepared_sys_task_msg_l[i]

    input_ids = inference_tokenizer.apply_chat_template(
    messages,
    add_generation_prompt=True,
    tokenize=True,
    padding=True,
    truncation=True,
    return_tensors="pt"
).to(generation_model.device)

    outputs = generation_model.generate(
    input_ids = input_ids,
    max_new_tokens=1024,
    pad_token_id=inference_tokenizer.eos_token_id,
    eos_token_id=terminators,
    do_sample=True,
    temperature=0.1,
    top_p=0.9,
    )
    # inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True)
    outputs_l.append(inference_tokenizer.decode(outputs[0][input_ids.shape[-1]:], skip_special_tokens=True))

  0%|          | 0/5282 [00:00<?, ?it/s]

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


In [21]:
len(outputs_l)

5282

In [22]:
grounds = df.emotions_list.tolist()

In [23]:
preds = [list(ast.literal_eval(output).values()) for output in outputs_l]

In [24]:
len(grounds), len(preds)

(5282, 5282)

In [25]:
def opposite(component_type):

    if component_type == "anger":
        return "surprise"
    elif component_type == "disgust":
        return "joy"
    elif component_type == "fear":
        return "sadness"
    elif component_type == "sadness":
        return "anger"
    elif component_type == "surprise":
        return "disgust"
    elif component_type == "joy":
        return "fear"
    elif component_type == "Neutral":
        return "sadness"

In [26]:
def harmonize_preds(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

In [27]:
for i,(x,y) in enumerate(zip(grounds, preds)):
        
        if len(x) != len(y):
            
            preds[i] = harmonize_preds(x, y)

In [28]:
len(preds), len(grounds)

(5282, 5282)

In [29]:
# grounds = [item for sublist in grounds for item in sublist]
# preds = [item for sublist in preds for item in sublist]

In [30]:
len(grounds), len(preds)

(5282, 5282)

In [31]:
def labels_to_binary_matrix(label_list, all_labels):
    binary_matrix = np.zeros((len(label_list), len(all_labels)))
    for i, labels in enumerate(label_list):
        for label in labels:
            if label in all_labels:
                binary_matrix[i][all_labels.index(label)] = 1
    return binary_matrix

In [32]:
true_matrix = labels_to_binary_matrix(grounds, all_labels)
predicted_matrix = labels_to_binary_matrix(preds, all_labels)

In [33]:
print(classification_report(true_matrix, predicted_matrix, target_names=all_labels, digits=4))

              precision    recall  f1-score   support

       anger     0.5328    0.5762    0.5536      1791
    surprise     0.4764    0.4434    0.4593      1590
        fear     0.2095    0.0867    0.1226      1373
     disgust     0.0592    0.1865    0.0899       311
     sadness     0.3805    0.1761    0.2408      1238
         joy     0.4384    0.4158    0.4268      1104
     neutral     0.1836    0.2478    0.2109       343

   micro avg     0.3797    0.3453    0.3617      7750
   macro avg     0.3258    0.3046    0.3005      7750
weighted avg     0.3917    0.3453    0.3561      7750
 samples avg     0.3707    0.3571    0.3619      7750



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
