<a href="https://colab.research.google.com/github/finardi/tutos/blob/master/RAW_CROSSBERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!nvidia-smi

Sun Aug 30 21:30:10 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.66       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   55C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

# Modeling util

In [None]:
import math
import torch

In [None]:
def subbatch(toks, maxlen):
    _, DLEN = toks.shape[:2]
    SUBBATCH = math.ceil(DLEN / maxlen)
    S = math.ceil(DLEN / SUBBATCH) if SUBBATCH > 0 else 0 # minimize the size given the number of subbatch
    stack = []
    if SUBBATCH == 1:
        return toks, SUBBATCH
    else:
        for s in range(SUBBATCH):
            stack.append(toks[:, s*S:(s+1)*S])
            if stack[-1].shape[1] != S:
                nulls = torch.zeros_like(toks[:, :S - stack[-1].shape[1]])
                stack[-1] = torch.cat([stack[-1], nulls], dim=1)
        return torch.cat(stack, dim=0), SUBBATCH


def un_subbatch(embed, toks, maxlen):
    BATCH, DLEN = toks.shape[:2]
    SUBBATCH = math.ceil(DLEN / maxlen)
    if SUBBATCH == 1:
        return embed
    else:
        embed_stack = []
        for b in range(SUBBATCH):
            embed_stack.append(embed[b*BATCH:(b+1)*BATCH])
        embed = torch.cat(embed_stack, dim=1)
        embed = embed[:, :DLEN]
        return embed

# Modeling

In [None]:
!pip install -q pytools
!pip install -q pytorch_pretrained_bert
from pytools import memoize_method
import torch
import torch.nn.functional as F
import pytorch_pretrained_bert

