# abstRCT dataset post-processing

## Libraries

In [1]:
# Run this cell only once to install LLaMA-Factory

# %cd ..
# %rm -rf LLaMA-Factory
# !git clone https://github.com/hiyouga/LLaMA-Factory.git
# %cd LLaMA-Factory
# %ls
# !pip install -e .[torch,bitsandbytes]

In [2]:
# !pip uninstall -y pydantic
# !pip install pydantic==1.10.9 # 

# !pip uninstall -y gradio
# !pip install gradio==3.48.0

# !pip uninstall -y bitsandbytes
# !pip install --upgrade bitsandbytes

# !pip install tqdm
# !pip install ipywidgets
# !pip install scikit-learn

# Restart kernel afterwards.

In [1]:
import os
import ast
import sys
import json
import torch
import pickle
import subprocess

sys.path.append('../')

import pandas as pd

from tqdm.notebook import tqdm
from llamafactory.chat import ChatModel
from llamafactory.extras.misc import torch_gc
from sklearn.metrics import classification_report
from utils.post_processing import post_process_acc

In [2]:
try:    
    assert torch.cuda.is_available() is True
    
except AssertionError:
    
    print("Please set up a GPU before using LLaMA Factory...")

## Parameters

In [3]:
abst_rct_dir = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

In [4]:
abst_rct_dir

'/Utilisateurs/umushtaq/am_work/AMwithLLMs/abstRCT'

In [5]:
BASE_MODEL = "unsloth/gemma-2-9b-it-bnb-4bit"

In [161]:
TASK = "acc"

In [162]:
OUTPUT_DIR = os.path.join(abst_rct_dir, "finetuned_models_run3", f"""abstRCT_{TASK}_{BASE_MODEL.split("/")[1]}""")

In [163]:
OUTPUT_DIR

'/Utilisateurs/umushtaq/am_work/AMwithLLMs/abstRCT/finetuned_models_run3/abstRCT_acc_gemma-2-9b-it-bnb-4bit'

In [164]:
os.listdir(OUTPUT_DIR)

['README.md',
 'abstRCT_acc_results_5_gla.pickle',
 'abstRCT_acc_results_5_mix.pickle',
 'abstRCT_acc_results_5_neo.pickle',
 'adapter_config.json',
 'all_results.json',
 'checkpoint-110',
 'special_tokens_map.json',
 'tokenizer.json',
 'tokenizer.model',
 'tokenizer_config.json',
 'train_results.json',
 'trainer_log.jsonl',
 'trainer_state.json',
 'training_args.bin']

In [165]:
NB_EPOCHS = 5

In [197]:
test_set = 'mix'

In [198]:
test_set

'mix'

## Post-processing

In [199]:
with open(os.path.join(OUTPUT_DIR, f"""abstRCT_{TASK}_results_{NB_EPOCHS}_{test_set}.pickle"""), "rb") as fh:
        
        results = pickle.load(fh)

In [200]:
results

