In [61]:
from utils import *
import pickle
import numpy as np
from coval.coval.conll.reader import get_coref_infos
from coval.coval.eval.evaluator import evaluate_documents as evaluate
from coval.coval.eval.evaluator import muc, b_cubed, ceafe, lea
import torch
from tqdm import tqdm
from utils import cluster
import pandas as pd
import random
import collections


def read(key, response):
    return get_coref_infos('%s' % key, '%s' % response,
            False, False, True)


def predict_with_inner_monologue(parallel_model, dev_ab,  device, batch_size):
    n = dev_ab['input_ids'].shape[0]
    indices = list(range(n))
    all_scores_ab = []
    all_scores_ba = []
    with torch.no_grad():
        for i in tqdm(range(0, n, batch_size), desc='Predicting'):
            batch_indices = indices[i: i + batch_size]
            scores_ab = forward_ab(parallel_model, dev_ab, device, batch_indices, ann_attn_hidden_logits = False)
            #scores_ba = forward_ab(parallel_model, dev_ba, device, batch_indices)
            all_scores_ab.append(scores_ab.detach().cpu())
            #all_scores_ba.append(scores_ba.detach().cpu())

    return torch.cat(all_scores_ab) 

def get_coreference_scores(dataset_folder, evt_mention_map, all_mention_pairs, dataset, split, heu, similarities, dpos_score_map, out_name, threshold):
    curr_mentions = sorted(evt_mention_map.keys())
    curr_gold_cluster_map = [(men, evt_mention_map[men]) for men in curr_mentions]
    gold_key_file = dataset_folder + f'/evt_gold_{split}.keyfile'
    generate_key_file(curr_gold_cluster_map, 'evt', dataset_folder, gold_key_file)

    w_dpos_sims = []
    for p, sim in zip(all_mention_pairs, similarities):
        if tuple(p) in dpos_score_map:
            w_dpos_sims.append(dpos_score_map[p][0])
        elif (p[1], p[0]) in dpos_score_map:
            w_dpos_sims.append(np.mean(dpos_score_map[p[0]]))
        else:
            w_dpos_sims.append(sim)
   

    mid2cluster = cluster(curr_mentions, all_mention_pairs, w_dpos_sims, threshold)
    system_key_file = dataset_folder + f'/evt_gold_dpos_{out_name}.keyfile'
    generate_key_file(mid2cluster.items(), 'evt', dataset_folder, system_key_file)
    doc = read(gold_key_file, system_key_file)

    mr, mp, mf = np.round(np.round(evaluate(doc, muc), 3) * 100, 1)
    br, bp, bf = np.round(np.round(evaluate(doc, b_cubed), 3) * 100, 1)
    cr, cp, cf = np.round(np.round(evaluate(doc, ceafe), 3) * 100, 1)
    lr, lp, lf = np.round(np.round(evaluate(doc, lea), 3) * 100, 1)
    

    conf = np.round((mf + bf + cf) / 3, 1)
    print(dataset, split)
    final_frame = [mr, mp, mf,br, bp, bf,cr, cp, cf,  lr, lp, lf,conf ]
    result_string = f'& {heu} && {mr}  & {mp} & {mf} && {br} & {bp} & {bf} && {cr} & {cp} & {cf} && {lr} & {lp} & {lf} && {conf} \\'

    print(result_string)
    return conf, result_string, final_frame



def get_final_scores(dataset, split, dpos_score_map, heu='lh_llama', threshold=0.5):
    dataset_folder = f'./datasets/{dataset}/'
    if dataset == 'ldc':
        print("loading dataset", dataset)
        evt_mention_map = pickle.load(open(dataset_folder + '/mention_gold_key.pkl', 'rb'))
        mps, mps_trans = pickle.load(open(dataset_folder + f'/{heu}/mp_mp_t_test_0.03_new.pkl', 'rb'))
        _, _, _, fns = mps_trans
        tps, fps, tns, fns_nt = mps
      
        all_mention_pairs = tps + fps
        heu_predictions = np.array([1] * len(tps) + [0] * len(fps))

    else:
        evt_mention_map = pickle.load(open(dataset_folder + '/mention_gold_key.pkl', 'rb'))
        mps, mps_trans = pickle.load(open(f'./datasets/{dataset}/{heu}/mp_mp_t_{split}.pkl', 'rb'))
        _, _, _, fns = mps_trans
        tps, fps, tns, fns_nt = mps
        print(len(tps), len(fps), len(fns))
        all_mention_pairs = tps + fps
        heu_predictions = np.array([1] * len(tps) + [0] * len(fps))

    conf, final_scores, final_frame = get_coreference_scores(dataset_folder, evt_mention_map, all_mention_pairs, dataset, split, heu, heu_predictions, dpos_score_map, out_name=heu, threshold=threshold)
    return conf,final_scores, final_frame



