In [1]:
import os
import json
import pickle
import numpy as np
from tqdm import tqdm
from collections import defaultdict
import random

In [2]:
repo_root = '../../../'

# Load WebNLG dataset

In [3]:
from datasets import load_dataset

dataset_name = "webnlg-challenge/web_nlg"
data_all = load_dataset(dataset_name, 'release_v3.0_en')

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data_all['test'][987]

{'category': 'Artist',
 'size': 1,
 'eid': 'Id988',
 'original_triple_sets': {'otriple_set': [['Petah_Tikva | country | Israel']]},
 'modified_triple_sets': {'mtriple_set': [['Petah_Tikva | country | Israel']]},
 'shape': '(X (X))',
 'shape_type': 'NA',
 'lex': {'comment': ['', '', ''],
  'lid': ['Id1', 'Id2', 'Id3'],
  'text': ['Petah Tikva is a city in Israel.',
   'Petah Tikva is in Israel.',
   'Petah Tikva is in the country of Israel.'],
  'lang': ['', '', '']},
 'test_category': 'rdf-to-text-generation-test-data-with-refs-en',
 'dbpedia_links': [],
 'links': []}

In [10]:
def extract(data):
    
    triple_set = set()
    concept_set = set()
    relation_set = set()
    
    for i in tqdm(range(len(data))):
        entry = data[i]
        for row in entry['modified_triple_sets']['mtriple_set']:
            for triple in row:
                # split triple by '|' and strip
                triple = tuple([x.strip() for x in triple.split('|')])
                triple_set.add(triple)
                concept_set.update([triple[0], triple[2]])
                relation_set.add(triple[1])
        
    return triple_set, concept_set, relation_set


triple_set_train, concept_set_train, relation_set_train = extract(data_all['train'])
triple_set_dev, concept_set_dev, relation_set_dev = extract(data_all['dev'])
triple_set_test, concept_set_test, relation_set_test = extract(data_all['test'])

# remove relations in test set that are not in train set
relation_set_test = relation_set_test.intersection(relation_set_train)

# remove triples in test set that have relations not in train set
triple_set_test = set([x for x in triple_set_test if x[1] in relation_set_train])

# save concept.txt, relation.txt

concept_set = concept_set_train.union(concept_set_dev).union(concept_set_test)
relation_set = relation_set_train.union(relation_set_dev).union(relation_set_test)


# only keep top 20 most frequent relations
relation_freq = defaultdict(int)
for triple in triple_set_train.union(triple_set_dev).union(triple_set_test):
    relation_freq[triple[1]] += 1
relation_freq = sorted(relation_freq.items(), key=lambda x: x[1], reverse=True)
relation_freq = relation_freq[:]
relation_set = set([x[0] for x in relation_freq])

with open(os.path.join(repo_root, 'data/dbpedia/concept.txt'), 'w') as f:
    for concept in concept_set:
        f.write(concept + '\n')

with open(os.path.join(repo_root, 'data/dbpedia/relation.txt'), 'w') as f:
    for relation in relation_set:
        f.write(relation + '\n')

triple_set_test = set([x for x in triple_set_test if x[1] in relation_set])
triple_set_dev = set([x for x in triple_set_dev if x[1] in relation_set])
triple_set_train = set([x for x in triple_set_train if x[1] in relation_set])



# # remove data in test that the relation is not in relation_set
# data_all_new = {'train': [], 'dev': [], 'test': []}

# for frame in ['train', 'dev', 'test']:
#     data_ori = data_all[frame]
#     for i in tqdm(range(len(data_ori))):
#         entry = data_ori[i]
#         # obtain the triples
#         triples = entry['modified_triple_sets']['mtriple_set']
#         triples = [[tuple([x.strip() for x in triple.split('|')]) for triple in row] for row in triples]
#         # if any of the triples' relation is not in relation_set, skip, otherwise add to new data
#         if any([triple[1] not in relation_set for row in triples for triple in row]):
#             continue
#         data_all_new[frame].append(entry)

# data_all = data_all_new

100%|██████████| 13211/13211 [00:01<00:00, 8183.97it/s]
100%|██████████| 1667/1667 [00:00<00:00, 8390.71it/s]
100%|██████████| 5713/5713 [00:00<00:00, 8561.50it/s]