[?25l[K     |████▉                           | 10kB 29.4MB/s eta 0:00:01[K     |█████████▊                      | 20kB 2.9MB/s eta 0:00:01[K     |██████████████▋                 | 30kB 3.9MB/s eta 0:00:01[K     |███████████████████▌            | 40kB 4.2MB/s eta 0:00:01[K     |████████████████████████▍       | 51kB 3.5MB/s eta 0:00:01[K     |█████████████████████████████▎  | 61kB 3.8MB/s eta 0:00:01[K     |████████████████████████████████| 71kB 2.9MB/s 
[?25h  Building wheel for pytools (setup.py) ... [?25l[?25hdone
[K     |████████████████████████████████| 133kB 4.8MB/s 
[?25h

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CustomBertModel(pytorch_pretrained_bert.BertModel):
    """
    Based on pytorch_pretrained_bert.BertModel, but also outputs un-contextualized embeddings.
    """

    def forward(self, input_ids, token_type_ids, attention_mask):
        """
        Based on pytorch_pretrained_bert.BertModel
        """
        embedding_output = self.embeddings(input_ids, token_type_ids)

        extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
        extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype)  # fp16 compatibility
        extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

        encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=True)

        return [embedding_output] + encoded_layers

class BertRanker(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.BERT_MODEL = 'bert-base-uncased'
        self.CHANNELS = 12 + 1  # from bert-base-uncased
        self.BERT_SIZE = 768  # from bert-base-uncased
        self.bert = CustomBertModel.from_pretrained(self.BERT_MODEL)
        self.tokenizer = pytorch_pretrained_bert.BertTokenizer.from_pretrained(self.BERT_MODEL)

    def forward(self, **inputs):
        raise NotImplementedError

    def save(self, path):
        state = self.state_dict(keep_vars=True)
        for key in list(state):
            if state[key].requires_grad:
                state[key] = state[key].data
            else:
                del state[key]
        torch.save(state, path)

    def load(self, path):
        self.load_state_dict(torch.load(path), strict=False)

    @memoize_method
    def tokenize(self, text):
        toks = self.tokenizer.tokenize(text)
        toks = [self.tokenizer.vocab[t] for t in toks]
        return toks

    def encode_bert(self, query_tok, query_mask, doc_tok, doc_mask, customBert=None):
        BATCH, QLEN = query_tok.shape
        DIFF = 3  # = [CLS] and 2x[SEP]
        maxlen = self.bert.config.max_position_embeddings
        MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF

        doc_toks, sbcount = subbatch(doc_tok, MAX_DOC_TOK_LEN)
        doc_mask, _ = subbatch(doc_mask, MAX_DOC_TOK_LEN)

        query_toks = torch.cat([query_tok] * sbcount, dim=0)
        query_mask = torch.cat([query_mask] * sbcount, dim=0)

        CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]'])
        SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]'])
        ONES = torch.ones_like(query_mask[:, :1])
        NILS = torch.zeros_like(query_mask[:, :1])

        # build BERT input sequences
        toks = torch.cat([CLSS, doc_toks, SEPS, query_toks, SEPS], dim=1)
        mask = torch.cat([ONES, doc_mask, ONES, query_mask, ONES], dim=1)
        # segment_ids = torch.cat([NILS] * (2 + QLEN) + [ONES] * (1 + doc_toks.shape[1]), dim=1)
        segment_ids = torch.cat([NILS] * (2 + doc_toks.shape[1]) + [ONES] * (1 + QLEN), dim=1)
        toks[toks == -1] = 0  # remove padding (will be masked anyway)

        # execute BERT model
        if not customBert:
            result = self.bert(toks, segment_ids.long(), mask)
        else:
            result = customBert(toks, segment_ids.long(), mask)

        # extract relevant subsequences for query and doc
        query_results = [r[:BATCH, 1:QLEN + 1] for r in result]
        doc_results = [r[:, QLEN + 2:-1] for r in result]
        doc_results = [un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN) for r in doc_results]

        # build CLS representation
        cls_results = []
        for layer in result:
            cls_output = layer[:, 0]
            cls_result = []
            for i in range(cls_output.shape[0] // BATCH):
                cls_result.append(cls_output[i * BATCH:(i + 1) * BATCH])
            cls_result = torch.stack(cls_result, dim=2).mean(dim=2)
            cls_results.append(cls_result)

        return cls_results, query_results, doc_results



class CrossBert(BertRanker):
    def __init__(self, args):
        super().__init__()

        self.args = args

        self.dropout = torch.nn.Dropout(0.1)
        self.cls = torch.nn.Linear(self.BERT_SIZE, 1)
        self.cls2 = torch.nn.Linear(self.BERT_SIZE, 1)
        self.clsAll = torch.nn.Linear(2, 1)

    def forward(self, query_tok, query_mask, doc_tok, doc_mask, wiki_tok, wiki_mask, question_tok, question_mask):
        cls_query_tok, _, _ = self.encode_bert(query_tok, query_mask, doc_tok, doc_mask)
        cls_doc_tok, _, _ = self.encode_bert(doc_tok, doc_mask, query_tok, query_mask)
        if self.args.mode % 2 == 0:
            cls_wiki_doc_tok, _, _ = self.encode_bert(wiki_tok, wiki_mask, doc_tok, doc_mask)
            cls_doc_wiki_tok, _, _ = self.encode_bert(doc_tok, doc_mask, wiki_tok, wiki_mask)

        if self.args.mode == 1:
            mul = torch.mul(cls_query_tok[-1], cls_doc_tok[-1])
            return self.cls(self.dropout(mul))

        elif self.args.mode == 2:
            mul = torch.mul(cls_query_tok[-1], cls_doc_tok[-1])
            mul_wiki = torch.mul(cls_wiki_doc_tok[-1], cls_doc_wiki_tok[-1])
            cat = self.cls(self.dropout(mul))
            cat_wiki = self.cls2(self.dropout(mul_wiki))
            return self.clsAll(torch.cat([cat, cat_wiki], dim=1))

# Data

In [None]:
import random
from tqdm import tqdm
import torch
import numpy as np

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def read_datafiles(files):
    queries, wikis, questions, docs, qtypes = {}, {}, {}, {}, {}
    # for file in files:
    for idx, file in enumerate(files):
        for line in tqdm(file, desc='loading datafile (by line)', leave=False):
            cols = line.rstrip().split('\t')
            if len(cols) == 3:
                c_type, c_id, c_text = cols
            elif len(cols) == 4:
                c_type, c_id, c_text, c_qtype = cols
            # assert c_type in ('query', 'doc', 'wiki', 'question')
            # if c_type == 'query':
            if idx == 0:
                queries[c_id] = c_text
                qtypes[c_id] = c_qtype
            # elif c_type == 'doc':
            elif idx == 1:
                docs[c_id] = c_text
            elif idx == 2:
                wikis[c_id] = c_text
            elif idx == 3:
                questions[c_id] = c_text

    return queries, docs, wikis, questions, qtypes


def read_qrels_dict(file):
    result = {}
    for line in tqdm(file, desc='loading qrels (by line)', leave=False):
        qid, _, docid, score, _ = line.split()
        result.setdefault(qid, {})[docid] = int(score)
    return result


def read_run_dict(file):
    result = {}
    for line in tqdm(file, desc='loading run (by line)', leave=False):
        qid, _, docid, rank, score, _ = line.split()
        result.setdefault(qid, {})[docid] = float(score)
    return result


def read_pairs_dict(file):
    result = {}
    for line in tqdm(file, desc='loading pairs (by line)', leave=False):
        qid, docid = line.split()
        result.setdefault(qid, {})[docid] = 1
    return result


def iter_train_pairs(model, dataset, train_pairs, qrels, batch_size, args):
    batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': [], 'wiki_tok': [], 'question_tok': [], 'label': [], 'query_raw':[], 'doc_raw':[], 'wiki_raw':[]}
    for qid, did, query_tok, doc_tok, wiki_tok, question_tok, query, doc, wiki in _iter_train_pairs(model, dataset, train_pairs, qrels,
                                                                                  args):
        batch['query_id'].append(qid)
        batch['doc_id'].append(did)
        batch['query_tok'].append(query_tok)
        batch['doc_tok'].append(doc_tok)
        batch['wiki_tok'].append(wiki_tok)
        batch['question_tok'].append(question_tok)
        batch['query_raw'].append(query)
        batch['doc_raw'].append(doc)
        batch['wiki_raw'].append(wiki)

        if len(batch['query_id']) // 2 == batch_size:
            yield _pack_n_ship(batch, data, args)
            batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': [], 'wiki_tok': [], 'question_tok': [], 'query_raw':[], 'doc_raw':[], 'wiki_raw':[]}


def _iter_train_pairs(model, dataset, train_pairs, qrels, args):
    ds_queries, ds_docs, ds_wikis, ds_questions, ds_qtypes = dataset
    while True:
        qids = list(train_pairs.keys())
        random.shuffle(qids)
        for qid in qids:
            pos_ids = [did for did in train_pairs[qid] if qrels.get(qid, {}).get(did, 0) > 0]
            if len(pos_ids) == 0:
                continue

            pos_id = random.choice(pos_ids)
            neg_ids = [did for did in train_pairs[qid] if qrels.get(qid, {}).get(did, 0) == 0]

            if len(neg_ids) == 0:
                print("No neg instances", qid)
                continue

            neg_id = random.choice(neg_ids)
            query_tok = model.tokenize(ds_queries[qid])
            wiki_tok = model.tokenize(ds_wikis[qid])
            question_tok = model.tokenize(ds_questions[qid])

            pos_doc = ds_docs.get(pos_id)
            neg_doc = ds_docs.get(neg_id)
            if pos_doc is None:
                tqdm.write(f'missing doc {pos_id}! Skipping')
                continue
            if neg_doc is None:
                tqdm.write(f'missing doc {neg_id}! Skipping')
                continue

            yield qid, pos_id, query_tok, model.tokenize(pos_doc), wiki_tok, question_tok, ds_queries[qid], pos_doc, ds_wikis[qid]
            yield qid, neg_id, query_tok, model.tokenize(neg_doc), wiki_tok, question_tok, ds_queries[qid], neg_doc, ds_wikis[qid]

        # break


def iter_valid_records(model, dataset, run, batch_size, data, args):
    batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': [], 'wiki_tok': [], 'question_tok': [], 'label': [], 'query_raw':[], 'doc_raw':[], 'wiki_raw':[] }

    for qid, did, query_tok, doc_tok, wiki_tok, question_tok, query, doc, wiki, in _iter_valid_records(model, dataset, run, args):
        batch['query_id'].append(qid)
        batch['doc_id'].append(did)
        batch['query_tok'].append(query_tok)
        batch['doc_tok'].append(doc_tok)
        batch['wiki_tok'].append(wiki_tok)
        batch['question_tok'].append(question_tok)
        batch['query_raw'].append(query)
        batch['doc_raw'].append(doc)
        batch['wiki_raw'].append(wiki)

        if len(batch['query_id']) == batch_size:
            yield _pack_n_ship(batch, data, args)
            batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': [], 'wiki_tok': [], 'question_tok': [], 'query_raw':[], 'doc_raw':[], 'wiki_raw':[]}

    # final batch
    if len(batch['query_id']) > 0:
        yield _pack_n_ship(batch, data, args)


def _iter_valid_records(model, dataset, run, args):
    ds_queries, ds_docs, ds_wikis, ds_questions, ds_qtypes = dataset
    for qid in run:
        query_tok = model.tokenize(ds_queries[qid])
        wiki_tok = model.tokenize(ds_wikis[qid])
        question_tok = model.tokenize(ds_questions[qid])

        for did in run[qid]:
            doc = ds_docs.get(did)
            if doc is None:
                tqdm.write(f'missing doc {did}! Skipping')
                continue
            doc_tok = model.tokenize(doc)
            yield qid, did, query_tok, doc_tok, wiki_tok, question_tok, ds_queries[qid], doc, ds_wikis[qid]


def _pack_n_ship(batch, data, args):

    QLEN = min(args.maxlen, int(np.max([len(b) for b in batch['query_tok']])))
    DLEN = min(args.maxlen, int(np.max([len(b) for b in batch['doc_tok']])))
    WLEN = min(args.maxlen, int(np.max([len(b) for b in batch['wiki_tok']])))
    QQLEN = min(args.maxlen, int(np.max([len(b) for b in batch['question_tok']])))


    return {
        'query_id': batch['query_id'],
        'doc_id': batch['doc_id'],
        'query_tok': _pad_crop(batch['query_tok'], QLEN),
        'doc_tok': _pad_crop(batch['doc_tok'], DLEN),
        'wiki_tok': _pad_crop(batch['wiki_tok'], WLEN),
        'question_tok': _pad_crop(batch['question_tok'], QQLEN),
        'query_mask': _mask(batch['query_tok'], QLEN),
        'doc_mask': _mask(batch['doc_tok'], DLEN),
        'wiki_mask': _mask(batch['wiki_tok'], WLEN),
        'question_mask': _mask(batch['question_tok'], QQLEN),
    }

def toTensor(x):
    # print(torch.tensor(x))
    # try:
    return torch.tensor(x).float().cuda() if device.type == 'cuda' else torch.tensor(x).float()
    # except:
    #     print(x)


def _pad_crop_np(items, l):
    results = []
    for item in items:
        if len(item) < l:
            while len(item) != l:
                item.append([0] * 100)
        if len(item) > l:
            item = item[:l]
        results.append(item)
    return torch.tensor(results).float().cuda() if device.type == 'cuda' else torch.tensor(results).float()


def _pad_crop(items, l, val=-1):
    result = []
    for item in items:
        if len(item) < l:
            item = item + [val] * (l - len(item))
        if len(item) > l:
            item = item[:l]
        result.append(item)
    return torch.tensor(result).long().cuda() if device.type == 'cuda' else torch.tensor(result).long()


def _mask(items, l):
    result = []
    for item in items:
        # needs padding (masked)
        if len(item) < l:
            mask = [1. for _ in item] + ([0.] * (l - len(item)))
        # no padding (possible crop)
        if len(item) >= l:
            mask = [1. for _ in item[:l]]
        result.append(mask)
    return torch.tensor(result).float().cuda() if device.type == 'cuda' else torch.tensor(result).float()

# Train 

In [None]:
!pip install -q pyNTCIREVAL

In [None]:
import sys
sys.path.append('/content/drive/My Drive/Colab Notebooks/BERT/CrossBERT/')

!unzip -q '/content/drive/My Drive/Colab Notebooks/BERT/CrossBERT/data_CrossBERT.zip'

In [None]:
import Data
import os, io
import argparse
import subprocess
from time import strftime, localtime
import time
import pandas as pd
import numpy as np
import random, pickle
from tqdm import tqdm
import torch
from pyNTCIREVAL import Labeler
from pyNTCIREVAL.metrics import MSnDCG, nERR, nDCG, AP, RR
import collections

In [None]:
SEED = 42
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
random.seed(SEED)

In [None]:
MODEL_MAP = {'crossbert': CrossBert}

def main(model, dataset, train_pairs, qrels, valid_run, test_run, model_out_dir, qrelDict, modelName, fold,
         metricKeys, MAX_EPOCH, data, args):
    LR = 0.001
    BERT_LR = 2e-5

    params = [(k, v) for k, v in model.named_parameters() if v.requires_grad]
    non_bert_params = {'params': [v for k, v in params if not k.startswith('bert.')]}
    bert_params = {'params': [v for k, v in params if k.startswith('bert.')], 'lr': BERT_LR}
    optimizer = torch.optim.Adam([non_bert_params, bert_params], lr=LR)
    # optimizer = torch.optim.Adam([non_bert_params], lr=LR)

    top_valid_score = None
    bestResults = {}
    bestPredictions = []
    bestQids = []

    print("Fold: %d" % fold)

    if args.model in ["unsup"]:

        test_qids, test_results, test_predictions = validate(model, dataset, test_run, qrelDict, 0,
                                                             model_out_dir, data, args, "test")

        print(test_results["ndcg@15"])
        txt = 'new top validation score, %.4f' % np.mean(test_results["ndcg@10"])
        print2file(args.out_dir, modelName, ".txt", txt, fold)

        bestResults = test_results
        bestPredictions = test_predictions
        bestQids = test_qids
        pass
    else:
        for epoch in range(MAX_EPOCH):
            t2 = time.time()
            loss = train_iteration(model, optimizer, dataset, train_pairs, qrels, data, args)
            txt = f'train epoch={epoch} loss={loss}'
            print2file(args.out_dir, modelName, ".txt", txt, fold)

            valid_qids, valid_results, valid_predictions = validate(model, dataset, valid_run, qrelDict, epoch,
                                                                    model_out_dir, data, args, "valid")

            # valid_score = np.mean(valid_results["rp"])
            valid_score = np.mean(valid_results["ndcg@10"])
            elapsed_time = time.time() - t2
            txt = f'validation epoch={epoch} score={valid_score} : {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}'
            print2file(args.out_dir, modelName, ".txt", txt, fold)
            if top_valid_score is None or valid_score > top_valid_score:
                top_valid_score = valid_score
                # model.save(os.path.join(model_out_dir, 'weights.p'))
                test_qids, test_results, test_predictions = validate(model, dataset, test_run, qrelDict, epoch,
                                                                     model_out_dir, data, args, "test")

                # print(test_results["ndcg@15"])
                txt = 'new top validation score, %.4f' % np.mean(test_results["ndcg@10"])
                print2file(args.out_dir, modelName, ".txt", txt, fold)

                bestResults = test_results
                bestPredictions = test_predictions
                bestQids = test_qids

            # elif args.earlystop and epoch >=4:
            elif args.earlystop:
                break


    #   save outputs to files

    for k in metricKeys:
        result2file(args.out_dir, modelName, "." + k, bestResults[k], bestQids, fold)

    prediction2file(args.out_dir, modelName, ".out", bestPredictions, fold)

    print2file(args.out_dir, modelName, ".txt", txt, fold)
    return bestResults




In [None]:
def train_iteration(model, optimizer, dataset, train_pairs, qrels, data, args):
    BATCH_SIZE = 16
    BATCHES_PER_EPOCH = 32 if "eai" in args.data else 256
    GRAD_ACC_SIZE = 2
    total = 0
    model.train()
    total_loss = 0.
    with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train', leave=False) as pbar:
        for record in Data.iter_train_pairs(model, dataset, train_pairs, qrels, GRAD_ACC_SIZE, data, args):
            scores = model(record['query_tok'],
                           record['query_mask'],
                           record['doc_tok'],
                           record['doc_mask'],
                           record['wiki_tok'],
                           record['wiki_mask'],
                           record['question_tok'],
                           record['question_mask'])

            count = len(record['query_id']) // 2
            scores = scores.reshape(count, 2)
            loss = torch.mean(1. - scores.softmax(dim=1)[:, 0])  # pariwse softmax
            loss.backward()

            total_loss += loss.item()
            total += count
            if total % BATCH_SIZE == 0:
                optimizer.step()
                optimizer.zero_grad()
            pbar.update(count)
            if total >= BATCH_SIZE * BATCHES_PER_EPOCH:
                return total_loss
            # break


def validate(model, dataset, run, qrel, epoch, model_out_dir, data, args, desc):
    runf = os.path.join(model_out_dir, f'{epoch}.run')
    return run_model(model, dataset, run, runf, qrel, data, args, desc)


def run_model(model, dataset, run, runf, qrels, data, args, desc='valid'):
    BATCH_SIZE = 16
    rerank_run = {}
    with torch.no_grad(), tqdm(total=sum(len(r) for r in run.values()), ncols=80, desc=desc, leave=False) as pbar:
        model.eval()
        for records in Data.iter_valid_records(model, dataset, run, BATCH_SIZE, data, args):
            scores = model(records['query_tok'],
                           records['query_mask'],
                           records['doc_tok'],
                           records['doc_mask'],
                           records['wiki_tok'],
                           records['wiki_mask'],
                           records['question_tok'],
                           records['question_mask'])

            for qid, did, score in zip(records['query_id'], records['doc_id'], scores):
                rerank_run.setdefault(qid, {})[did] = score.item()
            pbar.update(len(records['query_id']))
            # break

    res = {"%s@%d" % (i, j): [] for i in ["p", "r", "ndcg", "nerr"] for j in [5, 10, 15, 20]}
    res['map'] = []
    res['mrr'] = []
    res['rp'] = []
    predictions = []
    qids = []

    for qid in rerank_run:
        ranked_list_scores = sorted(rerank_run[qid].items(), key=lambda x: x[1], reverse=True)
        ranked_list = [i[0] for i in ranked_list_scores]
        for (pid, score) in ranked_list_scores:
            predictions.append((qid, pid, score))
        result = eval(qrels[qid], ranked_list)
        for key in res:
            res[key].append(result[key])
        qids.append(qid)
    return qids, res, predictions


def eval(qrels, ranked_list):
    grades = [1, 2, 3, 4]  # a grade for relevance levels 1 and 2 (Note that level 0 is excluded)
    labeler = Labeler(qrels)
    labeled_ranked_list = labeler.label(ranked_list)
    rel_level_num = 5
    xrelnum = labeler.compute_per_level_doc_num(rel_level_num)
    result = {}

    for i in [5, 10, 15, 20]:
        metric = MSnDCG(xrelnum, grades, cutoff=i)
        result["ndcg@%d" % i] = metric.compute(labeled_ranked_list)

        nerr = nERR(xrelnum, grades, cutoff=i)
        result["nerr@%d" % i] = nerr.compute(labeled_ranked_list)

        _ranked_list = ranked_list[:i]
        result["p@%d" % i] = len(set.intersection(set(qrels.keys()), set(_ranked_list))) / len(_ranked_list)
        result["r@%d" % i] = len(set.intersection(set(qrels.keys()), set(_ranked_list))) / len(qrels)

    result["rp"] = len(set.intersection(set(qrels.keys()), set(ranked_list[:len(qrels)]))) / len(qrels)
    metric = MSnDCG(xrelnum, grades, cutoff=i)

    map = AP(xrelnum, grades)
    result["map"] = map.compute(labeled_ranked_list)
    mrr = RR()
    result["mrr"] = mrr.compute(labeled_ranked_list)

    return result


def write2file(path, name, format, output):
    print(output)
    if not os.path.exists(path):
        os.makedirs(path)
    thefile = open(path + name + format, 'a')
    thefile.write("%s\n" % output)
    thefile.close()


def prediction2file(path, name, format, preds, fold):
    if not os.path.exists(path):
        os.makedirs(path)
    thefile = open(path + name + format, 'a')
    for (qid, pid, score) in preds:
        thefile.write("%d\t%s\t%s\t%f\n" % (fold, qid, pid, score))
    thefile.close()

def print2file(path, name, format, printout, fold):
    print(printout)
    if not os.path.exists(path):
        os.makedirs(path)
    thefile = open(path + name + format, 'a')
    thefile.write("%d-%s\n" % (fold, printout))
    thefile.close()

def result2file(path, name, format, res, qids, fold):
    if not os.path.exists(path):
        os.makedirs(path)
    thefile = open(path + name + format, 'a')
    for q, r in zip(qids, res):
        thefile.write("%d\t%s\t%f\n" % (fold, q, r))
    thefile.close()

In [None]:
def main_cli():
    # argument
    parser = argparse.ArgumentParser('CEDR model training and validation')
    parser.add_argument('--model', choices=MODEL_MAP.keys(), default='crossbert')
    parser.add_argument('--data', default='akgg')
    parser.add_argument('--path', default="data/")
    parser.add_argument('--wikifile', default="wikihow")
    parser.add_argument('--questionfile', default="question-qq")
    parser.add_argument('--initial_bert_weights', type=argparse.FileType('rb'))
    parser.add_argument('--model_out_dir', default="models/vbert")
    parser.add_argument('--epoch', type=int, default=1)
    parser.add_argument('--fold', type=int, default=2)
    parser.add_argument('--out_dir', default="out/")
    parser.add_argument('--evalMode', default="all")
    parser.add_argument('--mode', type=int, default=2)
    parser.add_argument('--maxlen', type=int, default=16)
    parser.add_argument('--earlystop', type=int, default=1)

    args = parser.parse_args([])

    args.queryfile = io.TextIOWrapper(io.open("%s%s-query.tsv" % (args.path, args.data.split("-")[0]),'rb'), 'UTF-8')
    args.docfile = io.TextIOWrapper(io.open("%s%s-doc.tsv" % (args.path, args.data.split("-")[0]),'rb'), 'UTF-8')
    args.wikifile = io.TextIOWrapper(io.open("%s%s-%s.tsv" % (args.path, args.data.split("-")[0], args.wikifile),'rb'), 'UTF-8')
    args.questionfile = io.TextIOWrapper(io.open("%s%s-%s.tsv" % (args.path, args.data.split("-")[0], args.questionfile),'rb'), 'UTF-8')

    args.train_pairs = "%s%s-train" % (args.path, args.data)
    args.valid_run = "%s%s-valid" % (args.path, args.data)
    args.test_run = "%s%s-test" % (args.path, args.data)

    args.qrels = io.TextIOWrapper(io.open("%s%s-qrel.tsv" % (args.path, args.data.split("-")[0]),'rb'), 'UTF-8')

    dataset = read_datafiles([args.queryfile, args.docfile, args.wikifile,
                                   args.questionfile])
    args.dataset = dataset
    model = MODEL_MAP[args.model](args).cuda() if device.type == 'cuda' else MODEL_MAP[args.model](args)


    # if args.model == "cedr_pacrr":
    #     args.maxlen = 16 if args.mode == 1 else args.maxlen * args.mode
    #     model = MODEL_MAP[args.model](args).cuda() if Data.device.type == 'cuda' else MODEL_MAP[args.model](
    #         args)

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(pytorch_total_params)

    qrels = read_qrels_dict(args.qrels)

    MAX_EPOCH = args.epoch

    train_pairs = []
    valid_run = []
    test_run = []

    foldNum = args.fold
    for fold in range(foldNum):
        f = open(args.train_pairs + "%d.tsv" % fold, "r")
        train_pairs.append(read_pairs_dict(f))
        f = open(args.valid_run + "%d.tsv" % fold, "r")
        valid_run.append(read_run_dict(f))
        f = open(args.test_run + "%d.tsv" % fold, "r")
        test_run.append(read_run_dict(f))

    if args.initial_bert_weights is not None:
        model.load(args.initial_bert_weights.name)
    os.makedirs(args.model_out_dir, exist_ok=True)

    if not os.path.exists(args.out_dir):
        os.makedirs(args.out_dir)

    timestamp = strftime('%Y_%m_%d_%H_%M_%S', localtime())
    if "birch" in args.model:
        wikiName = args.wikifile.name.split("/")[-1].replace(".tsv", "")
        questionName = args.questionfile.name.split("/")[-1].replace(".tsv", "")
        additionName = []
        if args.mode in [1, 3, 5, 6]:
            additionName.append(wikiName)
        if args.mode in [2, 4, 5, 6]:
            additionName.append(questionName)

        modelName = "%s_m%d_%s_%s_%s_e%d_es%d_%s" % (
            args.model, args.mode, args.data, "_".join(additionName), args.evalMode, args.epoch, args.earlystop, timestamp)
    else:
        wikipediaFile = args.wikifile.name.split("/")[-1].replace(".tsv", "")
        questionFile = args.questionfile.name.split("/")[-1].replace(".tsv", "")
        modelName = "%s_%s_m%d_ml%d_%s_%s_%s_e%d_es%d_%s" % (args.data, args.model, args.mode, args.maxlen, wikipediaFile, questionFile, args.evalMode, args.epoch, args.earlystop, timestamp)

    print(modelName)

    df = pd.read_csv("%s%s-qrel.tsv" % (args.path, args.data.split("-")[0]), sep="\t", names=["qid", "empty", "pid", "rele_label", "etype"])
    qrelDict = collections.defaultdict(dict)
    type2pids = collections.defaultdict(set)
    for qid, prop, label, etype in df[['qid', 'pid', 'rele_label', 'etype']].values:
        qrelDict[str(qid)][str(prop)] = int(label)
        type2pids[str(etype)].add(prop)
    args.type2pids = type2pids


    metricKeys = {"%s@%d" % (i, j): [] for i in ["p", "r", "ndcg", "nerr"] for j in [5, 10, 15, 20]}
    metricKeys["rp"] = []
    metricKeys["mrr"] = []
    metricKeys["map"] = []

    results = []

    t1 = time.time()


    args.isUnsupervised = True if args.model in ["sen_emb"] else False


    for fold in range(len(train_pairs)):
        results.append(
            main(model, dataset, train_pairs[fold], qrels, valid_run[fold], test_run[fold], args.model_out_dir,
                 qrelDict, modelName, fold, metricKeys, MAX_EPOCH, Data, args))
    elapsed_time = time.time() - t1
    txt = f'total : {time.strftime("%H:%M:%S", time.gmtime(elapsed_time))}'
    print2file(args.out_dir, modelName, ".txt", txt, fold)


    #   average results across 5 folds
    output = []
    for k in metricKeys:
        tmp = []
        for fold in range(foldNum):
            tmp.extend(results[fold][k])
        _res = np.mean(tmp)
        output.append("%.4f" % _res)
    write2file(args.out_dir, modelName, ".res", ",".join(output))

In [None]:
if __name__ == '__main__':
    main_cli()

train:   0%|                                           | 0/4096 [00:00<?, ?it/s]

109483781
akgg_crossbert_m2_ml16_akgg-wikihow_akgg-question-qq_all_e1_es1_2020_08_30_22_35_57
Fold: 0


valid:   1%|▍                                | 16/1316 [00:00<00:08, 155.72it/s]

train epoch=0 loss=706.056735008955


test:   2%|▋                                 | 32/1587 [00:00<00:08, 176.09it/s]

validation epoch=0 score=0.5206618202494671 : 00:06:39


train:   0%|                                   | 2/4096 [00:00<05:55, 11.52it/s]

new top validation score, 0.5109
new top validation score, 0.5109
Fold: 1




KeyboardInterrupt: ignored