In [None]:
from pydantic import BaseModel
from typing import List, Optional, Dict, Tuple, Set, Union

from collections import defaultdict

import dataset

In [None]:
regenerate=False
statements = dataset.load_statements(regenerate=regenerate)
statements_by_uid = { s.uid:s for s in statements }

In [None]:
with open("../tg2020task/tableindex.txt", "rt") as f:
    table_names = ['Q', 'A-right', 'A-wrong', ]  
    table_names += [ l.strip().replace('.tsv', '') for l in f ]
name_to_table_idx = { n:i for i,n in enumerate(table_names) }
table_names[:6]

In [None]:
qanda = [] # Gather all question
for fold in 'train|dev|test'.split('|'):
    # Train set has 1 question without explanations: Mercury_7221305
    qanda += [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
               if fold=='test' or len(qa.explanation_gold)>0]

In [None]:
class Node(BaseModel):
    id:Union[str, dataset.UID]
    is_statement:bool=False
    is_question:bool =False; n_q:int=0
    is_ansY:bool     =False
    is_ansN:bool     =False
    raw_txt:str
    keywords:dataset.Keywords
    table:str

In [None]:
graph_nodes:List[Node] = []

In [None]:
for s in statements:
    graph_nodes.append( Node(id=s.uid, is_statement=True,
                             keywords=s.keywords, raw_txt=s.raw_txt, 
                             table=name_to_table_idx[s.table], ) )

In [None]:
for qa in qanda:
    graph_nodes.append( Node(id=qa.question_id, is_question=True, n_q=len(qa.answers),
                             keywords=qa.question.keywords, raw_txt=qa.question.raw_txt, 
                             table=name_to_table_idx['Q'], ) )
    for i,ans in enumerate(qa.answers):
        graph_nodes.append( Node(id=f"{qa.question_id}_A{i}", 
                                 is_ansY=(i==0), is_ansN=(i>0),
                                 keywords=ans.keywords, raw_txt=ans.raw_txt, 
                                  table=name_to_table_idx['A-right' if i==0 else 'A-wrong'], ) )

In [None]:
print(f"{len(graph_nodes):,}") # 33,872

In [None]:
# form a big list of keyword->node, so we can then do edges from that
kw_to_graph_idx = defaultdict(list)
for idx, node in enumerate(graph_nodes):
    for kw in node.keywords:
        kw_to_graph_idx[kw].append(idx)
print(len(kw_to_graph_idx)) # 6540

In [None]:
for kw, arr in kw_to_graph_idx.items():
    if len(arr)>500: 
        print(kw, len(arr))

In [None]:
graph_edges=[]
for kw, arr in kw_to_graph_idx.items():
    for i in arr:
        for j in arr:
            if i==j:continue
            graph_edges.append( (i,j) )
print(f"{len(graph_edges):,}") # 34,527,162

In [None]:
# Remove duplicate links
graph_edges = set(graph_edges)
print(f"n_edges={len(graph_edges):,}, "+
      f"edge_fraction={len(graph_edges)/len(graph_nodes)/len(graph_nodes)*100.:.2f}%")
# n_edges=31,695,592, edge_fraction=2.76%