In [None]:
import os
import time

import numpy as np
np.set_printoptions(precision=2, suppress=True, floatmode='fixed', sign=' ')

from pydantic import BaseModel
from typing import List, Optional, Dict, Tuple, Set, Union

from collections import defaultdict

import dataset

In [None]:
regenerate=False
statements = dataset.load_statements(regenerate=regenerate)
statements_by_uid = { s.uid:s for s in statements }

In [None]:
with open("../tg2020task/tableindex.txt", "rt") as f:
    table_names:List[str] = ['Q', 'A-right', 'A-wrong', ]  
    table_names += [ l.strip().replace('.tsv', '') for l in f ]
name_to_table_idx:Dict[str,int] = { n:i for i,n in enumerate(table_names) }
table_names[:6]

In [None]:
qanda = [] # Gather all question
for fold in 'train|dev|test'.split('|'):
    # Train set has 1 question without explanations: Mercury_7221305
    qanda += [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
               if fold=='test' or len(qa.explanation_gold)>0]

In [None]:
class Node(BaseModel):
    id:Union[str, dataset.UID]
    is_statement:bool=False
    is_question:bool =False; n_ans:int=0
    is_ansY:bool     =False
    is_ansN:bool     =False
    raw_txt:str
    keywords:dataset.Keywords
    table_idx:int

In [None]:
graph_nodes:List[Node] = []

In [None]:
statements_existing=set()
for s in statements:
    if not len(s.uid)==19: continue # Only do base statements (not combos)  FIXME
    if not s.uid in statements_existing:
        graph_nodes.append( Node(id=s.uid, is_statement=True,
                                 keywords=s.keywords, raw_txt=s.raw_txt, 
                                 table_idx=name_to_table_idx[s.table], ) )
        statements_existing.add(s.uid)
    else:
        print(f"Duplicate statement ignored : {s.uid}")

In [None]:
for qa in qanda:
    graph_nodes.append( Node(id=qa.question_id, is_question=True, n_ans=len(qa.answers),
                             keywords=qa.question.keywords, raw_txt=qa.question.raw_txt, 
                             table_idx=name_to_table_idx['Q'], ) )
    for i,ans in enumerate(qa.answers):
        graph_nodes.append( Node(id=f"{qa.question_id}_A{i}", 
                                 is_ansY=(i==0), is_ansN=(i>0),
                                 keywords=ans.keywords, raw_txt=ans.raw_txt, 
                                 table_idx=name_to_table_idx['A-right' if i==0 else 'A-wrong'], ) )

In [None]:
dups=defaultdict(int)
for i,n in enumerate(graph_nodes):
    dups[n.id]+=1
[ k for k,v in dups.items() if v>1 ]

In [None]:
# Form a quick look-up from statment/question/answer id to node idx
id_to_graph_node_idx = { n.id:i for i,n in enumerate(graph_nodes) } 

print(f"{len(graph_nodes):,} == {len(id_to_graph_node_idx):,}") # 30,856

In [None]:
# form a big list of keyword->node, so we can then do edges from that
kw_to_graph_idx = defaultdict(list)
for idx, node in enumerate(graph_nodes):
    for kw in node.keywords:
        kw_to_graph_idx[kw].append(idx)
print(len(kw_to_graph_idx)) # 6527

In [None]:
', '.join(f"{kw}={len(arr)}" for kw, arr in kw_to_graph_idx.items() if len(arr)>500)

In [None]:
graph_edges=[]
for kw, arr in kw_to_graph_idx.items():
    for i in arr:
        for j in arr:
            if i==j:continue
            graph_edges.append( (i,j) )
print(f"{len(graph_edges):,}") # 27,271,684    # 6secs

In [None]:
# Remove duplicate links
graph_edges = sorted(list(set(graph_edges))) # Fixed order
print(f"n_edges={len(graph_edges):,}  "+
      f"edge_fraction={len(graph_edges)/len(graph_nodes)/len(graph_nodes)*100.:.2f}%")
# n_edges=25,051,930  edge_fraction=2.63%      # 36secs

In [None]:
# Get embedding for all the raw texts
raw_txt_arr = [ n.raw_txt for n in graph_nodes ]

In [None]:
from hashlib import blake2b
h = blake2b(digest_size=20)
for txt in raw_txt_arr:
    h.update(txt.encode('utf-8'))
