In [1]:
import re
import json
import pickle

from sklearn.metrics import classification_report

In [15]:
with open("/Utilisateurs/umushtaq/am_reasoning/saved_models/pe_pipeline_prompt3_Meta-Llama-3.1-70B-Instruct-bnb-4bit/pe_pipeline_results_5.pickle", 'rb') as f:
    
    results = pickle.load(f)

In [16]:
grounds = results['grounds']
predictions = results['predictions']

### Process Grounds

In [17]:
def process_grounds(grounds):
    
    arg_types = []
    #rel_pairs = []
    rel_types = []
    
    for grounds in grounds:
        
        args = grounds.split("\n")[1].replace("'", '"')
        arg_types.append(json.loads(args)["argument_types"])
        
        rels = re.sub(r'\((\d+), (\d+)\)', r'[\1, \2]', grounds.split("\n")[2])
        rels = eval(rels)["relation_types"]
        rels = [(tuple(pair), label) for pair, label in rels]
        rel_types.append(rels)
        
    return arg_types, rel_types

In [18]:
grounds_acc, grounds_aric  = process_grounds(grounds)

In [19]:
len(grounds_acc), len(grounds_aric)

(80, 80)

### Process Predictions

In [20]:
def process_predictions(predictions):
    
    arg_types = []
    rel_types = []
    
    for pred in predictions:
        
        args = pred.split("\n")[1].replace("'", '"')
        arg_types.append(json.loads(args)["argument_types"])
        
        rels = re.sub(r'\((\d+), (\d+)\)', r'[\1, \2]', pred.split("\n")[2])
        rels = eval(rels)["relation_types"]
        rels = [(tuple(pair), label) for pair, label in rels]
        rel_types.append(rels)
        
    return arg_types, rel_types

In [21]:
predictions_acc, predictions_aric  = process_predictions(predictions)

In [22]:
len(predictions_acc), len(predictions_aric)

(80, 80)

### Compute ACC metrics

In [23]:
bad_idx = []

for idx, (g, p) in enumerate(zip(grounds_acc, predictions_acc)):
    
    if len(p) != len(g):
        print(idx)
        bad_idx.append(idx)
        #predictions_acc[idx] = p[:len(g)]

In [24]:
grounds_acc = [elem for idx, elem in enumerate(grounds_acc) if idx not in bad_idx]
predictions_acc = [elem for idx, elem in enumerate(predictions_acc) if idx not in bad_idx]

In [25]:
grounds = [elem for sublist in grounds_acc for elem in sublist]
predictions = [elem for sublist in predictions_acc for elem in sublist]

In [26]:
len(grounds), len(predictions)

(1266, 1266)

In [27]:
print(classification_report(grounds, predictions, digits=3))

              precision    recall  f1-score   support

           C      0.852     0.832     0.842       304
           M      0.980     0.980     0.980       153
           P      0.941     0.949     0.945       809

    accuracy                          0.925      1266
   macro avg      0.924     0.921     0.923      1266
weighted avg      0.924     0.925     0.925      1266



### Compute ARIC metrics

In [15]:
len(grounds_aric), len(predictions_aric)

(80, 80)

In [16]:
grounds_aric[70]

[((3, 9), 'S'),
 ((4, 5), 'S'),
 ((5, 9), 'S'),
 ((6, 9), 'S'),
 ((7, 6), 'A'),
 ((8, 7), 'A'),
 ((11, 10), 'S'),
 ((12, 11), 'A'),
 ((13, 12), 'A'),
 ((14, 10), 'S'),
 ((16, 17), 'S'),
 ((17, 10), 'S'),
 ((15, 10), 'S')]

In [17]:
predictions_aric[0]

[((8, 4), 'S'),
 ((7, 4), 'S'),
 ((6, 4), 'S'),
 ((10, 9), 'S'),
 ((9, 5), 'S'),
 ((11, 5), 'S')]

In [18]:
nr_acs = [len(elem) for elem in grounds_acc]