In [27]:
# data_all_new = {'train': [], 'dev': [], 'test': []}

# for frame in ['train', 'dev', 'test']:
#     data_ori = data_all[frame]
#     for i in tqdm(range(len(data_ori))):
#         entry = data_ori[i]
#         # obtain the triples
#         triples = entry['modified_triple_sets']['mtriple_set']
#         # replace relation with new relation
#         triples_new = []
#         SKIP = False
#         for row in triples:
#             for triple in row:
#                 # split triple by '|' and strip
#                 triple = tuple([x.strip() for x in triple.split('|')])
#                 if triple[1] not in relation_mapping:
#                     SKIP = True
#                     break
#                 if triple[1] in relation_mapping:
#                     triple = [triple[0], relation_mapping[triple[1]], triple[2]]
#                     triple = triple[0] + ' | ' + triple[1] + ' | ' + triple[2]
#                     triples_new.append(triple)
        
#         if SKIP:
#             continue
#         triples = triples_new

#         entry['modified_triple_sets']['mtriple_set'] = triples
#         data_all_new[frame].append(entry)

# # data_all = data_all_new

100%|██████████| 13211/13211 [00:01<00:00, 8258.26it/s]
100%|██████████| 1667/1667 [00:00<00:00, 8657.17it/s]
100%|██████████| 5713/5713 [00:00<00:00, 8952.84it/s]


In [39]:
data_all['train'][121]['modified_triple_sets']['mtriple_set']

['Ardmore_Airport_(New_Zealand) | TransportationCharacteristics | "07/25"']

# Data Curation by Merging different graph-text pairs

In [5]:
import random

def combine_samples_once(dataset, min_size=6):
    combined_triples = []
    combined_text = []
    current_size = 0

    dataset = dataset.shuffle()
    
    for sample in dataset:
        # Extract all triples from the sample
        triples = sample['modified_triple_sets']['mtriple_set']
        text = random.choice(sample['lex']['text'])  # Randomly sample a text
        size = sample['size']

        SKIP = False
        for row in triples:
            for triple in row:
                triple = tuple([x.strip() for x in triple.split('|')])
                rel = triple[1]
                if rel not in relation_set:
                    SKIP = True
                    break
        
        if SKIP:
            continue

        # Add all triples to the combined results
        combined_triples.extend(triples)
        combined_text.append(text)
        current_size += size

        # If the combined size is greater than or equal to the threshold, stop combining
        if current_size >= min_size:
            break
        
    # convert combined triples [[], [], []] to []
    combined_triples = [item for sublist in combined_triples for item in sublist]
    # Return the combined results
    return {
        'combined_triples': tuple(combined_triples),  # Use tuple to ensure hashability
        'combined_text': " ".join(combined_text),
        'combined_size': current_size
    }

def generate_unique_combinations(dataset, num_combinations=10000, min_size=10):
    unique_combinations = set()
    results = []

    # Initialize tqdm progress bar
    with tqdm(total=num_combinations) as pbar:
        # Keep generating until we reach the required number of unique combinations
        while len(unique_combinations) < num_combinations:
            try:
                combined_sample = combine_samples_once(dataset, min_size)
            except:
                continue

            # Sort the combined triples to avoid duplicates due to different order
            sorted_triples = tuple(sorted(combined_sample['combined_triples']))
            
            # Ensure the sorted combination of triples is unique
            if sorted_triples not in unique_combinations:
                unique_combinations.add(sorted_triples)
                combined_sample['combined_triples'] = sorted_triples  # Ensure the result keeps sorted triples
                results.append(combined_sample)
                
                # Update progress bar
                pbar.update(1)

    return results

# combined_sample = combine_samples_once(data_all['train'], min_size=10)

def generate_new_data(split, num_combinations=10000, min_size=4):
    unique_combinations = generate_unique_combinations(data_all[split], num_combinations=num_combinations, min_size=min_size)
    # save the unique combinations
    with open(os.path.join(repo_root, f'data/webnlg/webnlg_{split}_combinations.json'), 'w') as f:
        json.dump(unique_combinations, f)


In [6]:
num_dict = {
    'train': 5000,
    'dev': 2000,
    'test': 1000
}