raw_txt_hash = h.hexdigest()
raw_txt_hash # '18c2325ea80990539f32ab5f97173541007ddef0'

In [None]:
# If the file is there, use it, otherwise generate it...
embedding_path="../data/cache/embedding" 
raw_txt_embed_file=f"{embedding_path}/{raw_txt_hash}.npz"

if not os.path.isfile(raw_txt_embed_file):  # Avoid loading model if possible
    os.makedirs(embedding_path, exist_ok=True)
    
    model_file='distilbert-base-nli-stsb-mean-tokens'
    
    t0=time.time()
    from sentence_transformers import SentenceTransformer
    #model = SentenceTransformer('distilbert-base-nli-stsb-mean-tokens')
    model = SentenceTransformer(f"../data/{model_file}")  # <2 secs
    t1=time.time()
    
    raw_txt_embed_arr = model.encode(raw_txt_arr)   #  15secs
    t2=time.time()
    print(f"DONE embedding from scratch : initialisation={t1-t0:.2f}secs, embedding={t2-t1:.2f}secs")
    
    np.savez(raw_txt_embed_file, emb=np.array(raw_txt_embed_arr))
    raw_txt_embed_arr=None

In [None]:
import torch

from torch.nn import functional as F
from torch.utils.data import Dataset

import pytorch_lightning as pl
pl.seed_everything(seed=42)

In [None]:
RANK_MAX=512
#BATCH_SIZE=64  # Set below

In [None]:
from torch_geometric.data import Data, DataLoader
import torch_geometric.utils

In [None]:
# x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

# This will be updated for each different question/prediction input
#   and will be the .x values
ranker = torch.tensor( [ [0.] for n in graph_nodes ], dtype=torch.float32, requires_grad=False)

# This is auxilliary data (embeddings, etc)
table_idx = torch.tensor( [ [n.table_idx] for n in graph_nodes ], dtype=torch.long, requires_grad=False)
bools     = torch.tensor( [ [n.is_statement, n.is_question, n.is_ansY, n.is_ansN, ] 
                            for n in graph_nodes ], dtype=torch.int32, requires_grad=False)

#edge_index = torch.tensor([[0, 1, 1, 2],
#                           [1, 0, 2, 1]], dtype=torch.long)
#edge_index = torch.tensor([
#    [ pair[0] for pair in graph_edges],
#    [ pair[1] for pair in graph_edges],
#], dtype=torch.long)
edge_index_t = torch.tensor(graph_edges, dtype=torch.long)  

graph_data = Data(x=ranker, table_idx=table_idx, bools=bools, 
                  edge_index=edge_index_t.t().contiguous()) # Like suggested in the intro docs
                  #edge_index=edge_index)
graph_data.num_nodes, graph_data.num_edges # (30856, 25051930)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
graph_data = graph_data.to(device)

In [None]:
#list(np.load(raw_txt_embed_file).keys())
raw_txt_embedding = torch.tensor(np.load(raw_txt_embed_file)['emb'], dtype=torch.float32).to(device)
raw_txt_embedding.shape # torch.Size([30856, 768])