In [19]:
def build_grounds(grounds, n_acs):
    
    ground_pairs = [elem[0] for elem in grounds]
    all_triples = [(i, j, "NR") for i in range(1, n_acs+1) for j in range(1, n_acs+1) if (i != j and (i, j) not in ground_pairs)]
    
    grounds_t = [(elem[0][0], elem[0][1], elem[1]) for elem in grounds]
    
    return all_triples + grounds_t

In [20]:
grounds_triples = []

for ground, n_acs in zip(grounds_aric, nr_acs):
    
    grounds_triples.append(build_grounds(ground, n_acs))

In [21]:
len(grounds_triples)

80

In [22]:
grounds_triples[60]

[(1, 2, 'NR'),
 (1, 3, 'NR'),
 (1, 4, 'NR'),
 (1, 5, 'NR'),
 (1, 6, 'NR'),
 (1, 7, 'NR'),
 (1, 8, 'NR'),
 (1, 9, 'NR'),
 (1, 10, 'NR'),
 (1, 11, 'NR'),
 (1, 12, 'NR'),
 (1, 13, 'NR'),
 (1, 14, 'NR'),
 (1, 15, 'NR'),
 (1, 16, 'NR'),
 (1, 17, 'NR'),
 (1, 18, 'NR'),
 (1, 19, 'NR'),
 (1, 20, 'NR'),
 (1, 21, 'NR'),
 (1, 22, 'NR'),
 (1, 23, 'NR'),
 (1, 24, 'NR'),
 (2, 1, 'NR'),
 (2, 3, 'NR'),
 (2, 4, 'NR'),
 (2, 5, 'NR'),
 (2, 6, 'NR'),
 (2, 7, 'NR'),
 (2, 8, 'NR'),
 (2, 9, 'NR'),
 (2, 10, 'NR'),
 (2, 11, 'NR'),
 (2, 12, 'NR'),
 (2, 13, 'NR'),
 (2, 14, 'NR'),
 (2, 15, 'NR'),
 (2, 16, 'NR'),
 (2, 17, 'NR'),
 (2, 18, 'NR'),
 (2, 19, 'NR'),
 (2, 20, 'NR'),
 (2, 21, 'NR'),
 (2, 22, 'NR'),
 (2, 23, 'NR'),
 (2, 24, 'NR'),
 (3, 1, 'NR'),
 (3, 2, 'NR'),
 (3, 4, 'NR'),
 (3, 5, 'NR'),
 (3, 6, 'NR'),
 (3, 7, 'NR'),
 (3, 8, 'NR'),
 (3, 9, 'NR'),
 (3, 10, 'NR'),
 (3, 11, 'NR'),
 (3, 12, 'NR'),
 (3, 13, 'NR'),
 (3, 14, 'NR'),
 (3, 15, 'NR'),
 (3, 16, 'NR'),
 (3, 17, 'NR'),
 (3, 18, 'NR'),
 (3, 19, 'NR'),


In [23]:
def build_predictions(ground, pred, n_acs):
    
    predictions = []
    
    for p in pred:
        
        if p in ground:
            
            predictions.append(p)
    
    prediction_pairs = [elem[0] for elem in predictions]
    #all_triples = [(i, j, "NR") for i in range(1, n_acs+1) for j in range(1, n_acs+1) if (i != j or (i, j) not in prediction_pairs)]
    all_triples = [(i, j, "NR") for i in range(1, n_acs+1) for j in range(1, n_acs+1) if i != j and (i, j) not in prediction_pairs]

    
    preds_t = [(elem[0][0], elem[0][1], elem[1]) for elem in predictions]
    #print(preds_t)
    
    return all_triples + preds_t

In [24]:
prediction_triples = []

for ground, pred, n_acs in zip(grounds_aric, predictions_aric, nr_acs):
    
    prediction_triples.append(build_predictions(ground, pred, n_acs))
    #break

In [25]:
prediction_triples[0]

[(1, 2, 'NR'),
 (1, 3, 'NR'),
 (1, 4, 'NR'),
 (1, 5, 'NR'),
 (1, 6, 'NR'),
 (1, 7, 'NR'),
 (1, 8, 'NR'),
 (1, 9, 'NR'),
 (1, 10, 'NR'),
 (1, 11, 'NR'),
 (2, 1, 'NR'),
 (2, 3, 'NR'),
 (2, 4, 'NR'),
 (2, 5, 'NR'),
 (2, 6, 'NR'),
 (2, 7, 'NR'),
 (2, 8, 'NR'),
 (2, 9, 'NR'),
 (2, 10, 'NR'),
 (2, 11, 'NR'),
 (3, 1, 'NR'),
 (3, 2, 'NR'),
 (3, 4, 'NR'),
 (3, 5, 'NR'),
 (3, 6, 'NR'),
 (3, 7, 'NR'),
 (3, 8, 'NR'),
 (3, 9, 'NR'),
 (3, 10, 'NR'),
 (3, 11, 'NR'),
 (4, 1, 'NR'),
 (4, 2, 'NR'),
 (4, 3, 'NR'),
 (4, 5, 'NR'),
 (4, 6, 'NR'),
 (4, 7, 'NR'),
 (4, 8, 'NR'),
 (4, 9, 'NR'),
 (4, 10, 'NR'),
 (4, 11, 'NR'),
 (5, 1, 'NR'),
 (5, 2, 'NR'),
 (5, 3, 'NR'),
 (5, 4, 'NR'),
 (5, 6, 'NR'),
 (5, 7, 'NR'),
 (5, 8, 'NR'),
 (5, 9, 'NR'),
 (5, 10, 'NR'),
 (5, 11, 'NR'),
 (6, 1, 'NR'),
 (6, 2, 'NR'),
 (6, 3, 'NR'),
 (6, 5, 'NR'),
 (6, 7, 'NR'),
 (6, 8, 'NR'),
 (6, 9, 'NR'),
 (6, 10, 'NR'),
 (6, 11, 'NR'),
 (7, 1, 'NR'),
 (7, 2, 'NR'),
 (7, 3, 'NR'),
 (7, 5, 'NR'),
 (7, 6, 'NR'),
 (7, 8, 'NR'),
 (7, 9, 'NR')

In [26]:
bad_idx = []

for idx, (g, p) in enumerate(zip(grounds_triples, prediction_triples)):
    
    if len(p) != len(g):
        print(idx, len(g), len(p))
        bad_idx.append(idx)

60 553 552
70 242 240


In [27]:
nr_acs[70]

16

In [28]:
grounds_triples[70]

[(1, 2, 'NR'),
 (1, 3, 'NR'),
 (1, 4, 'NR'),
 (1, 5, 'NR'),
 (1, 6, 'NR'),
 (1, 7, 'NR'),
 (1, 8, 'NR'),
 (1, 9, 'NR'),
 (1, 10, 'NR'),
 (1, 11, 'NR'),
 (1, 12, 'NR'),
 (1, 13, 'NR'),
 (1, 14, 'NR'),
 (1, 15, 'NR'),
 (1, 16, 'NR'),
 (2, 1, 'NR'),
 (2, 3, 'NR'),
 (2, 4, 'NR'),
 (2, 5, 'NR'),
 (2, 6, 'NR'),
 (2, 7, 'NR'),
 (2, 8, 'NR'),
 (2, 9, 'NR'),
 (2, 10, 'NR'),
 (2, 11, 'NR'),
 (2, 12, 'NR'),
 (2, 13, 'NR'),
 (2, 14, 'NR'),
 (2, 15, 'NR'),
 (2, 16, 'NR'),
 (3, 1, 'NR'),
 (3, 2, 'NR'),
 (3, 4, 'NR'),
 (3, 5, 'NR'),
 (3, 6, 'NR'),
 (3, 7, 'NR'),
 (3, 8, 'NR'),
 (3, 10, 'NR'),
 (3, 11, 'NR'),
 (3, 12, 'NR'),
 (3, 13, 'NR'),
 (3, 14, 'NR'),
 (3, 15, 'NR'),
 (3, 16, 'NR'),
 (4, 1, 'NR'),
 (4, 2, 'NR'),
 (4, 3, 'NR'),
 (4, 6, 'NR'),
 (4, 7, 'NR'),
 (4, 8, 'NR'),
 (4, 9, 'NR'),
 (4, 10, 'NR'),
 (4, 11, 'NR'),
 (4, 12, 'NR'),
 (4, 13, 'NR'),
 (4, 14, 'NR'),
 (4, 15, 'NR'),
 (4, 16, 'NR'),
 (5, 1, 'NR'),
 (5, 2, 'NR'),
 (5, 3, 'NR'),
 (5, 4, 'NR'),
 (5, 6, 'NR'),
 (5, 7, 'NR'),
 (5, 8, 'NR'

In [29]:
prediction_triples[70]

[(1, 2, 'NR'),
 (1, 3, 'NR'),
 (1, 4, 'NR'),
 (1, 5, 'NR'),
 (1, 6, 'NR'),
 (1, 7, 'NR'),
 (1, 8, 'NR'),
 (1, 9, 'NR'),
 (1, 10, 'NR'),
 (1, 11, 'NR'),
 (1, 12, 'NR'),
 (1, 13, 'NR'),
 (1, 14, 'NR'),
 (1, 15, 'NR'),
 (1, 16, 'NR'),
 (2, 1, 'NR'),
 (2, 3, 'NR'),
 (2, 4, 'NR'),
 (2, 5, 'NR'),
 (2, 6, 'NR'),
 (2, 7, 'NR'),
 (2, 8, 'NR'),
 (2, 9, 'NR'),
 (2, 10, 'NR'),
 (2, 11, 'NR'),
 (2, 12, 'NR'),
 (2, 13, 'NR'),
 (2, 14, 'NR'),
 (2, 15, 'NR'),
 (2, 16, 'NR'),
 (3, 1, 'NR'),
 (3, 2, 'NR'),
 (3, 4, 'NR'),
 (3, 5, 'NR'),
 (3, 6, 'NR'),
 (3, 7, 'NR'),
 (3, 8, 'NR'),
 (3, 9, 'NR'),
 (3, 10, 'NR'),
 (3, 11, 'NR'),
 (3, 12, 'NR'),
 (3, 13, 'NR'),
 (3, 14, 'NR'),
 (3, 15, 'NR'),
 (3, 16, 'NR'),
 (4, 1, 'NR'),
 (4, 2, 'NR'),
 (4, 3, 'NR'),
 (4, 5, 'NR'),
 (4, 6, 'NR'),
 (4, 7, 'NR'),
 (4, 8, 'NR'),
 (4, 9, 'NR'),
 (4, 10, 'NR'),
 (4, 11, 'NR'),
 (4, 12, 'NR'),
 (4, 13, 'NR'),
 (4, 14, 'NR'),
 (4, 15, 'NR'),
 (4, 16, 'NR'),
 (5, 1, 'NR'),
 (5, 2, 'NR'),
 (5, 3, 'NR'),
 (5, 4, 'NR'),
 (5, 6, 'NR'

In [30]:
bad_idx

[60, 70]

In [31]:
grounds_triples = [val for i, val in enumerate(grounds_triples) if i not in bad_idx]
prediction_triples = [val for i, val in enumerate(prediction_triples) if i not in bad_idx]

In [32]:
grounds_l = [elem for sublist in grounds_triples for elem in sublist]
predictions_l = [elem for sublist in prediction_triples for elem in sublist]

In [33]:
len(grounds_l), len(predictions_l)

(19400, 19400)

In [34]:
grounds = [elem[2] for elem in grounds_l]
predictions = [elem[2] for elem in predictions_l]

In [35]:
print(classification_report(grounds, predictions, digits=3))

              precision    recall  f1-score   support

           A      0.429     0.237     0.305        38
          NR      0.984     1.000     0.992     18620
           S      0.966     0.604     0.743       742

    accuracy                          0.983     19400
   macro avg      0.793     0.614     0.680     19400
weighted avg      0.983     0.983     0.981     19400



### Compute ARC metrics

In [None]:
grounds_ari

[['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'A', 'A', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'A', 'S', 'S'],
 ['S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A', 'A', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'A', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S

In [106]:
grounds_ari

[[(8, 4), (7, 4), (6, 4), (10, 11), (9, 5), (11, 5)],
 [(6, 3), (7, 4), (8, 7), (9, 4), (10, 5), (11, 5)],
 [(7, 4),
  (8, 9),
  (10, 9),
  (9, 4),
  (11, 5),
  (12, 5),
  (13, 6),
  (14, 6),
  (15, 6),
  (16, 6),
  (17, 16),
  (18, 16),
  (19, 16)],
 [(5, 3),
  (4, 3),
  (9, 6),
  (11, 13),
  (12, 13),
  (10, 6),
  (13, 6),
  (14, 7),
  (15, 7),
  (18, 8),
  (17, 16),
  (16, 8)],
 [(6, 4), (7, 4), (8, 7), (9, 7), (10, 5), (11, 5)],
 [(5, 2),
  (6, 2),
  (7, 2),
  (8, 2),
  (9, 2),
  (10, 2),
  (11, 4),
  (4, 3),
  (12, 13),
  (13, 3)],
 [(11, 10),
  (12, 10),
  (7, 4),
  (8, 4),
  (9, 4),
  (10, 4),
  (13, 5),
  (14, 5),
  (15, 5),
  (16, 5),
  (17, 6),
  (18, 6),
  (19, 6),
  (20, 6),
  (21, 6)],
 [(6, 4), (8, 9), (7, 9), (9, 5), (10, 5), (11, 5)],
 [(6, 3), (7, 4), (8, 4), (9, 5)],
 [(5, 4), (6, 4), (7, 4), (8, 4), (11, 9), (12, 13), (9, 13)],
 [(8, 3),
  (9, 3),
  (10, 3),
  (11, 3),
  (12, 13),
  (13, 4),
  (14, 4),
  (6, 5),
  (7, 5)],
 [(4, 3), (6, 7), (7, 5), (9, 10)],
 [(9, 3)

In [107]:
aric_grounds = []
    
for rel_pairs, rel_types in zip(grounds_ari, grounds_arc):
    
    result = list(zip(rel_pairs, rel_types))
    aric_grounds.append(result)

In [108]:
len(aric_grounds)

80

In [109]:
aric_grounds[0]

[((8, 4), 'S'),
 ((7, 4), 'S'),
 ((6, 4), 'S'),
 ((10, 11), 'S'),
 ((9, 5), 'S'),
 ((11, 5), 'S')]

In [110]:
aric_predictions = []
    
for rel_pairs, rel_types in zip(predictions_ari, predictions_arc):
        
    result = list(zip(rel_pairs, rel_types))
    aric_predictions.append(result)

In [111]:
len(aric_predictions)

80

In [112]:
aric_predictions[0]

[((7, 4), 'S'), ((8, 4), 'S'), ((10, 5), 'S'), ((11, 5), 'S'), ((6, 4), 'S')]

In [172]:
len(nr_acs)

80

In [175]:
len(aric_grounds), len(aric_predictions)

(80, 80)

In [176]:
aric_grounds[0]

[((8, 4), 'S'),
 ((7, 4), 'S'),
 ((6, 4), 'S'),
 ((10, 11), 'S'),
 ((9, 5), 'S'),
 ((11, 5), 'S')]

In [177]:
aric_predictions[0]

[((7, 4), 'S'), ((8, 4), 'S'), ((10, 5), 'S'), ((11, 5), 'S'), ((6, 4), 'S')]

In [178]:
g_0 = set(aric_grounds[0])

In [179]:
g_0

{((6, 4), 'S'),
 ((7, 4), 'S'),
 ((8, 4), 'S'),
 ((9, 5), 'S'),
 ((10, 11), 'S'),
 ((11, 5), 'S')}

In [180]:
p_0 = set(aric_predictions[0])

In [181]:
p_0

{((6, 4), 'S'), ((7, 4), 'S'), ((8, 4), 'S'), ((10, 5), 'S'), ((11, 5), 'S')}

In [182]:
z = g_0.intersection(p_0)

In [183]:
z

{((6, 4), 'S'), ((7, 4), 'S'), ((8, 4), 'S'), ((11, 5), 'S')}

In [121]:
correct_predictions = []
correct_grounds = []

for aric_ground, aric_pred in zip(aric_grounds, aric_predictions):
    
    common_elements = [item for item in aric_pred if item in aric_ground]
            
    correct_predictions.append(common_elements)
    correct_grounds.append(aric_ground)

In [122]:
len(correct_predictions), len(correct_grounds)

(80, 80)

In [123]:
correct_predictions[0]

[((7, 4), 'S'), ((8, 4), 'S'), ((11, 5), 'S'), ((6, 4), 'S')]

In [124]:
correct_grounds[0]

[((8, 4), 'S'),
 ((7, 4), 'S'),
 ((6, 4), 'S'),
 ((10, 11), 'S'),
 ((9, 5), 'S'),
 ((11, 5), 'S')]

In [128]:
ground_types = [[item[1] for item in sublist] for sublist in correct_grounds]

In [130]:
len(ground_types)

80

In [131]:
prediction_types = [[item[1] for item in sublist] for sublist in correct_predictions]

In [132]:
len(prediction_types)

80

In [173]:
ground_types

[['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'A', 'A', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'A', 'S', 'S'],
 ['S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A', 'A', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'A', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S

In [174]:
prediction_types

[['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S'],
 ['S', 'S', 'S'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S'],
 ['S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S'],
 ['S', 'S'],
 ['S', 'S', 'S', 'S', 'A'],
 ['S', 'S'],
 ['S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S'],
 ['S', 'S', 'S'],
 ['S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S'],
 ['S', 'S'],
 ['S'],
 ['S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S',
  'S'],
 ['S', 'S'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'A', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A'],
 [],
 ['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S

In [187]:
bad_idx = []
oppo_d = {"S": "A", "A": "S"}

for idx, (i,j) in enumerate(zip(ground_types, prediction_types)):
    
    if len(i) > len(j):
        prediction_types[idx] = prediction_types[idx] + [oppo_d[elem] for elem in ground_types[idx][len(j):]]
        
    elif len(j) > len(i):        
        prediction_types[idx] = prediction_types[idx][:len(i)]
        #print(idx, len(i), len(j))
        #bad_idx.append(idx)

In [188]:
grounds = [elem for sublist in ground_types for elem in sublist]
predictions = [elem for sublist in prediction_types for elem in sublist]

In [189]:
len(grounds), len(predictions)

(809, 809)

In [190]:
print(classification_report(grounds, predictions, digits=3))

              precision    recall  f1-score   support

           A      0.005     0.048     0.009        42
           S      0.904     0.493     0.638       767

    accuracy                          0.470       809
   macro avg      0.455     0.270     0.324       809
weighted avg      0.858     0.470     0.605       809



In [None]:
### Maybe it makes no sense to do ari and arc separately.

In [None]:
{"relation_types": {"Paragraph1" : [],
                "Paragraph2" : [(4, 3, 'S'), (5, 3, 'S'), (6, 3, 'S')],
                "Paragraph_3" : [(8, 7, 'S')],
                "Paragraph_4" : [(10, 11, 'S'), (9, 11, 'S')]}}