for split in ['train', 'dev', 'test']:
    generate_new_data(split, num_combinations=num_dict[split], min_size=4)


100%|██████████| 5000/5000 [01:20<00:00, 62.38it/s]
100%|██████████| 2000/2000 [00:23<00:00, 83.52it/s]
100%|██████████| 1000/1000 [00:36<00:00, 27.71it/s]


In [7]:
def mix_combinations(positive_samples, num_mixed_samples):
    unique_mixed_samples = set()  # To store unique mixed samples
    mixed_results = []

    # Initialize tqdm progress bar
    with tqdm(total=num_mixed_samples) as pbar:
        while len(unique_mixed_samples) < num_mixed_samples:
            # Step 1: Randomly select two different positive samples
            sample_a = random.choice(positive_samples)
            sample_b = random.choice(positive_samples)
            
            # Ensure we are not mixing the same sample
            if sample_a == sample_b:
                continue
            
            # Step 2: Create a new mixed sample (a1, b2) - a1 (triples), b2 (text)
            mixed_triples = sample_a['combined_triples']
            mixed_text = sample_b['combined_text']
            
            # Step 3: Sort the combined triples to avoid duplicates due to different order
            sorted_triples = tuple(sorted(mixed_triples))

            # Step 4: Ensure the mixed sample is unique
            if (sorted_triples, mixed_text) not in unique_mixed_samples:
                # Add to the set of unique samples
                unique_mixed_samples.add((sorted_triples, mixed_text))

                # Add the mixed sample to results
                mixed_results.append({
                    'combined_triples': sorted_triples,
                    'combined_text': mixed_text,
                    'combined_size': len(sorted_triples)
                })
                
                # Update the progress bar
                pbar.update(1)

    return mixed_results


def generate_negative_data(split, num_combinations=10000, min_size=8):
    positive = generate_unique_combinations(data_all[split], num_combinations=num_combinations, min_size=min_size)
    negative = mix_combinations(positive, num_mixed_samples=num_combinations)

    # save the unique combinations
    with open(os.path.join(repo_root, f'data/webnlg/webnlg_{split}_negative.json'), 'w') as f:
        json.dump(negative, f)


for split in ['train', 'dev', 'test']:
    generate_negative_data(split, num_combinations=num_dict[split])

100%|██████████| 5000/5000 [03:04<00:00, 27.11it/s]
100%|██████████| 5000/5000 [00:00<00:00, 238440.08it/s]
100%|██████████| 2000/2000 [00:44<00:00, 44.62it/s] 
100%|██████████| 2000/2000 [00:00<00:00, 236518.68it/s]
100%|██████████| 1000/1000 [01:11<00:00, 14.04it/s]
100%|██████████| 1000/1000 [00:00<00:00, 344983.06it/s]


In [14]:
# Prepare `statement` data following CommonsenseQA, OpenBookQA
webnlg_root = f'{repo_root}/data/webnlg'
os.system(f'mkdir -p {webnlg_root}/statement')

for fname in ["train", 'dev', "test"]:
    # read repo_root/data/webnlg/webnlg_{frame}_combinations.json as positive
    with open(f'{webnlg_root}/webnlg_{fname}_combinations.json') as f:
        positive = json.load(f)
    with open(f'{webnlg_root}/webnlg_{fname}_negative.json') as f:
        negative = json.load(f)
    
    def process_data(data, label=0):
        examples = []
        
        for i in tqdm(range(len(data))):
            line = data[i]
            _id  = f"{fname}-{i:05d}-{label}"

            # if label is 0, then answerKey is A, otherwise B
            answerKey = 'A' if label == 0 else 'B'
            stem = line['combined_text']
            triples = line['combined_triples']
            stmts = stem
            ex_obj    = {"id": _id, 
                        "question": {"stem": stem, "choices": [{'text': ""}], 'triples': triples}, 
                        "answerKey": answerKey, 
                        "statements": stmts
                        }
            examples.append(ex_obj)
        
        return examples

    pos_examples = process_data(positive, label=1)
    neg_examples = process_data(negative, label=0)

    # combine positive and negative examples
    all_examples = pos_examples + neg_examples
    random.shuffle(all_examples)
    
    with open(f'{webnlg_root}/statement/{fname}.statement.jsonl', 'w') as fout:
        for dic in all_examples:
            print(json.dumps(dic), file=fout)