In [None]:
class RerankGraphDataset(Dataset):
    def __init__(self, fold='dev', preds_file='../predictions/predict.FOLD.baseline-retrieval.txt',
                ):
        self.fold  = fold
        
        regenerate=False
        # Train set has 1 question without explanations: Mercury_7221305
        self.qanda = [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
                         if fold=='test' or len(qa.explanation_gold)>0]
        
        # Load up prediction set
        preds=defaultdict(list) # qa_id -> [statements in order]
        with open(preds_file.replace('FOLD', self.fold), 'rt') as f:
            for l in f.readlines():
                qid, uid = l.strip().split('\t')
                #if qid not in preds: preds[qid]=[]
                preds[qid].append(uid)
        self.preds=preds
        
        self.ranker_base = np.linspace( 0.95, 0.05, num=RANK_MAX)
        self.return_Data=True
        
    def __len__(self):
        return len(self.qanda)
    
    def question_id(self, idx):
        qa = self.qanda[idx]
        return qa.question_id
    
    def __getitem__(self, idx):  # This corresponds to a specific question
        qa = self.qanda[idx]
        q_id=qa.question_id
        
        # Want to return a list of nodes in graph_data be retained
        #   Can calculate edges using
        #     https://pytorch-geometric.readthedocs.io/en/latest/modules/utils.html#torch_geometric.utils.subgraph
            
        # And return corresponding 'x' and 'y' values for each of these slots too
        q_node_idx = id_to_graph_node_idx[q_id]
        n_ans = graph_nodes[q_node_idx].n_ans  # This must be a question...
        if n_ans<=0:
            print(f"Question not found : {q_id}")
            n_ans=0

        # Set up the nodes in 'condensed form'
        n_nodes  = RANK_MAX+1+n_ans
        node_idx = np.zeros( (n_nodes,), dtype=np.int32 )    # These are the nodes of interest
        ranker   = np.zeros( (n_nodes,), dtype=np.float32 )  # This is the previous output
        labels   = np.zeros( (n_nodes,), dtype=np.float32 )  # This is the {0,1} target  (BCE likes floats)
        
        ranker[:RANK_MAX] = self.ranker_base[:RANK_MAX]
        
        pred, pred_uid_to_idx=self.preds[q_id][:RANK_MAX], {}
        for i,uid in enumerate(pred):  # For all the predictions
            node_idx[i]= id_to_graph_node_idx[uid]
            pred_uid_to_idx[uid] = i # Needed for explanation_gold population of labels
        #pred_uid_to_idx = { uid:i for i, uid in enumerate(pred) }  # Needed for explanation_gold
        
        # Here's the question
        node_idx[RANK_MAX] = q_node_idx
        # And the answers
        for i in range(n_ans):
            node_idx[RANK_MAX+1+i] = id_to_graph_node_idx[f"{q_id}_A{i}"]
        
        for ex in qa.explanation_gold:
            if ex.uid in pred_uid_to_idx:  # These are in same positions as node_idx
                labels[ pred_uid_to_idx[ex.uid] ] = 1.
            else:
                pass
                #print(f"Missing explanation: {ex.uid}")
        #if labels.sum()==0.: labels[0]=1. # Prevent NAN if nothing is 1 for APLoss...
        
        if not self.return_Data:
            return dict( idx=idx, q_id=qa.question_id,
                            node_idx = node_idx, 
                            ranker = ranker,
                            labels = labels,
                        )
        subset = torch.tensor(node_idx, dtype=torch.long)
        edge_index_subset, _ = torch_geometric.utils.subgraph(subset, graph_data.edge_index, 
                                                              relabel_nodes=True)
        ranker    = torch.tensor(ranker, requires_grad=False)
        #table_idx = graph_data.table_idx[subset]  # This becomes 'relabelled' in the same way
        #bools     = graph_data.bools[subset]      # This becomes 'relabelled' in the same way

        labels    = torch.tensor(labels, requires_grad=False)

        return Data(x=ranker, 
                    #table_idx=table_idx, bools=bools, 
                    subset_indices=subset,
                    edge_index=edge_index_subset, 
                    y=labels)

ds_dev  =RerankGraphDataset(fold='dev')
ds_train=RerankGraphDataset(fold='train')
ds_test =RerankGraphDataset(fold='test')
#ds_dev[1]

In [None]:
def RerankGraphDataset_to_cache(ds, fold, regenerate=False):
    folder=f"../data/cache/{fold}"
    os.makedirs(folder, exist_ok=True)
    def as_numpy(p): return p.detach().cpu().numpy()
    for idx in range(len(ds)):
        cache_file = os.path.join(folder, f"{idx:05d}.npz")
        if (idx+1)%100==0: print(cache_file)
        if os.path.isfile(cache_file) and not regenerate:
            continue # Skip this, it's there already
        data = ds[idx]
        np.savez(cache_file, 
                 x=as_numpy(data.x), 
                 subset_indices=as_numpy(data.subset_indices),
                 edge_index=as_numpy(data.edge_index),
                 y=as_numpy(data.y),
        )
    print(f"DONE : {fold} - {cache_file}")

regenerate=False
RerankGraphDataset_to_cache(ds_train, 'train', regenerate=regenerate) # 22x 30sec
RerankGraphDataset_to_cache(ds_dev,   'dev'  , regenerate=regenerate) # 5x  30sec
RerankGraphDataset_to_cache(ds_test,  'test' , regenerate=regenerate) # 16x 30sec

