In [1]:
%load_ext autoreload
%autoreload 2

import torch
import json
import numpy as np
import os

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils import process_sentences
from utils import serialize, deserialize
from utils import convert_text2graph
import random
import networkx as nx

### Make test graphs with sentences!

In [2]:
test = "/dfs/scratch1/gmachi/datasets/WikiSection/wikisection_en_disease_test.json"
with open(test, 'r') as f:
    test_data = json.load(f)

save_path = "/dfs/scratch1/gmachi/datasets/wikisection_processed/"
# save_path_attn = os.path.join(save_path, "attn_Gs")
# save_path_entail = os.path.join(save_path, "entail_Gs")
save_path_prob = os.path.join(save_path, "prob_Gs")
# save_path_shap = os.path.join(save_path, "shap_Gs")

target = "disease.genetics"  # class-1


In [4]:
def process_Zs_scores(test_data, score_fn, save_path_Gs):
    section_labs = {}
    section_pseudo = {}
    doc_labs = {}
    sal_count = 0
    for i in range(len(test_data)):
        G_file = "doc_"+str(i) + "_graph.obj"
        save_path_G = os.path.join(save_path_Gs, G_file)
        if os.path.isfile(save_path_G) == True:
            print("skipping sample bc already created:", G_file)
            continue

        text = test_data[i]["text"]
        annots = test_data[i]["annotations"]
        labs = []
        pseudos = []
        scores = []

        num_sents = 0
        for annot in annots:
            begin = int(annot["begin"])
            idx = [begin, begin+int(annot["length"])]
            chunk = text[idx[0]:idx[1]]
            sents = [s for s in chunk.split(".")]
            sents = process_sentences(sents)

            # get score for sentence
            es = [[score_fn(s)] for s in sents]
            scores.extend(es)

            ns = len(sents)
            num_sents += ns
            pseudo = [annot["sectionLabel"]] * ns
            pseudos.extend(pseudo)

            lab = [0] * ns
            if pseudo[0] == target: # if any match
                lab = [1] * ns
            labs.extend(lab)
        
        print("sentences in doc:", num_sents)
        sal_count += int(np.sum(labs)) # keep count of targets
        doc_lab = int(np.sum(labs) > 0)
        doc_labs[i] = doc_lab
        
        Z = np.array(scores)
        # save straight as graph
        G = convert_text2graph(Z)
        serialize(G, save_path_G)
        section_pseudo[i] = pseudos
        section_labs[i] = labs

    return section_labs, section_pseudo, doc_labs, sal_count

In [5]:
# def process_Zs_scores_parallel(test_data, score_fn, save_path_G1s, save_path_G2s):
#     section_labs = {}
#     section_pseudo = {}
#     doc_labs = {}
#     sal_count = 0
#     for i in range(len(test_data)):
#         G_file = "doc_"+str(i) + "_graph.obj"
#         save_path_G1 = os.path.join(save_path_G1s, G_file)
#         save_path_G2 = os.path.join(save_path_G2s, G_file)

#         if os.path.isfile(save_path_G1) and os.path.isfile(save_path_G2):
#             print("skipping sample bc already created (for both):", G_file)
#             continue

#         text = test_data[i]["text"]
#         annots = test_data[i]["annotations"]
#         labs = []
#         pseudos = []
#         scores1 = []
#         scores2 = []

#         num_sents = 0
#         for annot in annots:
#             begin = int(annot["begin"])
#             idx = [begin, begin+int(annot["length"])]
#             chunk = text[idx[0]:idx[1]]
#             sents = [s for s in chunk.split(".")]
#             sents = process_sentences(sents)

#             # get score for sentence
#             es = [[score_fn(s)] for s in sents]
#             scores1.extend([el[0] for el in es])
#             scores2.extend([el[1] for el in es])

#             ns = len(sents)
#             num_sents += ns
#             pseudo = [annot["sectionLabel"]] * ns
#             pseudos.extend(pseudo)

#             lab = [0] * ns
#             if pseudo[0] == target: # if any match
#                 lab = [1] * ns
#             labs.extend(lab)
        
#         print("sentences in doc:", num_sents)
#         sal_count += int(np.sum(labs)) # keep count of targets
#         doc_lab = int(np.sum(labs) > 0)
#         doc_labs[i] = doc_lab
        
#         Z1 = np.array(scores1) # attn
#         Z2 = np.array(scores2) # entail
#         # save straight as graph
#         G1 = convert_text2graph(Z1)
#         G2 = convert_text2graph(Z2)
#         serialize(G1, save_path_G1)
#         serialize(G2, save_path_G2)

#         section_pseudo[i] = pseudos
#         section_labs[i] = labs

#     return section_labs, section_pseudo, doc_labs, sal_count

In [7]:
from text_baselines import deberta_zsc
process_Zs_scores(test_data, deberta_zsc, save_path_prob)

