# Finetune PE dataset for ACC

## Libraries

In [1]:
# Run this cell only once to install LLaMA-Factory

# %cd ..
# %rm -rf LLaMA-Factory
# !git clone https://github.com/hiyouga/LLaMA-Factory.git
# %cd LLaMA-Factory
# %ls
# !pip install -e .[torch,bitsandbytes]

In [2]:
# !pip uninstall -y pydantic
# !pip install pydantic==1.10.9 # 

# !pip uninstall -y gradio
# !pip install gradio==3.48.0

# !pip uninstall -y bitsandbytes
# !pip install --upgrade bitsandbytes

# !pip install tqdm
# !pip install ipywidgets
# !pip install scikit-learn

# Restart kernel afterwards.

In [45]:
import os
import ast
import sys
import json
import torch
import pickle
import subprocess

sys.path.append('../')

import pandas as pd

from tqdm.notebook import tqdm
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc
from sklearn.metrics import classification_report
from utils.post_processing import post_process_acc

In [4]:
try:    
    assert torch.cuda.is_available() is True
    
except AssertionError:
    
    print("Please set up a GPU before using LLaMA Factory...")

## Parameters

In [5]:
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [6]:
DATASET_DIR = os.path.join(ROOT_DIR, "datasets")

In [7]:
LLAMA_FACTORY_DIR = os.path.join(ROOT_DIR, "LLaMA-Factory")

In [8]:
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"

In [9]:
TASK = "acc"

In [10]:
TAGS = 1
TAGS = "wtags" if TAGS == 1 else "wotags"

In [11]:
CONTEXT = "essay" # essay or paragraph

In [12]:
OUTPUT_DIR = os.path.join(ROOT_DIR, "finetuned_models", f"""PE_{TASK}_{CONTEXT}_{TAGS}_{BASE_MODEL.split("/")[1]}""")

## Load Dataset

In [13]:
# *** TRAIN DATASET NAME *** #

train_dataset_name = f"""PE_{TASK}_{CONTEXT}_{TAGS}_train.json"""
test_dataset_name = f"""PE_{TASK}_{CONTEXT}_{TAGS}_test.json"""

#train_dataset_name = f"""PE_{TASK}_{CONTEXT}_train.json"""
train_dataset_file = os.path.join(DATASET_DIR, train_dataset_name)

# *** TEST DATASET NAME *** #

#test_dataset_name = f"""PE_{TASK}_{CONTEXT}_test.json"""
test_dataset_file = os.path.join(DATASET_DIR, test_dataset_name)

In [14]:
train_dataset_file, test_dataset_file

('/Utilisateurs/umushtaq/am_work/coling_2025/datasets/PE_acc_essay_wtags_train.json',
 '/Utilisateurs/umushtaq/am_work/coling_2025/datasets/PE_acc_essay_wtags_test.json')

## Fine-tune Model

In [15]:
if not os.path.exists(os.path.join(ROOT_DIR, "ft_arg_files")):
    os.mkdir(os.path.join(ROOT_DIR, "ft_arg_files"))

In [16]:
# *** TRAIN FILE ***

# model_name = f"""{train_dataset_name.split(".")[0].split("train")[0]}{BASE_MODEL.split("/")[1]}"""

train_file = os.path.join(ROOT_DIR, "ft_arg_files", f"""{train_dataset_name.split(".")[0].split("train")[0]}{BASE_MODEL.split("/")[1]}.json""")

In [17]:
dataset_info_line =  {
  "file_name": f"{train_dataset_file}",
  "columns": {
    "prompt": "instruction",
    "query": "input",
    "response": "output"
  }
}

In [18]:
dataset_info_line

{'file_name': '/Utilisateurs/umushtaq/am_work/coling_2025/datasets/PE_acc_essay_wtags_train.json',
 'columns': {'prompt': 'instruction', 'query': 'input', 'response': 'output'}}

In [19]:
with open(os.path.join(LLAMA_FACTORY_DIR, "data/dataset_info.json"), "r") as jsonFile:
    data = json.load(jsonFile)

data["persuasive_essays"] = dataset_info_line

with open(os.path.join(LLAMA_FACTORY_DIR, "data/dataset_info.json"), "w") as jsonFile:
    json.dump(data, jsonFile)

### Training Args

In [20]:
NB_EPOCHS = 0.2