In [None]:
class RerankGraphDatasetFromCache(Dataset):
    def __init__(self, fold='dev'):
        self.folder=f"../data/cache/{fold}"
        self.files=sorted(f for f in os.listdir(self.folder) if f.endswith('.npz'))
        
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):  # This corresponds to a specific question
        cache_file = os.path.join(self.folder, self.files[idx])
        data = np.load(cache_file)
        return Data(x=torch.tensor(data['x'], dtype=torch.float),  # , requires_grad=False
                    subset_indices=torch.tensor(data['subset_indices'], dtype=torch.long),
                    edge_index=torch.tensor(data['edge_index'], dtype=torch.long),
                    y=torch.tensor(data['y'], dtype=torch.float),
                )

ds_train_cached = RerankGraphDatasetFromCache(fold='train')
ds_dev_cached   = RerankGraphDatasetFromCache(fold='dev')
ds_test_cached  = RerankGraphDatasetFromCache(fold='test')

len(ds_train_cached), len(ds_dev_cached), len(ds_test_cached) # (2206, 496, 1664)

In [None]:
ds_train_cached[1] # Data(edge_index=[2, 50730], subset_indices=[517], x=[517], y=[517])

In [None]:
idx=413
ds_dev.return_Data=False
d_plain=ds_dev[idx]

ds_dev.return_Data=True
d_data =ds_dev[idx]

# 0.7083 0.7083 1.0000 : XX...X.....X.................................................... : 
#   # 413 = Mercury_7264023 :: [ CONTAINS.CENTRAL KINDOF.GROUNDING MADEOF.NEG SYNONYMY.LEXGLUE ]

In [None]:
# This uses just the 'plain' data
print(d_plain['node_idx'].shape)

for idx in [1684, 4411, 5149, 7828,   # These are the gold explanation statements 
            22123, 22124, 22125, 22126, 22127  # These are the question + answer nodes
           ]:
    uid=graph_nodes[idx].id
    table_idx = graph_nodes[idx].table_idx
    print(f"{list(d_plain['node_idx']).index(idx):3d} {uid:>25s} : {table_idx:2d}={table_names[table_idx]}")

In [None]:
def make_subgraph(d): # This is 'outside' the main ds_data - for testing/debugging
    subset = torch.tensor(d['node_idx'], dtype=torch.long).to(device)
    edge_index_subset, _ = torch_geometric.utils.subgraph(subset, graph_data.edge_index, relabel_nodes=True)
    #print(edge_index.type(), edge_index.shape ) # torch.cuda.LongTensor torch.Size([2, 39206])
    
    ranker    = torch.tensor(d['ranker'], requires_grad=False).to(device)
    table_idx = graph_data.table_idx[subset]  # This becomes 'relabelled' in the same way
    bools     = graph_data.bools[subset]      # This becomes 'relabelled' in the same way

    labels    = torch.tensor(d['labels'], requires_grad=False).to(device)

    return Data(x=ranker, 
                table_idx=table_idx, bools=bools,  # This is 'unwrapped' representation
                subset_indices=subset,             # This is for deferred unwrapping
                edge_index=edge_index_subset, 
                y=labels)

In [None]:
subgraph = make_subgraph(d_plain)
for idx in [0,1,5,11,   # These are the gold explanation statements 
            512, 513, 514, 515, 516  # These are the question + answer nodes
           ]:
    print(f"{idx:3d} -> {d_plain['node_idx'][idx]:5d}, " # Agrees with above
          +f"table_idx={subgraph['table_idx'][idx].item():2d}, "
          #+f"bools={subgraph['bools'][idx]}"
          +f"bools={list(subgraph['bools'][idx].cpu().numpy())}"
         )
    
    table_idx=graph_data.table_idx[subgraph.subset_indices]
    bools    =graph_data.bools[subgraph.subset_indices]
    print(f"{' '*21}{table_idx[idx].item():5d}, bools={list(bools[idx].cpu().numpy())}")
    #print(f"{' '*21}{d_data.table_idx[idx].item():5d}, bools={list(d_data.bools[idx].cpu().numpy())}")    

