In [1]:
import re
import json
import pickle

from sklearn.metrics import classification_report

In [2]:
with open("/Utilisateurs/umushtaq/am_reasoning/saved_models/pe_pipeline_Meta-Llama-3.1-8B-Instruct-bnb-4bit/pe_pipeline_results_5.pickle", 'rb') as f:
    
    results = pickle.load(f)

In [3]:
grounds = results['grounds']
predictions = results['predictions']

### Process Grounds

In [4]:
def process_grounds(grounds):
    
    arg_types = []
    rel_pairs = []
    rel_types = []
    
    for ground in grounds:
        
        args = ground.split("\n")[1].replace("'", '"')
        arg_types.append(json.loads(args)["argument_types"])
        
        rels = re.sub(r'\((\d+), (\d+)\)', r'[\1, \2]', ground.split("\n")[2])
        related_args = [tuple(pair) for pair in json.loads(rels)["related_arguments"]]
        rel_pairs.append(related_args)
        
        
        rel_t = ground.split("\n")[3].replace("'", '"')
        rel_types.append(json.loads(rel_t)["relation_types"])
        
    return arg_types, rel_pairs, rel_types

In [5]:
grounds_acc, grounds_ari, grounds_arc  = process_grounds(grounds)

### Process Predictions

In [6]:
def process_predictions(predictions):
    
    arg_types = []
    rel_pairs = []
    rel_types = []
    
    for pred in predictions:
        
        args = pred.split("\n")[1].replace("'", '"')
        arg_types.append(json.loads(args)["argument_types"])
        
        rels = re.sub(r'\((\d+), (\d+)\)', r'[\1, \2]', pred.split("\n")[2])
        related_args = [tuple(pair) for pair in json.loads(rels)["related_arguments"]]
        rel_pairs.append(related_args)
        
        
        rel_t = pred.split("\n")[3].replace("'", '"')
        rel_types.append(json.loads(rel_t)["relation_types"])
        
    return arg_types, rel_pairs, rel_types

In [7]:
predictions_acc, predictions_ari, predictions_arc  = process_predictions(predictions)

### Compute ACC metrics

In [8]:
len(grounds_acc), len(predictions_acc)

(80, 80)

In [9]:
grounds = [elem for sublist in grounds_acc for elem in sublist]
predictions = [elem for sublist in predictions_acc for elem in sublist]

In [10]:
print(classification_report(grounds, predictions, digits=3))

              precision    recall  f1-score   support

           C      0.793     0.766     0.779       304
           M      0.903     0.908     0.906       153
           P      0.932     0.942     0.937       809

    accuracy                          0.896      1266
   macro avg      0.876     0.872     0.874      1266
weighted avg      0.895     0.896     0.895      1266



### Compute ARI metrics

In [13]:
len(grounds_ari), len(predictions_ari)

(80, 80)

In [16]:
nr_acs = [len(elem) for elem in grounds_acc]

In [19]:
all_pairs_l = [[(i, j) for i in range(1, n+1) for j in range(1, n+1) if i != j] for n in nr_acs]

In [20]:
len(all_pairs_l)

80

In [None]:
def process_pairs(all_pairs_l, grounds_ari):
    
    results = []
    
    for all_pairs, reference_pairs in zip(all_pairs_l, grounds_ari):
        
        result = ["R" if pair in reference_pairs else "NR" for pair in all_pairs]
        results.append(result)
    
    return results

In [39]:
grounds_l = process_pairs(all_pairs_l, grounds_ari)

In [40]:
len(grounds_l)

80

In [41]:
predictions_l = process_pairs(all_pairs_l, predictions_ari)

In [42]:
len(predictions_l)

80

In [43]:
grounds = [elem for sublist in grounds_l for elem in sublist]
predictions = [elem for sublist in predictions_l for elem in sublist]

In [44]:
print(classification_report(grounds, predictions, digits=3))

              precision    recall  f1-score   support

          NR      0.980     0.980     0.980     19386
           R      0.514     0.507     0.511       806

    accuracy                          0.961     20192
   macro avg      0.747     0.744     0.745     20192
weighted avg      0.961     0.961     0.961     20192



### Compute ARC metrics

In [105]:
grounds_arc