{'ground_truths': ['{"component_types": ["Premise", "Premise", "Premise", "Premise", "Premise", "Claim"]}',
  '{"component_types": ["Claim", "Premise", "Premise", "Premise", "Claim", "Claim", "Claim", "Claim"]}',
  '{"component_types": ["Premise", "Premise", "Premise", "Premise", "Premise", "Premise", "Premise", "Claim", "Claim"]}',
  '{"component_types": ["Claim", "Claim", "Premise", "Premise", "Premise", "Premise", "Claim", "Claim"]}',
  '{"component_types": ["Premise", "Premise", "Premise", "Premise", "Premise", "Claim", "Claim"]}',
  '{"component_types": ["Premise", "Premise", "Premise", "Claim", "Premise", "Premise", "Claim", "Claim"]}',
  '{"component_types": ["Claim", "Premise", "Premise", "Premise", "Claim", "Claim"]}',
  '{"component_types": ["Claim", "Premise", "Premise", "Premise", "Premise", "Premise", "Claim", "Claim", "Premise", "Claim"]}',
  '{"component_types": ["Premise", "Premise", "Premise", "Premise", "Claim"]}',
  '{"component_types": ["Premise", "Premise", "Premis

In [201]:
def opposite_acc(component_type):

    if component_type == "Claim":
        return "Premise"
    elif component_type == "Premise":
        return "Claim"

def harmonize_preds_acc(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite_acc(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

def post_process_acc(results):

    grounds = results["ground_truths"]
    preds = results["predictions"]
    
    grounds = [json.loads(x)["component_types"] for x in grounds]  
    
    preds = [x["content"] for x in preds]    
    preds = [json.loads(x)["component_types"] for x in preds]
    
    for i,(x,y) in enumerate(zip(grounds, preds)):
    
        if len(x) != len(y):
            
            preds[i] = harmonize_preds_acc(x, y)
            
    task_preds = [item for row in preds for item in row]
    task_grounds = [item for row in grounds for item in row]

    return task_grounds, task_preds


In [202]:
task_grounds, task_preds = post_process_acc(results)

In [203]:
print(classification_report(task_grounds, task_preds, digits=3))

              precision    recall  f1-score   support

       Claim      0.927     0.962     0.944       212
     Premise      0.979     0.960     0.969       397

    accuracy                          0.961       609
   macro avg      0.953     0.961     0.957       609
weighted avg      0.961     0.961     0.961       609



In [204]:
test_set

'mix'

In [205]:
with open(f"""{OUTPUT_DIR}/classification_report_{test_set}.pickle""", 'wb') as fh:
        
        pickle.dump(classification_report(task_grounds, task_preds, output_dict=True), fh)

In [126]:
test_set = 'mix'

In [127]:
neo_test_df = pd.read_pickle("../datasets/neo_test_df.pkl")
gla_test_df = pd.read_pickle("../datasets/gla_test_df.pkl")
mix_test_df = pd.read_pickle("../datasets/mix_test_df.pkl")

In [128]:
nr_acs_l = list(mix_test_df.acs_list)
nr_acs_l = [len(element) for element in nr_acs_l]

In [129]:
len(nr_acs_l)

100

In [130]:
import numpy as np

In [131]:
nr_acs_l

[7,
 11,
 10,
 10,
 9,
 11,
 7,
 16,
 7,
 9,
 7,
 12,
 7,
 9,
 9,
 6,
 10,
 8,
 11,
 10,
 7,
 7,
 8,
 7,
 10,
 7,
 9,
 7,
 7,
 7,
 11,
 4,
 7,
 10,
 6,
 7,
 8,
 11,
 6,
 7,
 7,
 8,
 7,
 7,
 5,
 8,
 6,
 5,
 6,
 8,
 7,
 8,
 7,
 7,
 8,
 7,
 8,
 6,
 8,
 6,
 5,
 8,
 8,
 7,
 10,
 6,
 8,
 8,
 6,
 5,
 5,
 6,
 6,
 6,
 5,
 5,
 9,
 4,
 16,
 5,
 8,
 10,
 6,
 8,
 6,
 7,
 6,
 6,
 7,
 3,
 6,
 9,
 5,
 9,
 6,
 8,
 8,
 11,
 10,
 6]

In [132]:
with open(os.path.join(OUTPUT_DIR, f"""abstRCT_{TASK}_results_{NB_EPOCHS}_{test_set}.pickle"""), "rb") as fh:
        
        results = pickle.load(fh)

In [133]:
results

{'ground_truths': ['{"list_argument_relation_types": [[1, 6, "support"], [4, 6, "support"]]}',
  '{"list_argument_relation_types": [[2, 5, "support"], [3, 5, "support"], [5, 7, "support"], [6, 8, "support"]]}',
  '{"list_argument_relation_types": [[1, 8, "support"], [2, 8, "support"], [3, 8, "support"], [5, 9, "support"], [6, 8, "support"], [7, 9, "support"]]}',
  '{"list_argument_relation_types": [[3, 8, "support"], [4, 8, "support"], [5, 8, "support"], [6, 8, "support"]]}',
  '{"list_argument_relation_types": [[1, 6, "support"], [2, 6, "support"], [3, 7, "support"], [4, 6, "support"], [5, 6, "support"]]}',
  '{"list_argument_relation_types": [[1, 8, "support"], [2, 8, "support"], [3, 7, "support"], [5, 4, "attack"], [6, 4, "support"]]}',
  '{"list_argument_relation_types": [[2, 6, "support"], [3, 6, "support"]]}',
  '{"list_argument_relation_types": [[3, 7, "support"], [4, 8, "support"], [5, 10, "support"], [6, 10, "support"], [9, 10, "support"]]}',
  '{"list_argument_relation_types"

In [134]:
grounds = results["ground_truths"]

In [135]:
grounds = [json.loads(x)["list_argument_relation_types"] for x in grounds]

In [136]:
preds = results["predictions"]

In [137]:
preds = [x["content"] for x in preds]

In [138]:
# preds[31] = '{"list_argument_relation_types": [[1, 6, "support"], [2, 1, "support"], [3, 6, "support"], [4, 6, "support"], [5, "none", ""]]}'

In [139]:
preds

['{"list_argument_relation_types": [[1, 6, "support"], [2, 6, "support"], [3, 6, "support"], [4, 6, "support"]]}',
 '{"list_argument_relation_types": [[2, 7, "support"], [3, 7, "support"], [4, 7, "support"], [5, 7, "support"], [6, 8, "support"]]}',
 '{"list_argument_relation_types": [[5, 9, "support"], [6, 8, "support"], [7, 8, "support"]]}',
 '{"list_argument_relation_types": [[3, 8, "support"], [4, 8, "support"], [5, 8, "support"], [6, 8, "support"], [7, 8, "support"]]}',
 '{"list_argument_relation_types": [[1, 6, "support"], [2, 6, "support"], [3, 7, "support"], [4, 7, "support"], [5, 6, "support"], [7, 6, "attack"]]}',
 '{"list_argument_relation_types": [[1, 7, "support"], [2, 8, "support"], [3, 8, "support"], [4, 7, "support"], [5, 4, "attack"], [6, 7, "support"]]}',
 '{"list_argument_relation_types": [[2, 6, "support"], [3, 6, "support"]]}',
 '{"list_argument_relation_types": [[3, 7, "support"], [4, 8, "support"], [5, 9, "support"], [6, 10, "support"], [7, 11, "support"], [8, 11,

In [140]:
# for i in range(len(preds)):
#     try: 
#         json.loads(preds[i])["list_argument_relation_types"]
#     except:
#         print(i) # for x in preds


preds = [json.loads(x)["list_argument_relation_types"] for x in preds]

In [141]:
len(grounds), len(preds)

(100, 100)

In [142]:
for i,(x,y) in enumerate(zip(grounds, preds)):
    
    if len(x) != len(y):
            
        print(i)

0
1
2
3
4
5
7
8
9
10
11
12
13
14
17
18
19
21
22
30
33
34
36
37
38
39
40
41
44
47
49
51
53
54
56
58
60
64
65
69
70
71
72
74
75
76
78
80
81
83
85
86
88
89
93
94
97


In [143]:
grounds

[[[1, 6, 'support'], [4, 6, 'support']],
 [[2, 5, 'support'], [3, 5, 'support'], [5, 7, 'support'], [6, 8, 'support']],
 [[1, 8, 'support'],
  [2, 8, 'support'],
  [3, 8, 'support'],
  [5, 9, 'support'],
  [6, 8, 'support'],
  [7, 9, 'support']],
 [[3, 8, 'support'], [4, 8, 'support'], [5, 8, 'support'], [6, 8, 'support']],
 [[1, 6, 'support'],
  [2, 6, 'support'],
  [3, 7, 'support'],
  [4, 6, 'support'],
  [5, 6, 'support']],
 [[1, 8, 'support'],
  [2, 8, 'support'],
  [3, 7, 'support'],
  [5, 4, 'attack'],
  [6, 4, 'support']],
 [[2, 6, 'support'], [3, 6, 'support']],
 [[3, 7, 'support'],
  [4, 8, 'support'],
  [5, 10, 'support'],
  [6, 10, 'support'],
  [9, 10, 'support']],
 [[1, 5, 'support'], [3, 5, 'support']],
 [[1, 5, 'support'], [2, 5, 'support'], [4, 8, 'support'], [8, 7, 'attack']],
 [[1, 6, 'support'], [2, 6, 'support'], [4, 3, 'attack']],
 [[1, 9, 'support'],
  [2, 10, 'support'],
  [3, 10, 'support'],
  [7, 11, 'support'],
  [8, 7, 'attack'],
  [9, 11, 'support'],
  [10,

In [144]:
grounds[7]

[[3, 7, 'support'],
 [4, 8, 'support'],
 [5, 10, 'support'],
 [6, 10, 'support'],
 [9, 10, 'support']]

In [145]:
preds

[[[1, 6, 'support'], [2, 6, 'support'], [3, 6, 'support'], [4, 6, 'support']],
 [[2, 7, 'support'],
  [3, 7, 'support'],
  [4, 7, 'support'],
  [5, 7, 'support'],
  [6, 8, 'support']],
 [[5, 9, 'support'], [6, 8, 'support'], [7, 8, 'support']],
 [[3, 8, 'support'],
  [4, 8, 'support'],
  [5, 8, 'support'],
  [6, 8, 'support'],
  [7, 8, 'support']],
 [[1, 6, 'support'],
  [2, 6, 'support'],
  [3, 7, 'support'],
  [4, 7, 'support'],
  [5, 6, 'support'],
  [7, 6, 'attack']],
 [[1, 7, 'support'],
  [2, 8, 'support'],
  [3, 8, 'support'],
  [4, 7, 'support'],
  [5, 4, 'attack'],
  [6, 7, 'support']],
 [[2, 6, 'support'], [3, 6, 'support']],
 [[3, 7, 'support'],
  [4, 8, 'support'],
  [5, 9, 'support'],
  [6, 10, 'support'],
  [7, 11, 'support'],
  [8, 11, 'support']],
 [[1, 5, 'support'], [2, 5, 'support'], [3, 5, 'support'], [4, 3, 'support']],
 [[1, 7, 'support'],
  [2, 5, 'support'],
  [3, 8, 'support'],
  [4, 8, 'support'],
  [5, 7, 'support'],
  [6, 5, 'attack'],
  [8, 7, 'attack']],
 

In [146]:
grounds_tmp = grounds[:]
preds_tmp = preds[:]

In [147]:
def process_lists(l):

    l_new = [] 
    for ll in l:
        ll_tmp = []
        for item in ll:
            if item not in ll_tmp:
                ll_tmp.append(item)
        # ll = [list(set(x)) for x in ll]
        l_new.append(ll_tmp)

    return l_new

In [148]:
# # l = [[1,2,3,3], [1,2,3,5]]
# grounds_arg_new = [] 
# for ll in grounds_arg:
#     ll_tmp = []
#     for item in ll:
#         if item not in ll_tmp:
#             ll_tmp.append(item)
#     # ll = [list(set(x)) for x in ll]
#     grounds_arg_new.append(ll_tmp)
#     #[grounds_arg_new.append(item) for item in l if item not in l_new]

In [149]:
type(1) == int

True

In [150]:
def clean_preds(nr_acs, current_preds_arg):

    current_preds = current_preds_arg[:]
    
    for pred in current_preds_arg:
        if len(pred) != 3:
            current_preds.remove(pred)
        elif pred[0] == pred[1]:
            current_preds.remove(pred)
        elif pred[2] != 'support' and pred[2] != "attack":
            current_preds.remove(pred)
        elif (type(pred[0]) == int and pred[0] > nr_acs) or (type(pred[1]) == int and pred[1] > nr_acs):
            current_preds.remove(pred)
        elif (type(pred[0]) == int and pred[0] <= 0) or (type(pred[1]) == int and pred[1] <= 0):
            current_preds.remove(pred)
        elif type(pred[0]) != int or type(pred[1]) != int:
            current_pred.remove(pred)

    return current_preds

In [151]:
def get_all_relations(nr_acs_l, grounds_arg, preds_arg):

    
    #grounds_tmp = copy.copy(grounds)
    #preds_tmp = copy.copy(preds)

    grounds_arg = process_lists(grounds_arg)
    preds_arg = process_lists(preds_arg)
    
    final_grounds = []
    final_preds = []

    for idx, nr_acs in enumerate(nr_acs_l):
        
        current_grounds = grounds_arg[idx]
        current_grounds_st = [[x[0], x[1]] for x in current_grounds]
        current_preds = preds_arg[idx]
        current_preds = clean_preds(nr_acs, current_preds)
        current_preds_st = [[x[0], x[1]] for x in current_preds]       

        
        for i in range(1, nr_acs+1):
            for j in range(1, nr_acs+1):
                
                if i != j:
                    
                    st = [i, j]
                    
                    if st not in current_grounds_st:
                        current_grounds.append([st[0], st[1], "None"])

                    if st not in current_preds_st:
                        current_preds.append([st[0], st[1], "None"])

        current_grounds.sort()
        current_preds.sort()
        final_grounds.append(current_grounds)
        final_preds.append(current_preds)

    return final_grounds, final_preds

In [152]:
final_grounds, final_preds = get_all_relations(nr_acs_l, grounds_tmp, preds_tmp)

In [153]:
len(final_grounds), len(final_preds)

(100, 100)

In [154]:
for i,(x,y) in enumerate(zip(final_grounds, final_preds)):
    
    if len(x) != len(y):
            
        print(i)

In [155]:
final_grounds = [x for xs in final_grounds for x in xs]

In [156]:
final_preds = [x for xs in final_preds for x in xs]

In [157]:
len(final_grounds), len(final_preds)

(5496, 5496)

In [158]:
final_grounds = [x[2] for x in final_grounds]
final_preds = [x[2] for x in final_preds]

In [159]:
print(classification_report(final_grounds, final_preds, digits=3))

              precision    recall  f1-score   support

        None      0.986     0.971     0.979      5176
      attack      0.364     0.500     0.421        24
     support      0.631     0.787     0.701       296

    accuracy                          0.959      5496
   macro avg      0.661     0.753     0.700      5496
weighted avg      0.965     0.959     0.961      5496



In [160]:
with open(f"""{OUTPUT_DIR}/classification_report_{test_set}.pickle""", 'wb') as fh:
        
        pickle.dump(classification_report(final_grounds, final_preds, output_dict=True), fh)

In [48]:
def opposite_acc(component_type):

    if component_type == "fact":
        return "value"
    elif component_type == "value":
        return "policy"
    elif component_type == "policy":
        return "value"
    elif component_type == "testimony":
        return "fact"
    elif component_type == "reference":
        return "policy"


In [49]:
def harmonize_preds_acc(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite_acc(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

In [50]:
for i,(x,y) in enumerate(zip(grounds, preds)):
    
    if len(x) != len(y):
            
        preds[i] = harmonize_preds_acc(x, y)

In [51]:
task_preds = [item for row in preds for item in row]
task_grounds = [item for row in grounds for item in row]

In [52]:
# sanity check: 
len(task_preds) == len(task_grounds)

True

## Results

In [53]:
print(classification_report(task_grounds, task_preds, digits=3))

              precision    recall  f1-score   support

        fact      0.596     0.750     0.664       132
      policy      0.883     0.889     0.886       153
   reference      1.000     1.000     1.000         1
   testimony      0.922     0.869     0.895       244
       value      0.872     0.835     0.853       496

    accuracy                          0.840      1026
   macro avg      0.855     0.868     0.860      1026
weighted avg      0.850     0.840     0.844      1026



In [54]:
with open(f"""{OUTPUT_DIR}/classification_report.pickle""", 'wb') as fh:
    
    pickle.dump(classification_report(task_grounds, task_preds, output_dict=True), fh)