In [None]:
from pydantic import BaseModel
from typing import List, Optional, Dict, Tuple, Set, Union

from collections import defaultdict

import dataset

In [None]:
regenerate=False
statements = dataset.load_statements(regenerate=regenerate)
statements_by_uid = { s.uid:s for s in statements }

In [None]:
with open("../tg2020task/tableindex.txt", "rt") as f:
    table_names:List[str] = ['Q', 'A-right', 'A-wrong', ]  
    table_names += [ l.strip().replace('.tsv', '') for l in f ]
name_to_table_idx:Dict[str,int] = { n:i for i,n in enumerate(table_names) }
table_names[:6]

In [None]:
qanda = [] # Gather all question
for fold in 'train|dev|test'.split('|'):
    # Train set has 1 question without explanations: Mercury_7221305
    qanda += [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
               if fold=='test' or len(qa.explanation_gold)>0]

In [None]:
class Node(BaseModel):
    id:Union[str, dataset.UID]
    is_statement:bool=False
    is_question:bool =False; n_ans:int=0
    is_ansY:bool     =False
    is_ansN:bool     =False
    raw_txt:str
    keywords:dataset.Keywords
    table_idx:int

In [None]:
graph_nodes:List[Node] = []

In [None]:
statements_existing=set()
for s in statements:
    if not len(s.uid)==19: continue # Only do base statements (not combos)  FIXME
    if not s.uid in statements_existing:
        graph_nodes.append( Node(id=s.uid, is_statement=True,
                                 keywords=s.keywords, raw_txt=s.raw_txt, 
                                 table_idx=name_to_table_idx[s.table], ) )
        statements_existing.add(s.uid)
    else:
        print(f"Duplicate statement ignored : {s.uid}")

In [None]:
for qa in qanda:
    graph_nodes.append( Node(id=qa.question_id, is_question=True, n_ans=len(qa.answers),
                             keywords=qa.question.keywords, raw_txt=qa.question.raw_txt, 
                             table_idx=name_to_table_idx['Q'], ) )
    for i,ans in enumerate(qa.answers):
        graph_nodes.append( Node(id=f"{qa.question_id}_A{i}", 
                                 is_ansY=(i==0), is_ansN=(i>0),
                                 keywords=ans.keywords, raw_txt=ans.raw_txt, 
                                 table_idx=name_to_table_idx['A-right' if i==0 else 'A-wrong'], ) )

In [None]:
dups=defaultdict(int)
for i,n in enumerate(graph_nodes):
    dups[n.id]+=1
[ k for k,v in dups.items() if v>1 ]

In [None]:
# Form a quick look-up from statment/question/answer id to node indes
id_to_graph_node_idx = { n.id:i for i,n in enumerate(graph_nodes) } 

print(f"{len(graph_nodes):,} == {len(id_to_graph_node_idx):,}") # 33,872

In [None]:
# form a big list of keyword->node, so we can then do edges from that
kw_to_graph_idx = defaultdict(list)
for idx, node in enumerate(graph_nodes):
    for kw in node.keywords:
        kw_to_graph_idx[kw].append(idx)
print(len(kw_to_graph_idx)) # 6540

In [None]:
for kw, arr in kw_to_graph_idx.items():
    if len(arr)>500: 
        print(kw, len(arr))

In [None]:
graph_edges=[]
for kw, arr in kw_to_graph_idx.items():
    for i in arr:
        for j in arr:
            if i==j:continue
            graph_edges.append( (i,j) )
print(f"{len(graph_edges):,}") # 34,518,168

In [None]:
# Remove duplicate links
graph_edges = list(set(graph_edges)) # Fixed order
print(f"n_edges={len(graph_edges):,}  "+
      f"edge_fraction={len(graph_edges)/len(graph_nodes)/len(graph_nodes)*100.:.2f}%")
# n_edges=31,687,626  edge_fraction=2.76%

In [None]:
import os

import numpy as np
np.set_printoptions(precision=2, suppress=True, floatmode='fixed', sign=' ')

import torch

from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl
pl.seed_everything(seed=42)

In [None]:
RANK_MAX=512
#BATCH_SIZE=64

In [None]:
from torch_geometric.data import Data

In [None]:
# x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

# This will be updated for each different question/prediction input
#   and will be the .x values
ranker = torch.tensor( [ [0.] for n in graph_nodes ], dtype=torch.float32, requires_grad=False)