skipping sample bc already created: doc_0_graph.obj
skipping sample bc already created: doc_1_graph.obj
skipping sample bc already created: doc_2_graph.obj
skipping sample bc already created: doc_3_graph.obj
skipping sample bc already created: doc_4_graph.obj
skipping sample bc already created: doc_5_graph.obj
skipping sample bc already created: doc_6_graph.obj
skipping sample bc already created: doc_7_graph.obj
skipping sample bc already created: doc_8_graph.obj
skipping sample bc already created: doc_9_graph.obj
skipping sample bc already created: doc_10_graph.obj
skipping sample bc already created: doc_11_graph.obj
skipping sample bc already created: doc_12_graph.obj
skipping sample bc already created: doc_13_graph.obj
skipping sample bc already created: doc_14_graph.obj
skipping sample bc already created: doc_15_graph.obj
skipping sample bc already created: doc_16_graph.obj
skipping sample bc already created: doc_17_graph.obj
The history saving thread hit an unexpected error (Datab

({}, {}, {}, 0)

## attn and NLI

In [3]:
test = "/dfs/scratch1/gmachi/datasets/WikiSection/wikisection_en_disease_test.json"
with open(test, 'r') as f:
    test_data = json.load(f)

save_path = "/dfs/scratch1/gmachi/datasets/wikisection_processed/"
save_path_attn = os.path.join(save_path, "attn_Gs")
save_path_entail = os.path.join(save_path, "entail_Gs")
# save_path_prob = os.path.join(save_path, "prob_Gs")

# save_path_shap = os.path.join(save_path, "shap_Gs")

target = "disease.genetics"  # class-1

In [4]:
def process_Zs_scores_parallel(test_data, score_fn, save_path_G1s, save_path_G2s):
    section_labs = {}
    section_pseudo = {}
    doc_labs = {}
    sal_count = 0
    for i in range(len(test_data)):
        G_file = "doc_"+str(i) + "_graph.obj"
        save_path_G1 = os.path.join(save_path_G1s, G_file)
        save_path_G2 = os.path.join(save_path_G2s, G_file)

        if os.path.isfile(save_path_G1) and os.path.isfile(save_path_G2):
            print("skipping sample bc already created (for both):", G_file)
            continue

        text = test_data[i]["text"]
        annots = test_data[i]["annotations"]
        labs = []
        pseudos = []
        scores1 = []
        scores2 = []

        num_sents = 0
        for annot in annots:
            begin = int(annot["begin"])
            idx = [begin, begin+int(annot["length"])]
            chunk = text[idx[0]:idx[1]]
            sents = [s for s in chunk.split(".")]
            sents = process_sentences(sents)

            # get score for sentence
            es = [score_fn(s) for s in sents]
            scores1.extend([[el[0]] for el in es])
            scores2.extend([[el[1]] for el in es])

            ns = len(sents)
            num_sents += ns
            pseudo = [annot["sectionLabel"]] * ns
            pseudos.extend(pseudo)

            lab = [0] * ns
            if pseudo[0] == target: # if any match
                lab = [1] * ns
            labs.extend(lab)
        
        print("sentences in doc:", num_sents)
        sal_count += int(np.sum(labs)) # keep count of targets
        doc_lab = int(np.sum(labs) > 0)
        doc_labs[i] = doc_lab
        
        Z1 = np.array(scores1) # attn
        Z2 = np.array(scores2) # entail
        # save straight as graph
        G1 = convert_text2graph(Z1)
        G2 = convert_text2graph(Z2)
        serialize(G1, save_path_G1)
        serialize(G2, save_path_G2)

        section_pseudo[i] = pseudos
        section_labs[i] = labs

    return section_labs, section_pseudo, doc_labs, sal_count

In [6]:
from text_baselines import deberta_attn_NLI
process_Zs_scores_parallel(test_data, deberta_attn_NLI, save_path_attn, save_path_entail)

skipping sample bc already created (for both): doc_0_graph.obj
skipping sample bc already created (for both): doc_1_graph.obj
skipping sample bc already created (for both): doc_2_graph.obj
skipping sample bc already created (for both): doc_3_graph.obj
skipping sample bc already created (for both): doc_4_graph.obj
skipping sample bc already created (for both): doc_5_graph.obj
skipping sample bc already created (for both): doc_6_graph.obj
skipping sample bc already created (for both): doc_7_graph.obj
skipping sample bc already created (for both): doc_8_graph.obj
skipping sample bc already created (for both): doc_9_graph.obj
skipping sample bc already created (for both): doc_10_graph.obj
skipping sample bc already created (for both): doc_11_graph.obj
skipping sample bc already created (for both): doc_12_graph.obj
skipping sample bc already created (for both): doc_13_graph.obj
skipping sample bc already created (for both): doc_14_graph.obj
skipping sample bc already created (for both): doc

({}, {}, {}, 0)