def get_random_scores(k):
   
    random_scores = []
    for _ in range(k):
        k = random.randint(0, 1)# decide on a k each time the loop runs
        random_scores.append(k) 
    return random_scores

def get_gpt_scores_aida(): 
    output = []
    output_dict = {}
    dataset = 'ldc'
    dataset_folder = f'./datasets/{dataset}/'
    test_mp_mpt, _ = pickle.load(open(dataset_folder + '/lh/mp_mp_t_test_0.03_new.pkl', 'rb'))
    tps_test, fps_test, _,fns_test = test_mp_mpt
    test_pairs= list(tps_test + fps_test )
    test_labels = [1] * len(tps_test) + [0] * len(fps_test)
    bad_idx = pickle.load(open(dataset_folder + f"/bad_test_indices_{dataset}.pkl", 'rb')) #indices with faulty mention triggers

    test_pairs = [y for x, y in enumerate(test_pairs) if x not in bad_idx]
    test_labels = [y for x, y in enumerate(test_labels) if x not in bad_idx]

    llama_eval_map  = pickle.load(open(dataset_folder + f"/ldc_1word.pickle", 'rb'))

    gpt_scores = []
    for x in test_pairs:
        if x in llama_eval_map.keys():
            gpt_scores.append(llama_eval_map[x])
    counter_clean = collections.Counter(gpt_scores)
    yes_list = ['Yes', 'Yes.', 'yes']

    #now preprocess the GPT scores 
    for i, x in enumerate(gpt_scores):
        output.append(x)
        if x in yes_list:
            output_dict[i] = 1
        else:
            output_dict[i] = 0
    return list(output_dict.values())

def save_pair_info(pairs, mention_map, file_name):
    sentence_pairs = []
    for m1, m2 in pairs:
        mention1 = mention_map[m1]
        mention2 = mention_map[m2]
        sentence_pairs.append((m1, m2, mention1['gold_cluster'], mention2['gold_cluster'], mention1['bert_sentence'], mention2['bert_sentence']))


    m1, m2, c1, c2, first, second = zip(*sentence_pairs)
    df = pd.DataFrame({'m1': m1, 'm2': m2, 'c1':c1, 'c2':c2, 'first': first, 'second': second})
    df.to_csv(file_name)


def mention_pair_analysis(dataset, split, heu):
    from collections import defaultdict
    dataset_folder = f'./datasets/{dataset}/'
    mention_map = pickle.load(open(dataset_folder + "/mention_map.pkl", 'rb'))
    evt_mention_map = {m_id: m for m_id, m in mention_map.items() if m['men_type'] == 'evt' and m['split'] == split}
    dpos_map = get_dpos(dataset, heu, split)
    (tps, fps, tns, fns), (tps_t, fps_t, tns_t, fns_t) = lh_split(heu, dataset, split, 0.05)

    curr_mentions = list(evt_mention_map.keys())
    mid2int = {m: i for i, m in enumerate(curr_mentions)}

    tps_t = set([tuple(sorted(p)) for p in tps])

    p_pos = tps + fps

    similarities = np.array([np.mean(dpos_map[p]) if p in p_pos else 0 for p in p_pos])

    true_predictions = np.array([1]*len(tps) + [0]*len(fps))
    predictions = similarities > 0.5

    hard_fps = np.logical_and(predictions, np.logical_not(true_predictions)).nonzero()
    hard_fps = [p_pos[i] for i in hard_fps[0]]
    print('hard_fps', len(hard_fps))

    save_pair_info(hard_fps, mention_map, f'./datasets/{dataset}/analysis/hard_fps_{dataset}.csv')

    # clusters = cluster(curr_mentions, mention_pairs=test_pairs, threshold=0.5)

    hard_fns = np.logical_and(np.logical_not(predictions), true_predictions).nonzero()
    print('hard_fns', len(hard_fps))
    hard_fns = [p_pos[i] for i in hard_fns[0]]
    save_pair_info(hard_fns, mention_map, f'./datasets/{dataset}/analysis/hard_fns_{dataset}.csv')


