In [None]:
import dataset

In [None]:
regenerate=False
statements = dataset.load_statements(regenerate=regenerate)
statements_by_uid = { s.uid:s for s in statements }

In [None]:
with open("../tg2020task/tableindex.txt", "rt") as f:
    table_names = [ l.strip().replace('.tsv', '') for l in f ]
table_names[:4]

In [None]:
#regenerate=False
#qanda_train = dataset.load_qanda('train', regenerate=regenerate) # 1.8MB
#qanda_dev   = dataset.load_qanda('dev', regenerate=regenerate)   # 400k in 496 lines
#qanda_test  = dataset.load_qanda('test', regenerate=regenerate)  # 800k

In [None]:
import os

import numpy as np
np.set_printoptions(precision=2, suppress=True, floatmode='fixed', sign=' ')

import torch

from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl

In [None]:
RANK_MAX=512
BATCH_SIZE=64

In [None]:
class RerankDataset(Dataset):
    def __init__(self, fold='dev', preds_file='../predictions/predict.FOLD.baseline-retrieval.txt',
                ):
        self.fold  = fold
        
        regenerate=False
        
        # Train set has 1 question without explanations: Mercury_7221305
        self.qanda = [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
                         if fold!='test' and len(qa.explanation_gold)>0]
        
        # Load up prediction set
        preds=dict() # qa_id -> [statements in order]
        with open(preds_file.replace('FOLD', self.fold), 'rt') as f:
            for l in f.readlines():
                qid, uid = l.strip().split('\t')
                if qid not in preds: preds[qid]=[]
                preds[qid].append(uid)
        self.preds=preds
        
        # Create labels for the table names
        table_label={t:i for i,t in enumerate(table_names)}
        self.statement_tables = { s.uid:table_label[s.table] for s in statements }

        # Load in other embedding points, etc here
        
    def __len__(self):
        return len(self.qanda)

    def __getitem__(self, idx):
        qa = self.qanda[idx]
        q_id=qa.question_id
        
        pred = self.preds[q_id][:RANK_MAX]
        pred_uid_to_idx = { uid:i for i, uid in enumerate(pred) }
        
        pred_tables = np.array([ self.statement_tables[uid] for uid in pred ], dtype=np.long)
        
        labels = np.zeros( (RANK_MAX,), dtype=np.int32 )
        for ex in qa.explanation_gold:
            if ex.uid in pred_uid_to_idx:
                labels[ pred_uid_to_idx[ex.uid] ] = 1
        if labels.sum()==0: labels[0]=1 # Prevent NAN if nothing is 1
        
        return dict( idx=idx, q_id=qa.question_id,
                        tables = pred_tables,
                        labels = labels,
                    )

ds_dev=RerankDataset(fold='dev')

In [None]:
len(ds_dev), #ds_dev[20], table_names[74], table_names[21], table_names[11]

In [None]:
# pip install https://github.com/PytorchLightning/pytorch-lightning/archive/master.zip --upgrade
class RerankDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=BATCH_SIZE):
        super().__init__()
        self.batch_size = batch_size
        
        self.ds_train = RerankDataset(fold='train')
        self.ds_dev   = RerankDataset(fold='dev')
        self.ds_test  = RerankDataset(fold='test')

    def train_dataloader(self):               # REQUIRED
        return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=8, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.ds_dev,   batch_size=self.batch_size, num_workers=4, shuffle=False)

    #def test_dataloader(self):
    #    return DataLoader(self.ds_test,  batch_size=self.batch_size, num_workers=4, shuffle=False)
    
dm = RerankDataModule()

In [None]:
import losses