100%|██████████| 5000/5000 [00:00<00:00, 684292.75it/s]


100%|██████████| 5000/5000 [00:00<00:00, 70021.07it/s]
100%|██████████| 2000/2000 [00:00<00:00, 660936.65it/s]
100%|██████████| 2000/2000 [00:00<00:00, 651592.98it/s]
100%|██████████| 1000/1000 [00:00<00:00, 559912.43it/s]
100%|██████████| 1000/1000 [00:00<00:00, 972705.01it/s]


# Filter Statements
Only keep more common relations

# Load KG

In [16]:
import networkx as nx
import os

# read concept.txt, relation.txt
with open(os.path.join(repo_root, 'data/dbpedia/concept.txt'), 'r') as f:
    concept_list = f.read().splitlines()

with open(os.path.join(repo_root, 'data/dbpedia/relation.txt'), 'r') as f:
    relation_list = f.read().splitlines()

id2concept = concept_list
id2relation = relation_list

concept2id = {concept: i for i, concept in enumerate(concept_list)}
relation2id = {relation: i for i, relation in enumerate(relation_list)}

def construct_graph(triple_set):
    graph = nx.MultiDiGraph()
    attrs = set()
    
    for triple in triple_set:
        subj = concept2id[triple[0]]
        obj = concept2id[triple[2]]
        rel = relation2id[triple[1]]
        weight = 1.
        graph.add_edge(subj, obj, rel=rel, weight=weight)
        attrs.add((subj, obj, rel))
        graph.add_edge(obj, subj, rel=rel + len(relation2id), weight=weight)
        attrs.add((obj, subj, rel + len(relation2id)))

    output_path = f"{repo_root}/data/dbpedia/dbpedia.graph"
    nx.write_gpickle(graph, output_path)
    
    return graph


KG = construct_graph(triple_set_train.union(triple_set_dev).union(triple_set_test))

In [None]:
# # for each example in test set, check 
# with open(f'{webnlg_root}/statement/test.statement_ori.jsonl') as f:
#     data = [json.loads(line) for line in f]

# print(len(data))

# def check_statement_test(example):
#     triples = example['question']['choices']
#     for triple in triples:
#         triple = tuple([x.strip() for x in triple.split('|')]) 
#         rel = triple[1]
#         if rel not in relation_list:
#             return False
#     return True

# data = [x for x in data if check_statement_test(x)]

# print(len(data))

# # save the filtered data
# with open(f'{webnlg_root}/statement/test.statement.jsonl', 'w') as f:
#     for dic in data:
#         print(json.dumps(dic), file=f)

        

In [17]:
def process(frame):
    with open (f'{webnlg_root}/statement/{frame}.statement.jsonl') as f:
        stmts = [json.loads(line) for line in f]
    with open(f"{webnlg_root}/grounded/{frame}.grounded.jsonl", 'w') as fout:
        for stmt in tqdm(stmts):
            sent = stmt['question']['stem']
            qc = []
            qc_names = []
            triples = stmt['question']['triples']
            # obtain the entity names, split triples by '|' and strip, choose the first and last element
            for triple in triples:
                triple = [x.strip() for x in triple.split('|')]
                qc_names.extend([triple[0], triple[2]])
                qc.extend([concept2id[triple[0]], concept2id[triple[2]]])
            
            ans = stmt['answerKey']
            out = {'sent': sent, 'ans': ans, 'qc': qc, 'qc_names': qc_names, 'ac': [], 'ac_names': [], 'triples': triples}
            print (json.dumps(out), file=fout)


os.system(f'mkdir -p {webnlg_root}/grounded')
for frame in ['train', 'dev', 'test']:
    process(frame)

100%|██████████| 10000/10000 [00:00<00:00, 52812.82it/s]
100%|██████████| 4000/4000 [00:00<00:00, 52799.05it/s]
100%|██████████| 2000/2000 [00:00<00:00, 48411.27it/s]


## Get KG subgraph