def get_cluster_size_error(dataset,evt_mention_map_test, clus_map):
  
    fp_topics = []
    values = []
    c_to_p = []
    cluster = []
    cluster_to_pair = defaultdict(int)
    gold_clus = []
    dataset_folder = f'./datasets/{dataset}'
    full_path = f"{dataset_folder}/lh_prev_scores_full.csv"
    prev_data = pd.read_csv (full_path)
    
    df_pos = prev_data.loc[(prev_data['coref_label']==1)] 
    pair_list = df_pos['pairs'].tolist()

    for index, x in enumerate(pair_list):
        pair = eval(df_pos['pairs'].tolist()[index])
        men_1 = pair[0]
        men_2 = pair[1]
        c_to_p.append((pair, evt_mention_map_test[men_1]['gold_cluster'] ))


    result_dict = dict(zip(*zip(*c_to_p)))

    
    result_dict_large = [x for x, y in result_dict.items() if y in clus_map]
    len(result_dict_large)

    for index, x in enumerate(pair_list):
        pair = eval(df_pos['pairs'].tolist()[index])
        men_1 = pair[0]
        men_2 = pair[1]
        c_to_p.append((pair, evt_mention_map_test[men_1]['gold_cluster'] ))
      
        if pair in result_dict_large:
            values.append((pair,prev_data['scores_lh_prev'][index], prev_data['scores_r1_k1'][index]  ))
    df_large_cluster_anal = pd.DataFrame(values, columns=['mention_pair', 'scores_lh_prev', 'scores_r1_k1'])
    df_hard_pos_large_clus = df_large_cluster_anal.loc[(df_large_cluster_anal['scores_lh_prev']==0) &\
                                                       (df_large_cluster_anal['scores_r1_k1']==1)]
    df_hard_neg_large_clus = df_large_cluster_anal.loc[(df_large_cluster_anal['scores_lh_prev']==1) &\
                                                       (df_large_cluster_anal['scores_r1_k1']==0)]
    
    df_kd_only = df_large_cluster_anal.loc[(df_large_cluster_anal['scores_r1_k1']==1)]
    df_dpos_only = df_large_cluster_anal.loc[(df_large_cluster_anal['scores_lh_prev']==1)]
    kd_correct = len(df_hard_pos_large_clus)
    dpos_correct =  len(df_hard_neg_large_clus)
    kd_only = len(df_kd_only)
    dpos_only = len(df_dpos_only)
    return kd_correct, dpos_correct, kd_only,dpos_only
    
def get_cluster_maps(dataset, thres): 
    
    dataset_folder = f'./datasets/{dataset}/'
    mention_map = pickle.load(open(dataset_folder + "/mention_map.pkl", 'rb'))
    evt_mention_map = {m_id: m for m_id, m in mention_map.items() if m['men_type'] == 'evt'}
    evt_mention_map_test = {m_id: m for m_id, m in mention_map.items() if m['men_type'] == 'evt' and m['split']=='test'}
    inner_monologue_map = pickle.load(open(dataset_folder + f'/im_map_{dataset}.pkl', 'rb'))
   
    #get the clusters

    cluster_set = []
    for x, y in evt_mention_map_test.items():
        cluster_set.append(y['gold_cluster'])

    len(set(cluster_set))

    counter = collections.Counter(cluster_set)
    #cluster_dict = counter.items()

    cluster_dict = dict(counter)
    clus_large = {x:y for x, y in cluster_dict.items() if y == thres}  
    return clus_large, evt_mention_map_test


def save_cluster_plots(df_ecb, df_gvc):
    plt.style.use('default')
    sns.set_style("whitegrid")

    thresholds = df_ecb['c_thres']
    values_set1 = df_gvc['kd_correct']
    values_set2 = df_gvc['dpos_correct']
    values_set3 = df_ecb['kd_only'][0:26] # showing only overlapping cluster sizes
    values_set4 = df_ecb['dpos_only'][0:26]
    all_values = np.concatenate([values_set1, values_set2, values_set3, values_set4])

    colors = plt.cm.viridis(all_values)
    
    # Create a scatter plot
    plt.figure(figsize=(6, 4))

    plt.scatter(thresholds, values_set1, label= '$Long_{+ROEC,+KD}$ (GVC)', color='red', marker='o', s=20)
    plt.scatter(thresholds, values_set2, label='$Long$ (GVC)', color='green', marker='o', s=20)
    plt.scatter(thresholds, values_set3, label='$Long_{+ROEC,+KD}$ (ECB+)', color = 'deepskyblue', marker='^', s=20)
    plt.scatter(thresholds, values_set4, label='$Long$ (ECB+)', color = 'mediumorchid', marker='^', s=20)

    # Set labels and title
    plt.xlabel('Gold Cluster Thresholds')
    plt.ylabel('True Positive Counts')


    # Display legend
    plt.legend()
    #plt.savefig('cluster_plot_final.png', bbox_inches='tight')
    # Show the plot
    plt.show()