[['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'A', 'A', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'A', 'S', 'S'],
 ['S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A', 'A', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'A'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'A', 'S', 'S'],
 ['S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S', 'S'],
 ['S', 'S', 'S

In [106]:
grounds_ari

[[(8, 4), (7, 4), (6, 4), (10, 11), (9, 5), (11, 5)],
 [(6, 3), (7, 4), (8, 7), (9, 4), (10, 5), (11, 5)],
 [(7, 4),
  (8, 9),
  (10, 9),
  (9, 4),
  (11, 5),
  (12, 5),
  (13, 6),
  (14, 6),
  (15, 6),
  (16, 6),
  (17, 16),
  (18, 16),
  (19, 16)],
 [(5, 3),
  (4, 3),
  (9, 6),
  (11, 13),
  (12, 13),
  (10, 6),
  (13, 6),
  (14, 7),
  (15, 7),
  (18, 8),
  (17, 16),
  (16, 8)],
 [(6, 4), (7, 4), (8, 7), (9, 7), (10, 5), (11, 5)],
 [(5, 2),
  (6, 2),
  (7, 2),
  (8, 2),
  (9, 2),
  (10, 2),
  (11, 4),
  (4, 3),
  (12, 13),
  (13, 3)],
 [(11, 10),
  (12, 10),
  (7, 4),
  (8, 4),
  (9, 4),
  (10, 4),
  (13, 5),
  (14, 5),
  (15, 5),
  (16, 5),
  (17, 6),
  (18, 6),
  (19, 6),
  (20, 6),
  (21, 6)],
 [(6, 4), (8, 9), (7, 9), (9, 5), (10, 5), (11, 5)],
 [(6, 3), (7, 4), (8, 4), (9, 5)],
 [(5, 4), (6, 4), (7, 4), (8, 4), (11, 9), (12, 13), (9, 13)],
 [(8, 3),
  (9, 3),
  (10, 3),
  (11, 3),
  (12, 13),
  (13, 4),
  (14, 4),
  (6, 5),
  (7, 5)],
 [(4, 3), (6, 7), (7, 5), (9, 10)],
 [(9, 3)

In [107]:
aric_grounds = []
    
for rel_pairs, rel_types in zip(grounds_ari, grounds_arc):
    
    result = list(zip(rel_pairs, rel_types))
    aric_grounds.append(result)

In [108]:
len(aric_grounds)

80

In [109]:
aric_grounds[0]

[((8, 4), 'S'),
 ((7, 4), 'S'),
 ((6, 4), 'S'),
 ((10, 11), 'S'),
 ((9, 5), 'S'),
 ((11, 5), 'S')]

In [110]:
aric_predictions = []
    
for rel_pairs, rel_types in zip(predictions_ari, predictions_arc):
        
    result = list(zip(rel_pairs, rel_types))
    aric_predictions.append(result)

In [111]:
len(aric_predictions)

80

In [112]:
aric_predictions[0]

[((7, 4), 'S'), ((8, 4), 'S'), ((10, 5), 'S'), ((11, 5), 'S'), ((6, 4), 'S')]

In [121]:
correct_predictions = []
correct_grounds = []

for aric_ground, aric_pred in zip(aric_grounds, aric_predictions):
    
    common_elements = [item for item in aric_pred if item in aric_ground]
            
    correct_predictions.append(common_elements)
    correct_grounds.append(aric_ground)

In [122]:
len(correct_predictions), len(correct_grounds)

(80, 80)

In [123]:
correct_predictions[0]

[((7, 4), 'S'), ((8, 4), 'S'), ((11, 5), 'S'), ((6, 4), 'S')]

In [124]:
correct_grounds[0]

[((8, 4), 'S'),
 ((7, 4), 'S'),
 ((6, 4), 'S'),
 ((10, 11), 'S'),
 ((9, 5), 'S'),
 ((11, 5), 'S')]

In [128]:
ground_types = [[item[1] for item in sublist] for sublist in correct_grounds]

In [130]:
len(ground_types)

80

In [131]:
prediction_types = [[item[1] for item in sublist] for sublist in correct_predictions]

In [132]:
len(prediction_types)

80

In [139]:
bad_idx = []

for idx, (i,j) in enumerate(zip(ground_types, prediction_types)):
    
    if len(i) != len(j):
        print(idx, len(i), len(j))
        bad_idx.append(idx)

0 6 4
1 6 4
2 13 8
3 12 2
4 6 3
5 10 4
6 15 13
7 6 3
9 7 5
10 9 8
11 4 1
12 8 7
13 8 2
14 13 2
15 6 5
16 4 2
17 10 5
18 8 5
19 11 7
20 5 1
21 13 3
22 9 3
23 7 6
24 8 5
25 10 2
26 16 1
28 6 2
29 10 4
30 20 5
31 17 13
32 15 0
33 12 6
34 10 6
35 5 4
36 10 0
38 17 8
39 13 10
40 10 5
41 13 11
42 10 2
43 14 7
44 13 8
45 8 5
46 11 5
47 13 5
48 6 4
49 8 4
50 17 5
51 10 6
52 15 11
53 10 5
54 8 1
55 12 10
56 10 0
57 8 1
58 12 10
59 7 5
60 16 2
61 10 5
62 9 8
63 15 5
64 10 2
65 8 0
66 11 5
67 14 9
69 10 5
70 13 0
71 7 6
72 9 0
73 10 4
74 12 8
75 16 8
76 5 1
79 11 6


In [140]:
len(bad_idx)

74

In [None]:
### Maybe it makes no sense to do ari and arc separately.