In [21]:
args = dict(
  stage="sft",                           # do supervised fine-tuning
  do_train=True,
  model_name_or_path=BASE_MODEL,         # use bnb-4bit-quantized Llama-3-8B-Instruct model
  dataset="persuasive_essays",           # use alpaca and identity datasets
  template="llama3",                     # use llama3 prompt template
  finetuning_type="lora",                # use LoRA adapters to save memory
  lora_target="all",                     # attach LoRA adapters to all linear layers
  output_dir=OUTPUT_DIR,                 # the path to save LoRA adapters
  overwrite_output_dir=True,             # overrides existing output contents
  per_device_train_batch_size=2,         # the batch size
  gradient_accumulation_steps=4,         # the gradient accumulation steps
  lr_scheduler_type="cosine",            # use cosine learning rate scheduler
  logging_steps=10,                      # log every 10 steps
  warmup_ratio=0.1,                      # use warmup scheduler
  save_steps=3000,                       # save checkpoint every 1000 steps
  learning_rate=5e-5,                    # the learning rate
  num_train_epochs=NB_EPOCHS,            # the epochs of training
  max_samples=2000,                       # use 500 examples in each dataset
  max_grad_norm=1.0,                     # clip gradient norm to 1.0
  quantization_bit=4,                    # use 4-bit QLoRA
  loraplus_lr_ratio=16.0,                # use LoRA+ algorithm with lambda=16.0
  fp16=True,                             # use float16 mixed precision training
  report_to="none"                       # discards wandb
)

In [22]:
json.dump(args, open(train_file, "w", encoding="utf-8"), indent=2)

In [23]:
train_file

'/Utilisateurs/umushtaq/am_work/coling_2025/ft_arg_files/PE_acc_essay_wtags_llama-3-8b-Instruct-bnb-4bit.json'

In [24]:
p = subprocess.Popen(["llamafactory-cli", "train", train_file], cwd=LLAMA_FACTORY_DIR)

In [25]:
p.wait()

08/30/2024 16:14:54 - INFO - llamafactory.cli - Initializing distributed tasks at: 127.0.0.1:24976


W0830 16:14:55.671000 140156545742144 torch/distributed/run.py:757] 
W0830 16:14:55.671000 140156545742144 torch/distributed/run.py:757] *****************************************
W0830 16:14:55.671000 140156545742144 torch/distributed/run.py:757] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. 
W0830 16:14:55.671000 140156545742144 torch/distributed/run.py:757] *****************************************


08/30/2024 16:15:04 - INFO - llamafactory.hparams.parser - Process rank: 0, device: cuda:0, n_gpu: 1, distributed training: True, compute dtype: torch.float16
08/30/2024 16:15:04 - INFO - llamafactory.hparams.parser - Process rank: 1, device: cuda:1, n_gpu: 1, distributed training: True, compute dtype: torch.float16