def get_plots_for_clustersize(dataset, cluster_range):
    dataset = 'gvc'
    kd_error = []
    dpos_error = []
    kd = []
    dpos = []
    k_only = []
    d_only = []
    
    cluster_thresholds = [i+1 for i in range(cluster_range)]
    for c_thres in tqdm(cluster_thresholds,desc="Processing items", unit="item") :
        clus_large, evt_mention_map_test = get_cluster_maps(dataset, c_thres)   
        kd_correct, dpos_correct, kd_only,dpos_only = get_cluster_size_error(dataset,evt_mention_map_test, clus_large)
        kd_error.append((c_thres, kd_correct, dpos_correct, kd_only,dpos_only))
        kd.append((kd_correct, "L-KD"))
        dpos.append((dpos_correct, "dpos"))
        k_only.append((kd_only, "kd-only"))
        d_only.append((dpos_only, "dpos-only"))
    df_clustering = pd.DataFrame(kd_error, columns=['c_thres', 'kd_correct', 'dpos_correct','kd_only', 'dpos_only'])
    return df_clustering
 
def get_llm_scores(dataset, heu, split,model = None):
    dataset_folder = f'./datasets/{dataset}/'
    #mps, mps_trans = pickle.load(open(f'./datasets/{dataset}/{heu}/mp_mp_t_{split}.pkl', 'rb'))
    if dataset == 'ldc':
        mps, mps_trans =  pickle.load(open(dataset_folder + f'/{heu}/mp_mp_t_test_0.03_new.pkl', 'rb'))
        tps, fps, tns, fns = mps

        tps = tps
        fps = fps
        test_pairs = tps + fps
        test_labels = [1]*len(tps) + [0]*len(fps)  
        pairs = test_pairs
        #scores_ab =  get_gpt_scores_aida()
        bad_idx = pickle.load(open(dataset_folder + f"/bad_test_indices_{dataset}.pkl", 'rb')) 
        pairs = [y for x, y in enumerate(pairs) if x not in bad_idx]
    else:
        test_mp_mpt, _ = pickle.load(open(dataset_folder + f'/{heu}/mp_mp_t_test.pkl', 'rb'))
        tps_test, fps_test, _, _ = test_mp_mpt
        test_pairs = list(tps_test + fps_test)
        test_labels = [1] * len(tps_test) + [0] * len(fps_test)
        pairs = test_pairs
        


    scores1 = pickle.load(open(dataset_folder + f"/best_scores_llm_cdcr/scores_{model}.pkl", 'rb'))
    if len(scores1)==1:
        scores1 = np.array(scores1[0] )
    scores2 = scores1
    score_map = {}
    
    
    for b, ab, ba in zip(pairs, scores1, scores2):
        score_map[tuple(b)] = (float(ab), float(ba))
    #print("model score map", len(dpos_map))
    return score_map    
    
    
def main():
    datasets = ['ecb', 'gvc', 'ldc']
    dataset = 'ldc'
    split = 'test'
    heu = 'lh_llama'
    # a list of all model types 
    model_list = ['r1_k1', 'r1_k0', 'r0_k1', 'paired','llama', 'gpt']
     
    model_dict = {'r1_k0': 'Long+ROEC-KD', 'r0_k1': 'Long-ROEC+KD',\
                  'r1_k1': 'Long+ROEC+KD','paired': 'Long_paired','llama': 'LLaMA2-7B-Chat','gpt': 'GPT3.5Turbo'}
    final_exp_results = []
    #model_list = ['r1_k1']
    for data in datasets: 
        for model in model_list:
            print("getting model", model)
            score_map = get_llm_scores(data, heu, split, model = model )

            conf,final_scores, final_frame = get_final_scores(data, split, score_map, heu=heu)
            model = model_dict[model]
            final_frame.insert(0, model)
            final_frame.insert(0, data)
            final_exp_results.append(final_frame)
    return conf,final_scores, final_frame, final_exp_results

