In [None]:
import dataset

In [None]:
regenerate=False
statements = dataset.load_statements(regenerate=regenerate)
statements_by_uid = { s.uid:s for s in statements }

In [None]:
#regenerate=False
#qanda_train = dataset.load_qanda('train', regenerate=regenerate) # 1.8MB
#qanda_dev   = dataset.load_qanda('dev', regenerate=regenerate)   # 400k in 496 lines
#qanda_test  = dataset.load_qanda('test', regenerate=regenerate)  # 800k

In [None]:
import os

import numpy as np
np.set_printoptions(precision=2, suppress=True, floatmode='fixed', sign=' ')

import torch

from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import pytorch_lightning as pl

In [None]:
RANK_MAX=512
BATCH_SIZE=8

In [None]:
class RerankDataset(Dataset):
    def __init__(self, fold='dev', preds_file='../predictions/predict_baseline-retrieval.FOLD.txt',
                ):
        self.fold  = fold
        
        regenerate=False
        
        # Train set has 1 question without explanations: Mercury_7221305
        self.qanda = [qa for qa in dataset.load_qanda(fold, regenerate=regenerate)
                         if fold!='test' and len(qa.explanation_gold)>0]
        
        # Load up prediction set
        preds=dict() # qa_id -> [statements in order]
        with open(preds_file.replace('FOLD', self.fold), 'rt') as f:
            for l in f.readlines():
                qid, uid = l.strip().split('\t')
                if qid not in preds: preds[qid]=[]
                preds[qid].append(uid)
        self.preds=preds
        
        # Create labels for the table names
        with open("../tg2020task/tableindex.txt", "rt") as f:
            tables = [ l.strip().replace('.tsv', '') for l in f ]
        table_label={t:i for i,t in enumerate(tables)}
        self.statement_tables = { s.uid:table_label[s.table] for s in statements }

        # Load in other embedding points, etc here
        
    def __len__(self):
        return len(self.qanda)

    def __getitem__(self, idx):
        qa = self.qanda[idx]
        q_id=qa.question_id
        pred = self.preds[q_id][:RANK_MAX]
        
        pred_tables = np.array([ self.statement_tables[uid] for uid in pred ], dtype=np.long)
        pred_uid_to_idx = { uid:i for i, uid in enumerate(pred) }
        
        labels = np.zeros( (RANK_MAX,), dtype=np.int32 )
        for ex in qa.explanation_gold:
            if ex.uid in pred_uid_to_idx:
                labels[ pred_uid_to_idx[ex.uid] ] = 1
        if labels.sum()==0: labels[0]=1 # Prevent NAN if nothing is 1
        
        return dict( idx=idx, q_id=qa.question_id,
                        x_tables=x_tables,
                        labels  =labels,
                    )

#ds=RerankDataset()

In [None]:
# pip install https://github.com/PytorchLightning/pytorch-lightning/archive/master.zip --upgrade
class RerankDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=BATCH_SIZE):
        super().__init__()
        self.batch_size = batch_size
        
        self.ds_train = RerankDataset(fold='train')
        self.ds_dev   = RerankDataset(fold='dev')
        self.ds_test  = RerankDataset(fold='test')

    def train_dataloader(self):               # REQUIRED
        return DataLoader(self.ds_train, batch_size=self.batch_size, num_workers=8, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.ds_dev,   batch_size=self.batch_size, num_workers=4, shuffle=False)

    #def test_dataloader(self):
    #    return DataLoader(self.ds_test,  batch_size=self.batch_size, num_workers=4, shuffle=False)
    
dm = RerankDataModule()

In [None]:
#        self.table_embedding = torch.nn.Embedding(, 8)