In [None]:
edge_index_subgraph_list = list(subgraph.edge_index.t().cpu().numpy())
print(sorted([ list(pair) 
               for pair in edge_index_subgraph_list 
               if pair[0]==5]))

In [None]:
# Density of edges within this subgraph
subgraph.edge_index, subgraph.edge_index.shape[1]/(subgraph.subset_indices.shape[0]**2) * 100.

In [None]:
from torch_geometric.nn import GCNConv
from torch_geometric.nn.conv import GATConv, GatedGraphConv

class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        table_emb_size, hidden_size=8, 64
        
        self.table_embedding = torch.nn.Embedding(
            len(table_names), table_emb_size, 
            max_norm=1.0,
        ) # , requires_grad=True is the default::
        
        # ranker, table.embedding, bools
        n_features = 1+table_emb_size+4
        
        #self.conv1 = GCNConv(n_features, hidden_size)
        #self.conv2 = GCNConv(hidden_size, 1)
        
        # As long as n_features<hidden_size, this will just be padded out...
        self.gru = GatedGraphConv(hidden_size, num_layers=4, aggr='max')

        self.att_conv1 = GATConv(hidden_size, hidden_size//8, heads=8, )  # dropout=0.6
        #self.att_conv2 = GATConv(hidden_size, 1, heads=1, concat=False, ) # dropout=0.6
        
        self.final  = torch.nn.Linear(hidden_size, 1)
        self.output = torch.nn.Sigmoid()
        
        self.bx_loss = torch.nn.BCELoss()

    def forward(self, data):
        subset = data.subset_indices   # This is sent over 'per qa'
        ranker = data.x                # This is sent over 'per qa'
        ranker = ranker*4.0-2.0        # Scale to be the approx same range as the LSTM one (+/-2)
        
        table_idx = graph_data.table_idx[subset]  # This uses the on-device full graph
        bools     = graph_data.bools[subset]      # This uses the on-device full graph

        edge_index = data.edge_index
        
        #print(f"ranker.size() : {ranker.size()}")        # torch.Size([517])
        #print(f"table_idx.size() : {table_idx.size()}")  # table_idx.size() : torch.Size([517, 1])
        #print(f"bools.size() : {bools.size()}")          # bools.size() : torch.Size([517, 4])

        x_table_emb = self.table_embedding(table_idx)
        #print(f"x_table_emb.size() : {x_table_emb.size()}")  # x_table_emb.size() : torch.Size([517, 1, 8])

        # raw_txt_embedding
        
        x = torch.cat( [ranker.unsqueeze(-1), x_table_emb.squeeze(1), bools.float() ], axis=-1)

        if False:
            for idx in [0,1,256,512,513,514]:
                print(idx, x[idx,:])
        
        #x = self.conv1(x, edge_index)
        #x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        #x = self.conv2(x, edge_index)

        x = self.gru(x, edge_index)
        
        x = F.elu(self.att_conv1(x, edge_index))
        #x = F.dropout(x, p=0.6, training=self.training)
        #x = self.conv2(x, edge_index)
        
        #print(f"x.size() : {x.size()}")          # x.size() : torch.Size([517, 64])
        x = self.final(x)
        #print(f"final.size() : {x.size()}")      # final.size() : torch.Size([517, 1])
        
        x = x+ranker.unsqueeze(-1)
        return self.output(x)

In [None]:
model = Net().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01 ) #, weight_decay=5e-4)
#list(model.parameters())
model.train()

In [None]:
BATCH_SIZE=32
#loader_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True) #, num_workers=4)
##loader_dev   = DataLoader(ds_dev, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