if __name__ == '__main__':
    conf,final_scores, final_frame,final_exp_results = main()
    # Column names for coref metrics 
    columns = ['dataset', 'model', 'MUC R', 'MUC P', 'MUC F1','B3 R', 'B3 P', 'B3 F1','Ceafe R', 'Ceafe P', 'Ceafe F1','LEA R', 'LEA P', 'LEA F1', 'CoNLL F1' ]

 
    final_results_table = pd.DataFrame(final_exp_results, columns=columns)
  


 

getting model r1_k1
3506 2055 2275
ecb test
& lh_llama && 84.1  & 92.0 & 87.9 && 82.4 & 91.7 & 86.8 && 88.9 & 80.5 & 84.5 && 76.5 & 83.4 & 79.8 && 86.4 \
getting model r1_k0
3506 2055 2275
ecb test
& lh_llama && 79.4  & 92.4 & 85.4 && 79.8 & 93.1 & 85.9 && 89.1 & 76.1 & 82.1 && 73.1 & 81.4 & 77.0 && 84.5 \
getting model r0_k1
3506 2055 2275
ecb test
& lh_llama && 78.2  & 90.6 & 83.9 && 79.4 & 90.2 & 84.4 && 87.9 & 75.4 & 81.2 && 72.4 & 77.2 & 74.7 && 83.2 \
getting model paired
3506 2055 2275
ecb test
& lh_llama && 81.5  & 84.1 & 82.8 && 81.1 & 82.4 & 81.8 && 79.4 & 76.5 & 77.9 && 70.5 & 70.7 & 70.6 && 80.8 \
getting model llama
3506 2055 2275
ecb test
& lh_llama && 84.2  & 76.3 & 80.1 && 82.9 & 73.1 & 77.7 && 67.6 & 77.3 & 72.1 && 67.7 & 62.7 & 65.1 && 76.6 \
getting model gpt
3506 2055 2275
ecb test
& lh_llama && 81.7  & 81.0 & 81.4 && 81.0 & 78.6 & 79.8 && 76.1 & 77.0 & 76.5 && 69.1 & 67.3 & 68.2 && 79.2 \
getting model r1_k1
2804 7551 208
gvc test
& lh_llama && 91.6  & 94.2 & 92.9 

In [62]:
final_results_table

Unnamed: 0,dataset,model,MUC R,MUC P,MUC F1,B3 R,B3 P,B3 F1,Ceafe R,Ceafe P,Ceafe F1,LEA R,LEA P,LEA F1,CoNLL F1
0,ecb,Long+ROEC+KD,84.1,92.0,87.9,82.4,91.7,86.8,88.9,80.5,84.5,76.5,83.4,79.8,86.4
1,ecb,Long+ROEC-KD,79.4,92.4,85.4,79.8,93.1,85.9,89.1,76.1,82.1,73.1,81.4,77.0,84.5
2,ecb,Long-ROEC+KD,78.2,90.6,83.9,79.4,90.2,84.4,87.9,75.4,81.2,72.4,77.2,74.7,83.2
3,ecb,Long_paired,81.5,84.1,82.8,81.1,82.4,81.8,79.4,76.5,77.9,70.5,70.7,70.6,80.8
4,ecb,LLaMA2-7B-Chat,84.2,76.3,80.1,82.9,73.1,77.7,67.6,77.3,72.1,67.7,62.7,65.1,76.6
5,ecb,GPT3.5Turbo,81.7,81.0,81.4,81.0,78.6,79.8,76.1,77.0,76.5,69.1,67.3,68.2,79.2
6,gvc,Long+ROEC+KD,91.6,94.2,92.9,86.7,82.1,84.3,75.8,68.1,71.7,83.4,76.0,79.5,83.0
7,gvc,Long+ROEC-KD,91.9,92.5,92.2,86.8,75.3,80.6,66.9,65.3,66.1,83.3,69.5,75.8,79.6
8,gvc,Long-ROEC+KD,91.3,95.1,93.2,86.0,79.2,82.5,76.6,65.5,70.6,83.1,72.6,77.5,82.1
9,gvc,Long_paired,91.6,90.8,91.2,87.3,64.0,73.9,62.2,64.8,63.5,83.8,57.5,68.2,76.2