# This is auxilliary data (embeddings, etc)
table_idx = torch.tensor( [ [n.table_idx] for n in graph_nodes ], dtype=torch.long, requires_grad=False)
bools     = torch.tensor( [ [n.is_statement, n.is_question, n.is_ansY, n.is_ansN, ] 
                            for n in graph_nodes ], dtype=torch.int32, requires_grad=False)

#edge_index = torch.tensor([[0, 1, 1, 2],
#                           [1, 0, 2, 1]], dtype=torch.long)
#edge_index = torch.tensor([
#    [ pair[0] for pair in graph_edges],
#    [ pair[1] for pair in graph_edges],
#], dtype=torch.long)
edge_index_t = torch.tensor(graph_edges, dtype=torch.long)  

graph_data = Data(x=ranker, table_idx=table_idx, bools=bools, 
                  edge_index=edge_index_t.t().contiguous()) # Like suggested in the intro docs
                  #edge_index=edge_index)
graph_data.num_nodes, graph_data.num_edges

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
graph_data = graph_data.to(device)

In [None]:
class RerankGraphDataset(Dataset):
    def __init__(self, fold='dev', preds_file='../predictions/predict.FOLD.baseline-retrieval.txt',
                ):
        self.fold  = fold
        
        regenerate=False
        # Train set has 1 question without explanations: Mercury_7221305
        self.qanda = [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
                         if fold=='test' or len(qa.explanation_gold)>0]
        
        # Load up prediction set
        preds=defaultdict(list) # qa_id -> [statements in order]
        with open(preds_file.replace('FOLD', self.fold), 'rt') as f:
            for l in f.readlines():
                qid, uid = l.strip().split('\t')
                #if qid not in preds: preds[qid]=[]
                preds[qid].append(uid)
        self.preds=preds
        
    def __len__(self):
        return len(self.qanda)

    def __getitem__(self, idx):  # This corresponds to a specific question
        qa = self.qanda[idx]
        q_id=qa.question_id
        
        pred = self.preds[q_id][:RANK_MAX]
        pred_uid_to_idx = { uid:i for i, uid in enumerate(pred) }  # Needed for explanation_gold
        
        # Want to return a list of nodes in graph_data be retained
        #   Can calculate edges using
        #     https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.subgraph
            
        # And return corresponding 'x' and 'y' values for each of these slots too
        
        n_ans = graph_nodes[id_to_graph_node_idx[q_id]].n_ans  # This must be a question...
        if n_ans<=0:
            print(f"Question not found : {q_id}")
            n_ans=0

        n_nodes  = RANK_MAX+1+n_ans
        node_idx = np.zeros( (n_nodes,), dtype=np.int32 )    # These are the nodes of interest
        ranker   = np.zeros( (n_nodes,), dtype=np.float32 )  # This is the previous output
        labels   = np.zeros( (n_nodes,), dtype=np.float32 )  # This is the {0,1} target  (BCE likes floats)
        
        for i,uid in enumerate(pred):  # For all the predictions
            node_idx[i]= id_to_graph_node_idx[uid]
            ranker[i]  = 1.0 - (float(i)/RANK_MAX)
        
        # Here's the question
        node_idx[RANK_MAX] = id_to_graph_node_idx[q_id]
        # And the answers
        for a in range(n_ans):
            node_idx[RANK_MAX+1+a] = id_to_graph_node_idx[f"{q_id}_A{a}"]
        
        for ex in qa.explanation_gold:
            if ex.uid in pred_uid_to_idx:  # These are in same positions as node_idx
                labels[ pred_uid_to_idx[ex.uid] ] = 1.
        #if labels.sum()==0.: labels[0]=1. # Prevent NAN if nothing is 1 for APLoss...
        
        return dict( idx=idx, q_id=qa.question_id,
                        node_idx = node_idx, 
                        ranker = ranker,
                        labels = labels,
                    )

ds_dev=RerankGraphDataset(fold='dev')    
#ds_dev[1]

In [None]:
from torch_geometric.nn import GCNConv

