In [1]:
import os
import sys

uppath = lambda _path, n: os.sep.join(_path.split(os.sep)[:-n])

cur_path = os.path.dirname(os.path.abspath("__file__"))
module_path = uppath(cur_path, 1)  # root_path
if module_path not in sys.path:
    sys.path.append(module_path)

In [2]:
import logging
import os

In [3]:
from src.reranker import Reranker
from src.reranker import RerankerTrainer
from src.reranker.data import PredictionDataset, GroupCollator
from src.reranker.arguments import ModelArguments, DataArguments, RerankerTrainingArguments as TrainingArguments

In [4]:
from transformers import AutoConfig, AutoTokenizer
import torch

In [5]:
logger = logging.getLogger("__name__")

logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    level=logging.INFO 
)

In [6]:
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

In [7]:
output_dir = os.path.join(module_path, 'DATA/inference_output')
model_ck_path = os.path.join(module_path, 'DATA/output')
tokenizer_name = 'bert-base-uncased'
max_len = 512

In [8]:
model_args = ModelArguments(
    model_name_or_path = model_ck_path,
    tokenizer_name = tokenizer_name,
)

In [9]:
data_args = DataArguments(
    pred_path = os.path.join(output_dir, 'q_100.json'),  # all.json
    pred_id_file = os.path.join(output_dir, 'q_100_ids.tsv'),
    rank_score_path = os.path.join(output_dir, 'output_score'),
    max_len = 512
)

In [10]:
predict_args = TrainingArguments(
    output_dir = os.path.join(module_path, 'DATA/inference_output'),
    do_predict = True,
    per_device_eval_batch_size = 64, 
    dataloader_num_workers = 8,
)

In [11]:
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_name,
    use_fast=False,
)

In [12]:
test_dataset = PredictionDataset(
    data_args.pred_path,
    tokenizer=tokenizer,
    max_len=data_args.max_len,
)
assert data_args.pred_id_file is not None

pred_qids = []
pred_pids = []
with open(data_args.pred_id_file) as f:
    for l in f:
        q, p = l.split()
        pred_qids.append(q)
        pred_pids.append(p)



Downloading and preparing dataset json/default to /home/ubuntu/.cache/huggingface/datasets/json/default-853df0cf549532c9/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Dataset json downloaded and prepared to /home/ubuntu/.cache/huggingface/datasets/json/default-853df0cf549532c9/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [13]:
logger.info("Data parameters %s", data_args)

03/29/2022 16:17:49 - INFO - Data parameters DataArguments(train_dir=None, train_path=None, train_group_size=8, dev_path=None, pred_path='/home/ubuntu/minung/Reranker/DATA/inference_output/q_100.json', pred_dir=None, pred_id_file='/home/ubuntu/minung/Reranker/DATA/inference_output/q_100_ids.tsv', rank_score_path='/home/ubuntu/minung/Reranker/DATA/inference_output/output_score', max_len=512)


In [14]:
model = Reranker.from_pretrained(
    model_args, 
    data_args, 
    predict_args,
    model_ck_path,
)

In [15]:
model

Reranker(
  (hf_model): BertForSequenceClassification(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)

In [16]:
reranker = RerankerTrainer(
    model=model,
    args=predict_args,
    data_collator=GroupCollator(tokenizer),
)

In [18]:
pred_scores = reranker.predict(test_dataset=test_dataset).predictions

In [19]:
if reranker.is_world_process_zero():
    assert len(pred_qids) == len(pred_scores)
    with open(data_args.rank_score_path, "w") as writer:
        for qid, pid, score in zip(pred_qids, pred_pids, pred_scores):
            writer.write(f'{qid}\t{pid}\t{score}\n')

In [19]:
pred_scores

array([[ -6.40053 ],
       [-14.214392],
       [-17.128588],
       ...,
       [ -9.769784],
       [-16.292362],
       [-11.852869]], dtype=float32)

In [20]:
len(pred_scores)

519300