[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:04,594 >> loading file tokenizer.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/tokenizer.json
[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:04,595 >> loading file added_tokens.json from cache at None
[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:04,595 >> loading file special_tokens_map.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/special_tokens_map.json
[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:04,595 >> loading file tokenizer_config.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/tokenizer_config.json
Special tokens have been added in the vocabulary, make

08/30/2024 16:15:05 - INFO - llamafactory.data.template - Replace eos token: <|eot_id|>
08/30/2024 16:15:05 - INFO - llamafactory.data.loader - Loading dataset /Utilisateurs/umushtaq/am_work/coling_2025/datasets/PE_acc_essay_wtags_train.json...
08/30/2024 16:15:05 - INFO - llamafactory.data.template - Replace eos token: <|eot_id|>
08/30/2024 16:15:06 - INFO - llamafactory.data.loader - Loading dataset /Utilisateurs/umushtaq/am_work/coling_2025/datasets/PE_acc_essay_wtags_train.json...
training example:
input_ids:
[128000, 128006, 882, 128007, 271, 14711, 1472, 527, 459, 6335, 304, 14138, 26917, 13, 1472, 527, 2728, 459, 9071, 902, 5727, 49926, 5811, 6956, 44910, 555, 366, 1741, 1500, 1741, 29, 9681, 13, 4718, 3465, 374, 311, 49229, 1855, 5811, 6956, 304, 279, 9071, 439, 3060, 330, 35575, 46644, 498, 330, 46644, 1, 477, 330, 42562, 1082, 3343, 1472, 2011, 471, 264, 1160, 315, 5811, 3777, 4595, 304, 2768, 4823, 3645, 25, 5324, 8739, 9962, 794, 510, 8739, 1857, 320, 496, 705, 3777, 1857, 

[INFO|configuration_utils.py:733] 2024-08-30 16:15:06,577 >> loading configuration file config.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/config.json
[INFO|configuration_utils.py:796] 2024-08-30 16:15:06,578 >> Model config LlamaConfig {
  "_name_or_path": "unsloth/llama-3-8b-Instruct-bnb-4bit",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pad_token_id": 128255,
  "pretraining_tp": 1,
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtyp

08/30/2024 16:15:06 - INFO - llamafactory.model.model_utils.quantization - Loading ?-bit BITSANDBYTES-quantized model.
08/30/2024 16:15:06 - INFO - llamafactory.model.model_utils.quantization - Loading ?-bit BITSANDBYTES-quantized model.


[INFO|modeling_utils.py:3474] 2024-08-30 16:15:06,650 >> loading weights file model.safetensors from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/model.safetensors
Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.
[INFO|modeling_utils.py:1519] 2024-08-30 16:15:06,702 >> Instantiating LlamaForCausalLM model under default dtype torch.float16.
[INFO|configuration_utils.py:962] 2024-08-30 16:15:06,705 >> Generate config GenerationConfig {
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "pad_token_id": 128255
}



08/30/2024 16:15:11 - INFO - llamafactory.model.model_utils.checkpointing - Gradient checkpointing enabled.
08/30/2024 16:15:11 - INFO - llamafactory.model.model_utils.attention - Using torch SDPA for faster training and inference.
08/30/2024 16:15:11 - INFO - llamafactory.model.adapter - Upcasting trainable params to float32.
08/30/2024 16:15:11 - INFO - llamafactory.model.adapter - Fine-tuning method: LoRA
08/30/2024 16:15:11 - INFO - llamafactory.model.model_utils.misc - Found linear modules: down_proj,k_proj,up_proj,o_proj,q_proj,v_proj,gate_proj


[INFO|modeling_utils.py:4280] 2024-08-30 16:15:12,282 >> All model checkpoint weights were used when initializing LlamaForCausalLM.

[INFO|modeling_utils.py:4288] 2024-08-30 16:15:12,283 >> All the weights of LlamaForCausalLM were initialized from the model checkpoint at unsloth/llama-3-8b-Instruct-bnb-4bit.
If your task is similar to the task the model of the checkpoint was trained on, you can already use LlamaForCausalLM for predictions without further training.
[INFO|configuration_utils.py:917] 2024-08-30 16:15:12,434 >> loading configuration file generation_config.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/generation_config.json
[INFO|configuration_utils.py:962] 2024-08-30 16:15:12,434 >> Generate config GenerationConfig {
  "bos_token_id": 128000,
  "do_sample": true,
  "eos_token_id": [
    128001,
    128009
  ],
  "max_length": 8192,
  "pad_token_id": 128255,
 

08/30/2024 16:15:12 - INFO - llamafactory.model.loader - trainable params: 20,971,520 || all params: 8,051,232,768 || trainable%: 0.2605
08/30/2024 16:15:12 - INFO - llamafactory.train.trainer_utils - Using LoRA+ optimizer with loraplus lr ratio 16.00.
08/30/2024 16:15:12 - INFO - llamafactory.model.model_utils.checkpointing - Gradient checkpointing enabled.
08/30/2024 16:15:12 - INFO - llamafactory.model.model_utils.attention - Using torch SDPA for faster training and inference.
08/30/2024 16:15:12 - INFO - llamafactory.model.adapter - Upcasting trainable params to float32.
08/30/2024 16:15:12 - INFO - llamafactory.model.adapter - Fine-tuning method: LoRA
08/30/2024 16:15:12 - INFO - llamafactory.model.model_utils.misc - Found linear modules: gate_proj,up_proj,down_proj,q_proj,k_proj,v_proj,o_proj
08/30/2024 16:15:13 - INFO - llamafactory.model.loader - trainable params: 20,971,520 || all params: 8,051,232,768 || trainable%: 0.2605


[INFO|trainer.py:641] 2024-08-30 16:15:13,332 >> Using auto half precision backend


08/30/2024 16:15:13 - INFO - llamafactory.train.trainer_utils - Using LoRA+ optimizer with loraplus lr ratio 16.00.


[INFO|trainer.py:2078] 2024-08-30 16:15:13,845 >> ***** Running training *****
[INFO|trainer.py:2079] 2024-08-30 16:15:13,858 >>   Num examples = 322
[INFO|trainer.py:2080] 2024-08-30 16:15:13,858 >>   Num Epochs = 1
[INFO|trainer.py:2081] 2024-08-30 16:15:13,858 >>   Instantaneous batch size per device = 2
[INFO|trainer.py:2084] 2024-08-30 16:15:13,858 >>   Total train batch size (w. parallel, distributed & accumulation) = 16
[INFO|trainer.py:2085] 2024-08-30 16:15:13,858 >>   Gradient Accumulation steps = 4
[INFO|trainer.py:2086] 2024-08-30 16:15:13,858 >>   Total optimization steps = 4
[INFO|trainer.py:2087] 2024-08-30 16:15:13,861 >>   Number of trainable parameters = 20,971,520
100%|██████████| 4/4 [00:25<00:00,  6.29s/it][INFO|trainer.py:2329] 2024-08-30 16:15:39,689 >> 

Training completed. Do not forget to share your model on huggingface.co/models =)


100%|██████████| 4/4 [00:25<00:00,  6.39s/it]
[INFO|trainer.py:3410] 2024-08-30 16:15:39,719 >> Saving model checkpoint to /Uti

{'train_runtime': 25.828, 'train_samples_per_second': 2.493, 'train_steps_per_second': 0.155, 'train_loss': 0.44080960750579834, 'epoch': 0.2}


[INFO|configuration_utils.py:733] 2024-08-30 16:15:40,024 >> loading configuration file config.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/config.json
[INFO|configuration_utils.py:796] 2024-08-30 16:15:40,025 >> Model config LlamaConfig {
  "_name_or_path": "unsloth/llama-3-8b-Instruct",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pad_token_id": 128255,
  "pretraining_tp": 1,
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtype": "bflo

***** train metrics *****
  epoch                    =     0.1975
  total_flos               =  2005818GF
  train_loss               =     0.4408
  train_runtime            = 0:00:25.82
  train_samples_per_second =      2.493
  train_steps_per_second   =      0.155


[INFO|modelcard.py:450] 2024-08-30 16:15:41,339 >> Dropping the following result as it does not have all the necessary fields:
{'task': {'name': 'Causal Language Modeling', 'type': 'text-generation'}}


0

## Inference on the fine-tuned model

In [26]:
OUTPUT_DIR

'/Utilisateurs/umushtaq/am_work/coling_2025/finetuned_models/PE_acc_essay_wtags_llama-3-8b-Instruct-bnb-4bit'

In [27]:
os.listdir(OUTPUT_DIR)

['README.md',
 'adapter_model.safetensors',
 'adapter_config.json',
 'tokenizer_config.json',
 'special_tokens_map.json',
 'tokenizer.json',
 'training_args.bin',
 'train_results.json',
 'all_results.json',
 'trainer_state.json',
 'PE_acc_results_0.2.pickle',
 'trainer_log.jsonl']

In [28]:
args = dict(
  model_name_or_path=BASE_MODEL, # use bnb-4bit-quantized Llama-3-8B-Instruct model
  adapter_name_or_path=OUTPUT_DIR,            # load the saved LoRA adapters
  template="llama3",                     # same to the one in training
  finetuning_type="lora",                  # same to the one in training
  quantization_bit=4,                    # load 4-bit quantized model
)


In [29]:
model = ChatModel(args)

[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:46,859 >> loading file tokenizer.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/tokenizer.json
[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:46,861 >> loading file added_tokens.json from cache at None
[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:46,863 >> loading file special_tokens_map.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/special_tokens_map.json
[INFO|tokenization_utils_base.py:2108] 2024-08-30 16:15:46,865 >> loading file tokenizer_config.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/tokenizer_config.json


08/30/2024 16:15:47 - INFO - llamafactory.data.template - Replace eos token: <|eot_id|>


[INFO|configuration_utils.py:733] 2024-08-30 16:15:47,324 >> loading configuration file config.json from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/config.json
[INFO|configuration_utils.py:796] 2024-08-30 16:15:47,326 >> Model config LlamaConfig {
  "_name_or_path": "unsloth/llama-3-8b-Instruct-bnb-4bit",
  "architectures": [
    "LlamaForCausalLM"
  ],
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "hidden_act": "silu",
  "hidden_size": 4096,
  "initializer_range": 0.02,
  "intermediate_size": 14336,
  "max_position_embeddings": 8192,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 32,
  "num_hidden_layers": 32,
  "num_key_value_heads": 8,
  "pad_token_id": 128255,
  "pretraining_tp": 1,
  "quantization_config": {
    "_load_in_4bit": true,
    "_load_in_8bit": false,
    "bnb_4bit_compute_dtyp

08/30/2024 16:15:47 - INFO - llamafactory.model.model_utils.quantization - Loading ?-bit BITSANDBYTES-quantized model.
08/30/2024 16:15:47 - INFO - llamafactory.model.patcher - Using KV cache for faster generation.


[INFO|modeling_utils.py:3474] 2024-08-30 16:15:47,391 >> loading weights file model.safetensors from cache at /Utilisateurs/umushtaq/.cache/huggingface/hub/models--unsloth--llama-3-8b-Instruct-bnb-4bit/snapshots/65e42616f7908d202462119a2749377133801581/model.safetensors
[INFO|modeling_utils.py:1519] 2024-08-30 16:15:47,420 >> Instantiating LlamaForCausalLM model under default dtype torch.bfloat16.
[INFO|configuration_utils.py:962] 2024-08-30 16:15:47,424 >> Generate config GenerationConfig {
  "bos_token_id": 128000,
  "eos_token_id": 128009,
  "pad_token_id": 128255
}

[INFO|quantizer_bnb_4bit.py:105] 2024-08-30 16:15:47,812 >> target_dtype {target_dtype} is replaced by `CustomDtype.INT4` for 4-bit BnB quantization
[INFO|modeling_utils.py:4280] 2024-08-30 16:15:49,666 >> All model checkpoint weights were used when initializing LlamaForCausalLM.

[INFO|modeling_utils.py:4288] 2024-08-30 16:15:49,668 >> All the weights of LlamaForCausalLM were initialized from the model checkpoint at un

08/30/2024 16:15:49 - INFO - llamafactory.model.model_utils.attention - Using torch SDPA for faster training and inference.
08/30/2024 16:15:50 - INFO - llamafactory.model.adapter - Loaded adapter(s): /Utilisateurs/umushtaq/am_work/coling_2025/finetuned_models/PE_acc_essay_wtags_llama-3-8b-Instruct-bnb-4bit
08/30/2024 16:15:50 - INFO - llamafactory.model.loader - all params: 8,051,232,768


In [30]:
with open(test_dataset_file, "r+") as fh:
    test_dataset = json.load(fh)

In [31]:
test_prompts = []
test_grounds = []

for sample in test_dataset:
    test_prompts.append("\nUser:" + sample["instruction"] + sample["input"])
    test_grounds.append(sample["output"])

In [32]:
test_predictions = []

for prompt in tqdm(test_prompts):

    messages = []
    messages.append({"role": "user", "content": prompt})

    response = ""
    
    for new_text in model.stream_chat(messages):
        #print(new_text, end="", flush=True)
        response += new_text
        #print()
    test_predictions.append({"role": "assistant", "content": response})

    torch_gc()

  0%|          | 0/80 [00:00<?, ?it/s]

In [33]:
with open(os.path.join(OUTPUT_DIR, f"""PE_{TASK}_results_{NB_EPOCHS}.pickle"""), 'wb') as fh:
    results_d = {"ground_truths": test_grounds,
                 "predictions": test_predictions    
        
    }
    pickle.dump(results_d, fh)

## Post-processing

In [37]:
with open(os.path.join(OUTPUT_DIR, f"""PE_{TASK}_results_{NB_EPOCHS}.pickle"""), "rb") as fh:
        
        results = pickle.load(fh)

In [38]:
task_grounds, task_preds = post_process_acc(results)

In [42]:
# sanity check: 
len(task_preds) == len(task_grounds)

True

## Results

In [43]:
print(classification_report(task_grounds, task_preds, digits=3))

              precision    recall  f1-score   support

       Claim      0.231     0.032     0.056       283
  MajorClaim      0.724     0.136     0.230       154
     Premise      0.661     0.997     0.795       724

    accuracy                          0.648      1161
   macro avg      0.538     0.388     0.360      1161
weighted avg      0.564     0.648     0.540      1161



In [44]:
with open(f"""{OUTPUT_DIR}/classification_report.pickle""", 'wb') as fh:
    
    pickle.dump(classification_report(task_grounds, task_preds, output_dict=True), fh)