In [18]:
def load_kg():
    global cpnet, cpnet_simple
    cpnet = KG
    cpnet_simple = nx.Graph()
    for u, v, data in cpnet.edges(data=True):
        w = data['weight'] if 'weight' in data else 1.0
        if cpnet_simple.has_edge(u, v):
            cpnet_simple[u][v]['weight'] += w
        else:
            cpnet_simple.add_edge(u, v, weight=w)

load_kg()

In [19]:
from scipy.sparse import csr_matrix, coo_matrix
from multiprocessing import Pool

def concepts2adj(node_ids):
    global id2relation
    cids = np.array(node_ids, dtype=np.int32)
    n_rel = len(id2relation)
    n_node = cids.shape[0]
    adj = np.zeros((n_rel, n_node, n_node), dtype=np.uint8)
    for s in range(n_node):
        for t in range(n_node):
            s_c, t_c = cids[s], cids[t]
            if cpnet.has_edge(s_c, t_c):
                for e_attr in cpnet[s_c][t_c].values():
                    if e_attr['rel'] >= 0 and e_attr['rel'] < n_rel:
                        adj[e_attr['rel']][s][t] = 1
    adj = coo_matrix(adj.reshape(-1, n_node))
    return adj, cids

def concepts_to_adj_matrices_all_pair(data):
    qc_ids, ac_ids = data
    qa_nodes = set(qc_ids) | set(ac_ids)
    schema_graph = sorted(qc_ids) + sorted(ac_ids)
    arange = np.arange(len(schema_graph))
    qmask = arange < len(qc_ids)
    amask = (arange >= len(qc_ids)) & (arange < (len(qc_ids) + len(ac_ids)))
    adj, concepts = concepts2adj(schema_graph)
    return {'adj': adj, 'concepts': concepts, 'qmask': qmask, 'amask': amask, 'cid2score': None}

In [20]:
def generate_adj_data_from_grounded_concepts(grounded_path, cpnet_graph_path, cpnet_vocab_path, output_path, num_processes):
    qa_data = []
    with open(grounded_path, 'r', encoding='utf-8') as fin:
        for line in fin:
            dic = json.loads(line)
            q_ids = set(concept2id[c] for c in dic['qc_names'])
            if not q_ids:
                q_ids = {concept2id['31770']} 
            a_ids = set(concept2id[c] for c in dic['ac_names'])
            if not a_ids:
                a_ids = {concept2id['325']}
            q_ids = q_ids - a_ids
            qa_data.append((q_ids, a_ids))
    
    with Pool(num_processes) as p:
        res = list(tqdm(p.imap(concepts_to_adj_matrices_all_pair, qa_data), total=len(qa_data)))
    
    lens = [len(e['concepts']) for e in res]
    print ('mean #nodes', int(np.mean(lens)), 'med', int(np.median(lens)), '5th', int(np.percentile(lens, 5)), '95th', int(np.percentile(lens, 95)))

    with open(output_path, 'wb') as fout:
        pickle.dump(res, fout)

    print(f'adj data saved to {output_path}')
    print()


In [21]:
os.system(f'mkdir -p {repo_root}/data/webnlg/graph')

for fname in ['train', 'dev', "test"]:
    grounded_path = f"{repo_root}/data/webnlg/grounded/{fname}.grounded.jsonl"
    kg_path       = f"{repo_root}/data/dbpedia/dbpedia.graph"
    kg_vocab_path = f"{repo_root}/data/dbpedia/concept.txt"
    output_path   = f"{repo_root}/data/webnlg/graph/{fname}.graph.adj.pk"

    generate_adj_data_from_grounded_concepts(grounded_path, kg_path, kg_vocab_path, output_path, 10)

100%|██████████| 10000/10000 [00:00<00:00, 14600.51it/s]


mean #nodes 10 med 11 5th 6 95th 16
adj data saved to ../../..//data/webnlg/graph/train.graph.adj.pk



100%|██████████| 4000/4000 [00:00<00:00, 12190.16it/s]


mean #nodes 10 med 11 5th 6 95th 16
adj data saved to ../../..//data/webnlg/graph/dev.graph.adj.pk



100%|██████████| 2000/2000 [00:00<00:00, 14864.06it/s]


mean #nodes 10 med 11 5th 6 95th 15
adj data saved to ../../..//data/webnlg/graph/test.graph.adj.pk