loader_train = DataLoader(ds_train_cached, batch_size=BATCH_SIZE, shuffle=True) #, num_workers=4)
loader_dev   = DataLoader(ds_dev_cached, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

In [None]:
from tqdm.notebook import tqdm

def run_training_OLD(d, steps=10):
    subgraph, edge_index_subgraph = make_subgraph(d)
    for step in range(steps):
        optimizer.zero_grad()
        #ranks_pred = model(ranker, table_idx, bools, edge_index)
        ranks_pred = model(subgraph, edge_index_subgraph)
        
        #loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        loss = model.bx_loss(ranks_pred, subgraph.y.unsqueeze(-1))
        #print(i, loss.item())
        
        loss.backward()
        
        clipping_value = 1. # arbitrary value of your choosing
        torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
        
        optimizer.step()
    #break
    return loss.item()  # Final loss

def run_training(epoch, loader, loader_len, steps=1):
    loss_all=0.
    for data in tqdm(loader, total=loader_len, desc=f"epoch[{epoch}]"):
        #print(data.device)
        data = data.to(device)    
        for step in range(steps):
            optimizer.zero_grad()
            ranks_pred = model(data)

            loss = model.bx_loss(ranks_pred, data.y.unsqueeze(-1))
            loss.backward()

            #print(f"{loss.item():.4f}")
            clipping_value = 1. # arbitrary value of your choosing
            torch.nn.utils.clip_grad_norm_(model.parameters(), clipping_value)
            
            optimizer.step()
        loss_all += loss.item() * data.num_graphs  # Just at the last step
    return loss_all  # total loss

def run_validation_OLD(d):
    subgraph, edge_index_subgraph = make_subgraph(d)
    with torch.no_grad():
        ranks_pred = model(subgraph, edge_index_subgraph)
        loss = model.bx_loss(ranks_pred, subgraph.y.unsqueeze(-1))
    return loss.item()

def run_validation(epoch, loader, loader_len):
    loss_all=0.
    for data in tqdm(loader, total=loader_len, desc=f"epoch[{epoch}].dev", leave=None):
        #print(data.device)
        data = data.to(device)    
        with torch.no_grad():
            data = data.to(device)    
            ranks_pred = model(data)
            loss = model.bx_loss(ranks_pred, data.y.unsqueeze(-1))
        loss_all += loss.item() * data.num_graphs  # Just at the last step
    return loss_all  # total loss

In [None]:
n_steps=1
for i in []: # range(50):
    #loss=run_training(d, steps=n_steps)
    loss=run_validation(d_data)
    print(f"step[{i*n_steps:5d}] : loss_bx={loss:.4f}")

In [None]:
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#save-load-state-dict-recommended
from datetime import datetime
def save_model(model, epoch, loss, stub="model", fmt="./TS_MODEL_EPOCH_LOSS.pth"):
    filename = (fmt
                 .replace("TS", datetime.now().strftime('%Y-%m-%dT%H.%M.%S%z'))
                 .replace("MODEL", stub)
                 .replace("EPOCH", f"{epoch:03d}")
                 .replace("LOSS",  f"{loss:.4f}")
               )
    torch.save(model.state_dict(), filename) 

In [None]:
loader_train_len = np.ceil(len(ds_train)/BATCH_SIZE)
loader_dev_len   = np.ceil(len(ds_dev)  /BATCH_SIZE)

model_name='graph-with-tables'

loss_best=999.
for epoch in range(50): # []:
    loss_tot=run_training(epoch, loader_train, loader_train_len, steps=4)
    loss = loss_tot/len(ds_train)
    
    loss_tot = run_validation(epoch, loader_dev, loader_dev_len)
    loss_dev = loss_tot/len(ds_dev)
    print(f"epoch[{epoch}] : train.loss_bx.mean()={loss:.4f}, dev.loss_bx.mean()={loss_dev:.4f}")
    
    if loss_best>loss_dev:
        loss_best=loss_dev
        save_model(model, epoch, loss_dev, stub=model_name)

In [None]:
# Train BX target from LSTM list reranker : ~0.0360 (dev-set)
#   Here : 
#     epoch[28] : train.loss_bx.mean()=0.0338, dev.loss_bx.mean()=0.0347
best_model=model

In [None]:
STOP

In [None]:
#save_model(model, epoch, loss)

In [None]:
best_model = Net()
#best_model.load_state_dict(torch.load('2020-09-21T02.25.19_026_0.0354.pth'))
#best_model.load_state_dict(torch.load('2020-09-21T22.35.23_023_0.0361.pth'))

#best_model.load_state_dict(torch.load('2020-09-22T00.42.12_MODEL_028_0.0347.pth'))  # +/- 2ish
best_model.load_state_dict(torch.load('2020-09-22T00.52.37_MODEL_037_0.0345.pth'))  # +/- 2ish

In [None]:
best_model = best_model.to(device)
best_model.eval()

In [None]:
def save_reranked(ds, ds_cached, fold='dev', 
                  preds_file='../predictions/predict.FOLD.baseline-retrieval_plus-graph.txt'):
    with open(preds_file.replace('FOLD', fold), 'wt') as f:
        for idx, data in enumerate(ds_cached):
            q_id = ds.question_id(idx)
            print(q_id)
            with torch.no_grad():
                data = data.to(device)
                #print(data.x.device, data.subset_indices.device, ) # cuda:0 cuda:0
                #print(model.table_emb.device)
                ranks_pred = best_model(data)
            preds_np = ranks_pred.detach().cpu().numpy()
            print(preds_np.shape)
            preds_np = preds_np[:RANK_MAX, 0]
            print(preds_np[:20])
            ranks_np=np.argsort(-preds_np)  # This is the order of the ids we should output
            print(ranks_np[:20])

            preds_original = ds.preds[q_id]

            reranked = [preds_original[idx] for idx in list(ranks_np)]
            reranked = reranked + preds_original[RANK_MAX:]
            #print(len(preds_original), len(reranked))
            print(preds_original[:10]);print(reranked[:10])
            #if q_id=='MCAS_2009_8_18':break
            for p in reranked:
                f.write(f"{q_id}\t{p}\n")
            
                
save_reranked(ds_dev, ds_dev_cached, fold='dev')

In [None]:
#MCAS_2009_8_18
#(517, 1)
#[ 0.22 0.13 0.10 0.15 0.07 0.04 0.11 0.16 0.16 0.21 0.02 0.09 0.13 0.05 0.07 0.09 0.10 0.06 0.06 0.02]
#[22 29  0  9 44  8  7  3  1 12  6 35 43 67 16  2 27 11 15 32]
#['357e-a596-31bf-15a2', 'ae9d-5e74-afa3-d031', '6c7e-2cf1-8a3f-b14a', '73e5-00b0-22fd-e34e', 'f685-05f4-46b2-1a83', '7984-5745-044b-f0ab', 'cb9b-7410-93c7-2adf', 'd3b3-ecd0-5b4d-bedf', '0f8b-28e6-6305-a476', 'a0b2-a45f-01e1-4bf4']
#['09c2-3e04-4343-223d', '3f2b-87bb-f73a-3e5d', '357e-a596-31bf-15a2', 'a0b2-a45f-01e1-4bf4', '81b7-0a58-5b0b-0bbf', '0f8b-28e6-6305-a476', 'd3b3-ecd0-5b4d-bedf', '73e5-00b0-22fd-e34e', 'ae9d-5e74-afa3-d031', 'a10a-a44b-4ab6-225f']

#MCAS_2009_8_18
#(517, 1)
#[ 0.28 0.19 0.13 0.24 0.24 0.07 0.21 0.18 0.17 0.12 0.03 0.08 0.16 0.04 0.03 0.06 0.17 0.03 0.05 0.02]
#[22  0  4  3  6  1  7 16  8 12  2  9 11 35  5 44 15 21 23 67]
#['357e-a596-31bf-15a2', 'ae9d-5e74-afa3-d031', '6c7e-2cf1-8a3f-b14a', '73e5-00b0-22fd-e34e', 'f685-05f4-46b2-1a83', '7984-5745-044b-f0ab', 'cb9b-7410-93c7-2adf', 'd3b3-ecd0-5b4d-bedf', '0f8b-28e6-6305-a476', 'a0b2-a45f-01e1-4bf4']
#['09c2-3e04-4343-223d', '357e-a596-31bf-15a2', 'f685-05f4-46b2-1a83', '73e5-00b0-22fd-e34e', 'cb9b-7410-93c7-2adf', 'ae9d-5e74-afa3-d031', 'd3b3-ecd0-5b4d-bedf', 'bb18-f19f-69ba-8bc6', '0f8b-28e6-6305-a476', 'a10a-a44b-4ab6-225f']

In [None]:
!python ../tg2020task/evaluate.py --gold ../tg2020task/questions.dev.tsv \
                                  ../predictions/predict.dev.baseline-retrieval_plus-graph.txt 

In [None]:
# MAP:  0.4196867493738269 # !! WTF?
# MAP:  0.4613060121071838 # Still WTF... (with +/-2 ranker re-scaling)