class Net_GCN(torch.nn.Module):
    def __init__(self):
        super(Net_Simple, self).__init__()
        #self.conv1 = GCNConv(dataset.num_node_features, 16)
        #self.conv2 = GCNConv(16, dataset.num_classes)
        
        table_emb_size, hidden_size=8, 64
        self.table_embedding = torch.nn.Embedding(len(table_names), table_emb_size)        
        
        # ranker, table.embedding, bools
        n_features = 1+table_emb_size+4
        
        self.conv1 = GCNConv(n_features, hidden_size)
        self.conv2 = GCNConv(hidden_size, 1)
        self.output = torch.nn.Sigmoid()
        
        self.bx_loss = torch.nn.BCELoss()

    #def forward(self, ranker, table_idx, bools, edge_index):
    def forward(self, data, edge_index):
        #x, edge_index = data.x, data.edge_index
        ranker, table_idx, bools = data.x, data.table_idx, data.bools
        #print(f"ranker.size() : {ranker.size()}")        # torch.Size([517])
        #print(f"table_idx.size() : {table_idx.size()}")  # table_idx.size() : torch.Size([517, 1])
        #print(f"bools.size() : {bools.size()}")          # bools.size() : torch.Size([517, 4])

        x_table_emb = self.table_embedding(table_idx)
        #print(f"x_table_emb.size() : {x_table_emb.size()}")  # x_table_emb.size() : torch.Size([517, 1, 8])
        
        x = torch.cat( [ranker.unsqueeze(-1), x_table_emb.squeeze(1), bools], axis=-1)
        
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return self.output(x)

In [None]:
from torch_geometric.nn.conv import GATConv
from torch_geometric.nn.conv import GatedGraphConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        table_emb_size, hidden_size=8, 64
        self.table_embedding = torch.nn.Embedding(len(table_names), table_emb_size)        
        
        # ranker, table.embedding, bools
        n_features = 1+table_emb_size+4
        
        #self.conv1 = GCNConv(n_features, hidden_size)
        #self.conv2 = GCNConv(hidden_size, 1)
        
        # As long as n_features<hidden_size, this will just be padded out...
        self.gru = GatedGraphConv(hidden_size, num_layers=4, aggr='max')

        self.conv1 = GATConv(hidden_size, hidden_size//8, heads=8, )  # dropout=0.6
        self.conv2 = GATConv(hidden_size, 1, heads=1, concat=False, ) # dropout=0.6

        self.output = torch.nn.Sigmoid()
        
        self.bx_loss = torch.nn.BCELoss()

    def forward(self, data, edge_index):
        ranker, table_idx, bools = data.x, data.table_idx, data.bools
        x_table_emb = self.table_embedding(table_idx)
        x = torch.cat( [ranker.unsqueeze(-1), x_table_emb.squeeze(1), bools], axis=-1)
        
        x = self.gru(x, edge_index)
        
        x = F.elu(self.conv1(x, edge_index))
        #x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return self.output(x)

In [None]:
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
#list(model.parameters())
model.train()

In [None]:
import torch_geometric.utils

def make_subgraph(d):
    subset = torch.tensor(d['node_idx'], dtype=torch.long).to(device)
    edge_index_subset, _ = torch_geometric.utils.subgraph(subset, graph_data.edge_index, relabel_nodes=True)
    #print(edge_index.type(), edge_index.shape ) # torch.cuda.LongTensor torch.Size([2, 39206])
    
    ranker    = torch.tensor(d['ranker'], requires_grad=False).to(device)
    table_idx = graph_data.table_idx[subset]  # This becomes 'relabelled' in the same way
    bools     = graph_data.bools[subset]      # This becomes 'relabelled' in the same way

    labels    = torch.tensor(d['labels'], requires_grad=False).to(device)

    return Data(x=ranker, table_idx=table_idx, bools=bools, 
                #edge_index=edge_index_subset, 
                y=labels), edge_index_subset

In [None]:
from tqdm.notebook import tqdm

def run_training(d, steps=10):
    subgraph, edge_index_subgraph = make_subgraph(d)
    for step in range(steps):
        optimizer.zero_grad()
        #ranks_pred = model(ranker, table_idx, bools, edge_index)
        ranks_pred = model(subgraph, edge_index_subgraph)
        
        #loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss = model.bx_loss(ranks_pred, subgraph.y.unsqueeze(-1))
        #print(i, loss.item())
        
        loss.backward()
        optimizer.step()
    #break
    return loss.item()  # Final loss

In [None]:
for epoch in range(10):
    loss_tot, loss_cnt=0., 0
    for i in tqdm(range(len(ds_dev)), desc=f"epoch[{epoch}]"):
        d = ds_dev[i]
        loss=run_training(d, steps=10)
        loss_tot+=loss
        loss_cnt+=1
    print(f"epoch[{epoch}] : loss_bx.mean()={loss_tot/loss_cnt:.4f}")