In [None]:
class RerankModel(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters()
        self.hparams = hparams
        
        #uv_map     = torch.nn.Parameter( torch.randn( (1, h0, 256, 256) ), requires_grad=True )
        #self.register_parameter(name='uv_map', param=uv_map)
        
        table_emb_size, hidden_size=8,64
        self.table_embedding = torch.nn.Embedding(len(table_names), table_emb_size)

        self.lstm1 = torch.nn.LSTM(input_size=1+table_emb_size, hidden_size=hidden_size,
                                   num_layers=2,
                                   batch_first=True, bidirectional=True)

        self.gather = torch.nn.Linear(hidden_size*2, hidden_size)
        self.gather_act = torch.nn.LeakyReLU()
        #self.do = torch.nn.Dropout(p=0.1)
        
        self.final  = torch.nn.Linear(hidden_size, 1)
        self.output = torch.nn.Sigmoid()  # Ensures between 0.0 ... 1.0, maintains order
        
        self.ranking_loss = losses.APLoss(nq=50)
        
    def forward(self, x_table):   
        bs   = x_table.shape[0]
        #base = torch.linspace(0.95, 0.05, steps=RANK_MAX).expand( (bs, -1) )
        base = torch.linspace(+3.00, -3.00, steps=RANK_MAX).expand( (bs, -1) )  # With sigmoid
        #print(f"base.size() : {base.size()}") # base.size() : torch.Size([16, 512])
        #print(base[0,0:20])  # Makes sense
        
        x_table_emb = self.table_embedding(x_table)
        #print(f"x_table_emb.size() : {x_table_emb.size()}") # x_table_emb.size() : torch.Size([16, 512, 8])
        
        x = torch.cat([base.unsqueeze(2), x_table_emb, ], axis=2)
        #print(f"x_cat.size() : {x.size()}") # x_cat.size() : torch.Size([16, 512, 9])
        
        x, (hn, cn) = self.lstm1(x)  # , (self.h0, self.c0)
        
        x = self.gather_act( self.gather(x) )
        #x = self.do(x)
        #print(f"x_gather.size() : {x.size()}") # x_gather.size() : torch.Size([16, 512, 64])
        
        x = self.final(x).squeeze(2)
        #print(f"x_final.size() : {x.size()}")  # x_final.size() : torch.Size([16, 512])

        x = base + x  # Whole thing was a ResNet
        #print(f"x_resnet.size() : {x.size()}") # x_resnet.size() : torch.Size([16, 512])

        #return x
        return self.output(x)  
    
    def loss_calc(self, ranks_pred, ranks_target):
        loss_ap = self.ranking_loss(ranks_pred, ranks_target)-1.0
        loss = ( loss_ap
                 #self.hparams.alpha*self.relu(loss_match_diff) +
                 #self.hparams.reg*( loss_id2 + loss_ex2 )
               )
        return loss, dict(
            loss_ap=loss_ap, 
        )

    def training_step(self, batch, batch_idx_within_epoch): # REQUIRED
        ranks_pred = self(batch['tables'])
        loss, log = self.loss_calc(ranks_pred, batch['labels'])
        log.update(dict(train_loss=loss))
        return {'loss': loss, 'log': log}

    def configure_optimizers(self):           # REQUIRED
        # https://pytorch-lightning.readthedocs.io/en/latest/optimizers.html#learning-rate-scheduling
        #print([ n for n,p in self.named_parameters()])
        optimizer=torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer

    def validation_step(self, batch, batch_idx): # REQUIRED IF val_dataloader defined
        ranks_pred = self(batch['tables'])
        loss, log = self.loss_calc(ranks_pred, batch['labels'])
        log.update(dict(check_loss=loss))
        return log
    
    def validation_epoch_end(self, steps):
        log=dict()
        for agg in ['check_loss', ]: #'loss_point', 'loss_scale', 'loss_move']:
            log[agg] = torch.stack( [x[agg] for x in steps] ).mean()
        print(f"{-log['check_loss']:.4f}")
        return {'val_loss': log['check_loss'], 'log':log }

In [None]:
from argparse import ArgumentParser, Namespace
hparams = Namespace()  # alpha=2.0, beta=1.0, reg=0.05, gap=0.25
hparams.lr=0.001

rerank_model = RerankModel(hparams, ) 

In [None]:
trainer = pl.Trainer(
    #gpus=1,
    max_epochs=20,
    #max_epochs=5 if hparams.dilate else 20,
    #auto_lr_find=True,
    #fast_dev_run=True,
)
trainer.fit(rerank_model, dm)

In [None]:
# Save off the predictions for the given dataset