In [23]:
import os
import ast
import sys
import json
import torch
import pickle
import inspect
import argparse
import subprocess

# sys.path.append('../')

import pandas as pd
from tqdm import tqdm
from pathlib import Path

from sklearn.metrics import classification_report

In [13]:
current_dir = Path(os.path.dirname(os.path.abspath("__file__"))).as_posix()
joint_dir = Path(current_dir).parent.absolute().as_posix()
# parent_dir = Path(cdcp_dir).parent.absolute().as_posix()

In [14]:
current_dir

'/nfs/scratch/umushtaq/coling_2025/joint/notebooks'

In [15]:
joint_dir

'/nfs/scratch/umushtaq/coling_2025/joint'

In [16]:
BASE_MODEL = "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit"
TASK = "mega"

In [33]:
OUTPUT_DIR = os.path.join(joint_dir, "finetuned_models", f"""{TASK}_{BASE_MODEL.split("/")[1]}""")
NB_EPOCHS = 5
test_set = "mix"

In [34]:
OUTPUT_DIR

'/nfs/scratch/umushtaq/coling_2025/joint/finetuned_models/mega_Meta-Llama-3.1-8B-Instruct-bnb-4bit'

### AbstRCT Dataset post-processing 

In [35]:
with open(os.path.join(OUTPUT_DIR, f"""abstRCT_{TASK}_results_{NB_EPOCHS}_{test_set}.pickle"""), 'rb') as fh:
    
    results = pickle.load(fh)

In [None]:
results

In [37]:
def opposite_acc(component_type):

    if component_type == "Claim":
        return "Premise"
    elif component_type == "Premise":
        return "Claim"

def harmonize_preds_acc(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite_acc(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

def post_process_acc(results):

    grounds = results["ground_truths"]
    preds = results["predictions"]
    
    grounds = [json.loads(x)["component_types"] for x in grounds]  
    
    preds = [x["content"] for x in preds]    
    preds = [json.loads(x)["component_types"] for x in preds]
    
    for i,(x,y) in enumerate(zip(grounds, preds)):
    
        if len(x) != len(y):
            
            preds[i] = harmonize_preds_acc(x, y)
            
    task_preds = [item for row in preds for item in row]
    task_grounds = [item for row in grounds for item in row]

    return task_grounds, task_preds

In [38]:
task_grounds, task_preds = post_process_acc(results)

In [39]:
print(classification_report(task_grounds, task_preds, digits=3))

              precision    recall  f1-score   support

       Claim      0.934     0.939     0.936       212
     Premise      0.967     0.965     0.966       397

    accuracy                          0.956       609
   macro avg      0.951     0.952     0.951       609
weighted avg      0.956     0.956     0.956       609



### CDCP Post-Process

In [40]:
with open(os.path.join(OUTPUT_DIR, f"""CDCP_{TASK}_results_{NB_EPOCHS}.pickle"""), 'rb') as fh:
    
    results = pickle.load(fh)

In [42]:
def opposite_acc(component_type):

    if component_type == "fact":
        return "value"
    elif component_type == "value":
        return "policy"
    elif component_type == "policy":
        return "value"
    elif component_type == "testimony":
        return "fact"
    elif component_type == "reference":
        return "policy"

def harmonize_preds_acc(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite_acc(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

def post_process_acc(results):

    grounds = results["ground_truths"]
    preds = results["predictions"]
    
    grounds = [json.loads(x)["component_types"] for x in grounds]  
    
    preds = [x["content"] for x in preds]    
    preds = [json.loads(x)["component_types"] for x in preds]
    
    for i,(x,y) in enumerate(zip(grounds, preds)):
    
        if len(x) != len(y):
            
            preds[i] = harmonize_preds_acc(x, y)
            
    task_preds = [item for row in preds for item in row]
    task_grounds = [item for row in grounds for item in row]

    return task_grounds, task_preds

In [43]:
task_grounds, task_preds = post_process_acc(results)

In [44]:
print(classification_report(task_grounds, task_preds, digits=3))

              precision    recall  f1-score   support

        fact      0.615     0.750     0.676       132
      policy      0.890     0.902     0.896       153
   reference      0.500     1.000     0.667         1
   testimony      0.932     0.836     0.881       244
       value      0.859     0.847     0.853       496

    accuracy                          0.840      1026
   macro avg      0.759     0.867     0.795      1026
weighted avg      0.849     0.840     0.843      1026



### PE Post-processing

In [46]:
with open(os.path.join(OUTPUT_DIR, f"""PE_{TASK}_results_{NB_EPOCHS}.pickle"""), 'rb') as fh:
    
    results = pickle.load(fh)

In [47]:
def opposite_acc(component_type):

    if component_type == "Premise":
        return "Claim"
    elif component_type == "Claim":
        return "Premise"
    elif component_type == "MajorClaim":
        return "Claim"

def harmonize_preds_acc(grounds, preds):

    l1, l2 = len(preds), len(grounds)
    if l1 < l2:
        diff = l2 - l1
        preds = preds + [opposite_acc(x) for x in grounds[l1:]]
    else:
        preds = preds[:l2]
        
    return preds 

def post_process_acc(results):

    grounds = results["ground_truths"]
    preds = results["predictions"]
    
    grounds = [json.loads(x)["component_types"] for x in grounds]  
    
    preds = [x["content"] for x in preds]    
    preds = [json.loads(x)["component_types"] for x in preds]
    
    for i,(x,y) in enumerate(zip(grounds, preds)):
    
        if len(x) != len(y):
            
            preds[i] = harmonize_preds_acc(x, y)
            
    task_preds = [item for row in preds for item in row]
    task_grounds = [item for row in grounds for item in row]

    return task_grounds, task_preds

In [48]:
task_grounds, task_preds = post_process_acc(results)

In [49]:
print(classification_report(task_grounds, task_preds, digits=3))

              precision    recall  f1-score   support

       Claim      0.771     0.795     0.783       283
  MajorClaim      0.986     0.948     0.967       154
     Premise      0.922     0.919     0.920       724

    accuracy                          0.892      1161
   macro avg      0.893     0.887     0.890      1161
weighted avg      0.894     0.892     0.893      1161

