<a href="https://colab.research.google.com/github/jx-dohwan/Aiffel_Exp_Project/blob/main/Document_search/Document_search.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [18]:
!pip install torch
!pip install transformers
!pip install datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


In [20]:
from datasets import load_dataset

qa_data = load_dataset('Tevatron/msmarco-passage')
corpus = load_dataset('Tevatron/msmarco-passage-corpus')



  0%|          | 0/2 [00:00<?, ?it/s]



  0%|          | 0/1 [00:00<?, ?it/s]

In [21]:
qa_data

DatasetDict({
    train: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
        num_rows: 400782
    })
    dev: Dataset({
        features: ['query_id', 'query', 'positive_passages', 'negative_passages'],
        num_rows: 6980
    })
})

In [22]:
corpus

DatasetDict({
    train: Dataset({
        features: ['docid', 'title', 'text'],
        num_rows: 8841823
    })
})

In [23]:
from transformers import BertTokenizer, BertModel

In [24]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

In [25]:
# model -> dual encoder

from transformers import BertModel
from torch import nn

class Retriever(nn.Module):
    def __init__(self, opts):
        super(Retriever, self).__init__()
        self.q_encoder = BertModel.from_pretrained('bert-base-uncased') # 768
        self.ctx_encoder = BertModel.from_pretrained('bert-base-uncased')
        self.proj_q = nn.Linear(opts.bert_hidden_dim, opts.proj_dim) # MIPS 256
        self.proj_ctx = nn.Linear(opts.bert_hidden_dim, opts.proj_dim) # 256
        self.layer_norm = nn.LayerNorm(opts.proj_dim)
        self.dropout = nn.Dropout(opts.dropout)

    def forward(self, q_ids, q_attention_mask, q_type_ids, ctx_ids, ctx_attention_mask, ctx_type_ids): # title [SEP] passage
        q_vector = self.get_q_emb(q_ids, q_attention_mask, q_type_ids)[:,0,:]
        q_vector_proj = self.proj_q(q_vector)
        q_vector_proj = self.layer_norm(self.dropout(q_vector_proj))
        ctx_vector = self.get_ctx_emb(ctx_ids, ctx_attention_mask, ctx_type_ids)[:,0,:]
        ctx_vector_proj = self.proj_ctx(ctx_vector)
        ctx_vector_proj = self.layer_norm(ctx_vector_proj)

        return q_vector, ctx_vector

    def get_q_emb(self, q_ids, q_attention_mask, q_type_ids):
        q_vector = self.q_encoder(input_ids=q_ids, attention_mask=q_attention_mask, token_type_ids=q_type_ids)[:, 0, :]
        q_vector_proj = self.proj_q(q_vector)
        q_vector_proj = self.layer_norm(self.dropout(q_vector_proj))
        return q_vector_proj

    def get_ctx_emb(self, ctx_ids, ctx_attention_mask, ctx_type_ids):
        ctx_vector = self.ctx_encoder(input_ids=ctx_ids, attention_mask=ctx_attention_mask, token_type_ids=ctx_type_ids)[:, 0, :]
        ctx_vector_proj = self.proj_ctx(self.dropout(ctx_vector))
        ctx_vector_proj = self.layer_norm(ctx_vector_proj)
        return ctx_vector_proj

In [26]:
!touch model.py

In [27]:
import torch
import random
import json
import numpy as np

class RetrieverDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, split='training'):
        self.data = data
        self.tokenizer = tokenizer
        self.split = split

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        question = self.data[index]['question']
        if self.split == 'test':
            q = self.tokenizer(question, max_length=40, padding="max_length", truncation=True, return_tensors='pt')
            return index, q['input_ids'].squeeze(0), q['attention_mask'].squeeze(0), q['token_type_ids'].squeeze(0)

        ctx = self.tokenizer(self.data[index]['ctx'], max_length=200, padding="max_length", truncation=True, return_tensors='pt')
        q = self.tokenizer(question, max_length=40, padding="max_length", truncation=True, return_tensors='pt')
        return q['input_ids'].squeeze(0), q['attention_mask'].squeeze(0), q['token_type_ids'].squeeze(0), ctx['input_ids'].squeeze(0), ctx['attention_mask'].squeeze(0), ctx['token_type_ids'].squeeze(0)

    def get_example(self, index):
        return self.data[index]


class PassageDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        token = self.tokenizer(self.data[index][2], self.data[index][1], max_length=200, padding="max_length", truncation=True,
                               return_tensors='pt')

        return int(self.data[index][0])-1, token['input_ids'].squeeze(0), token['attention_mask'].squeeze(0), token['token_type_ids'].squeeze(0)

    def get_example(self, index):
        return self.data[index]

class ReaderDataset(torch.utils.data.Dataset):
    def __init__(self, data, n_context=5):
        self.data = data
        self.question_prefix = "question:"
        self.passage_prefix = "context:"
        self.title_prefix = "title:"
        self.n_context = n_context

    def __len__(self):
        return len(self.data)

    def get_target(self, example):
        if 'target' in example:
            target = example['target']
            return target + ' </s>'
        elif 'answers' in example:
            return random.choice(example['answers']) + ' </s>'
        else:
            return None

    def __getitem__(self, index):
        example = self.data[index]
        question = self.question_prefix + " " + example['question']
        target = self.get_target(example)

        if 'ctxs' in example and self.n_context is not None:
            f = self.title_prefix + " {} " + self.passage_prefix + " {}"
            contexts = example['ctxs'][:self.n_context]
            passages = [f.format(c['title'], c['text']) for c in contexts]
        else:
            passages = None


        return {
            'index' : index,
            'question' : question,
            'target' : target,
            'passages' : passages,
        }

class BertQADataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    #def __len__(self):

    #def __getitem__(self, idx):


class ICTDataset(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer):
        self.data = data
        self.tokenizer = tokenizer

    #def __len__(self):

    #def __getitem__(self, idx):

def encode_passages(batch_text_passages, tokenizer, max_length):
    passage_ids, passage_masks = [], []
    for k, text_passages in enumerate(batch_text_passages):
        p = tokenizer.batch_encode_plus(
            text_passages,
            max_length=max_length,
            pad_to_max_length=True,
            return_tensors='pt',
            truncation=True
        )
        passage_ids.append(p['input_ids'][None])
        passage_masks.append(p['attention_mask'][None])

    passage_ids = torch.cat(passage_ids, dim=0)
    passage_masks = torch.cat(passage_masks, dim=0)
    return passage_ids, passage_masks.bool()

class Collator(object):
    def __init__(self, text_maxlength, tokenizer, answer_maxlength=20):
        self.tokenizer = tokenizer
        self.text_maxlength = text_maxlength
        self.answer_maxlength = answer_maxlength

    def __call__(self, batch):
        assert(batch[0]['target'] != None)
        index = torch.tensor([ex['index'] for ex in batch])
        target = [ex['target'] for ex in batch]
        target = self.tokenizer.batch_encode_plus(
            target,
            max_length=self.answer_maxlength if self.answer_maxlength > 0 else None,
            pad_to_max_length=True,
            return_tensors='pt',
            truncation=True if self.answer_maxlength > 0 else False,
        )
        target_ids = target["input_ids"]
        target_mask = target["attention_mask"].bool()
        target_ids = target_ids.masked_fill(~target_mask, -100)

        def append_question(example):
            if example['passages'] is None:
                return [example['question']]
            return [example['question'] + " " + t for t in example['passages']]
        text_passages = [append_question(example) for example in batch]
        passage_ids, passage_masks = encode_passages(text_passages,
                                                     self.tokenizer,
                                                     self.text_maxlength)

        return (index, target_ids, target_mask, passage_ids, passage_masks)


def load_data(data_path=None, global_rank=-1, world_size=-1):
    assert data_path
    if data_path.endswith('.jsonl'):
        data = open(data_path, 'r')
    elif data_path.endswith('.json'):
        with open(data_path, 'r') as fin:
            data = json.load(fin)
    examples = []
    for k, example in enumerate(data):
        if global_rank > -1 and not k%world_size==global_rank:
            continue
        if data_path is not None and data_path.endswith('.jsonl'):
            example = json.loads(example)
        if not 'id' in example:
            example['id'] = k
        for c in example['ctxs']:
            if not 'score' in c:
                c['score'] = 1.0 / (k + 1)
        examples.append(example)
    ## egrave: is this needed?
    if data_path is not None and data_path.endswith('.jsonl'):
        data.close()

    return examples
    

class RetrieverCollator(object):
    def __init__(self, tokenizer, passage_maxlength=200, question_maxlength=40):
        self.tokenizer = tokenizer
        self.passage_maxlength = passage_maxlength
        self.question_maxlength = question_maxlength

    def __call__(self, batch):
        index = torch.tensor([ex['index'] for ex in batch])

        question = [ex['question'] for ex in batch]
        question = self.tokenizer.batch_encode_plus(
            question,
            pad_to_max_length=True,
            return_tensors="pt",
            max_length=self.question_maxlength,
            truncation=True
        )
        question_ids = question['input_ids']
        question_mask = question['attention_mask'].bool()

        if batch[0]['scores'] is None or batch[0]['passages'] is None:
            return index, question_ids, question_mask, None, None, None

        scores = [ex['scores'] for ex in batch]
        scores = torch.stack(scores, dim=0)

        passages = [ex['passages'] for ex in batch]
        passage_ids, passage_masks = encode_passages(
            passages,
            self.tokenizer,
            self.passage_maxlength
        )

        return (index, question_ids, question_mask, passage_ids, passage_masks, scores)


class TextDataset(torch.utils.data.Dataset):
    def __init__(self,
                 data,
                 title_prefix='title:',
                 passage_prefix='context:'):
        self.data = data
        self.title_prefix = title_prefix
        self.passage_prefix = passage_prefix

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        example = self.data[index]
        text = self.title_prefix + " " + example[2] + " " + \
            self.passage_prefix + " " + example[1]
        return example[0], text

class TextCollator(object):
    def __init__(self, tokenizer, maxlength=200):
        self.tokenizer = tokenizer
        self.maxlength = maxlength

    def __call__(self, batch):
        index = [x[0] for x in batch]
        encoded_batch = self.tokenizer.batch_encode_plus(
            [x[1] for x in batch],
            pad_to_max_length=True,
            return_tensors="pt",
            max_length=self.maxlength,
            truncation=True
        )
        text_ids = encoded_batch['input_ids']
        text_mask = encoded_batch['attention_mask'].bool()

        return index, text_ids, text_mask

In [28]:
!touch dataset.py

In [36]:
!pip install KorOpenQA

[31mERROR: You must give at least one requirement to install (see "pip help install")[0m[31m
[0m

In [None]:
# train, distributed training

import sys, os
#sys.path.append(os.path.abspath('..'))
import torch
from model import Retriever
import torch.nn as nn
#import util
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
from dataset import RetrieverDataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
import numpy as np
import json
import KorOpenQA.evaluation
from torch.utils.tensorboard import SummaryWriter

from argparse import ArgumentParser

parser = ArgumentParser(description="Retriever")
parser.add_argument_group()
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--total_steps', type=int, default=500000)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--save_freq', type=int, default=5000)
parser.add_argument('--eval_freq', type=int, default=500)
parser.add_argument('--gpus', type=int, default=4)
parser.add_argument('--train_data_path', type=str, default='data/train.json')
parser.add_argument('--dev_data_path', type=str, default='data/dev.json')
parser.add_argument('--checkpoint_path', type=str, default='checkpoint')
parser.add_argument('--warmup_steps', type=int, default=1000)
parser.add_argument('--scheduler_steps', type=int, default=None,
                         help='total number of step for the scheduler, if None then scheduler_total_step = total_step')
parser.add_argument('--accumulation_steps', type=int, default=1)
parser.add_argument('--dropout', type=float, default=0.1, help='dropout rate')
parser.add_argument('--lr', type=float, default=1e-5, help='learning rate')
parser.add_argument('--clip', type=float, default=1., help='gradient clipping')
parser.add_argument('--optim', type=str, default='adam')
parser.add_argument('--scheduler', type=str, default='fixed')
parser.add_argument('--weight_decay', type=float, default=0.1)
parser.add_argument('--bert_hidden_dim', type=int, default=768)
parser.add_argument('--proj_dim', type=int, default=256)
parser.add_argument('--node_rank', type=int, default=0)
parser.add_argument('--num_nodes', type=int, default=1)
parser.add_argument('--multi_node', type=int, default=0)
parser.add_argument('--num_gpus', type=int, default=-1)

args = parser.parse_args()

def evaluate(model, valid_loader, loss_fn, rank):
    model.eval()
    loss = 0
    tot = 0
    with torch.no_grad():
        for i, batch_data in enumerate(valid_loader):
            q_ids, q_attention_mask, q_type_ids, ctx_ids, ctx_attention_mask, ctx_type_ids = batch_data
            bsz, psg_len = ctx_ids.size()
            ctx_ids = ctx_ids.view(-1, psg_len)
            ctx_attention_mask = ctx_attention_mask.view(-1, psg_len)
            ctx_type_ids = ctx_type_ids.view(-1, psg_len)
            q_vector, ctx_vector = model(q_ids.to(rank), q_attention_mask.to(rank), q_type_ids.to(rank),
                                         ctx_ids.to(rank), ctx_attention_mask.to(rank), ctx_type_ids.to(rank))
            sim = torch.matmul(q_vector, ctx_vector.T)
            labels = torch.arange(0, bsz, 1, dtype=torch.int64, device=rank)
            loss += loss_fn(sim, labels).item()
            tot += q_ids.size(0)

    return loss/tot

def train(rank, world_size):
    logger = SummaryWriter()
    if args.multi_node:
        rank = args.node_rank * world_size + rank
        dist.init_process_group(backend='nccl', init_method="env://", rank=rank, world_size=args.num_nodes * world_size)
    else:
        dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
    model = Retriever(args).to(rank)
    model = DDP(model, find_unused_parameters=True)
    tokenizer = BertTokenizerFast.from_pretrained('./BERT')
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    scheduler =  KorOpenQA.util.WarmupLinearScheduler(optimizer, warmup_steps=args.warmup_steps, scheduler_steps=args.scheduler_steps, min_ratio=0., fixed_lr=True)
    train_examples = KorOpenQA.util.load_data(args.train_data_path)
    valid_examples = KorOpenQA.util.load_data(args.dev_data_path)
    trainset = RetrieverDataset(train_examples, tokenizer)
    validset = RetrieverDataset(valid_examples, tokenizer, split='validation')
    train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
    valid_sampler = torch.utils.data.distributed.DistributedSampler(validset)
    train_loader = DataLoader(dataset=trainset, batch_size=args.batch_size, num_workers=4, sampler=train_sampler)
    valid_loader = DataLoader(dataset=validset, batch_size=args.batch_size, num_workers=4, sampler=valid_sampler)
    tot_step = 0
    best_loss = np.inf
    for epoch in range(1, args.epochs+1):
        train_sampler.set_epoch(epoch)
        for i, batch_data in enumerate(train_loader):
            model.train()
            tot_step += 1
            q_ids, q_attention_mask, q_type_ids, ctx_ids, ctx_attention_mask, ctx_type_ids = batch_data
            bsz, psg_len = ctx_ids.size()
            ctx_ids =  ctx_ids.view(-1, psg_len)
            ctx_attention_mask = ctx_attention_mask.view(-1, psg_len)
            ctx_type_ids = ctx_type_ids.view(-1, psg_len)
            q_vector, ctx_vector = model(q_ids.to(rank), q_attention_mask.to(rank), q_type_ids.to(rank), ctx_ids.to(rank), ctx_attention_mask.to(rank), ctx_type_ids.to(rank))
            sim = torch.matmul(q_vector, ctx_vector.T) # bs * dim, dim * bs -> bs*bs [[q1p1, q1p2, q1p3 .... q1pn,],
                                                                                   #   [q2p1, q2p2, q2p3, .... q3pn], ... [qnp1, ...]]
            labels = torch.arange(0, bsz, 1, dtype=torch.int64, device=rank)
            loss = loss_fn(sim, labels)
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            print ('step train loss : ' + str(tot_step) + ' : ' + str(float(loss.item())/q_ids.size(0)))
            logger.add_scalar("train_loss", float(loss.item())/q_ids.size(0), tot_step)

            if tot_step % args.eval_freq == 0:
                model.eval()
                val_loss = evaluate(model, valid_loader, loss_fn, rank)
                print('valid loss : ' + str(tot_step) + ' : ' + str(float(val_loss)))
                logger.add_scalar("valid_loss", float(loss),     tot_step)
                if best_loss > val_loss:
                    best_loss = val_loss
                    if dist.get_rank() == 0:
                        torch.save(model.module.state_dict(), args.checkpoint_path+'/'+str(tot_step)+'_retriever.pt')

            if tot_step % args.save_freq == 0:
                if dist.get_rank() == 0:
                    torch.save(model.module.state_dict(), args.checkpoint_path+'/'+str(tot_step)+'_retriever.pt')

if __name__ == '__main__':
    torch.manual_seed(777)
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '16937'
    if not os.path.exists(args.checkpoint_path):
        os.mkdir(args.checkpoint_path)
    if args.num_gpus == -1:
        world_size = torch.cuda.device_count()
    else:
        world_size = args.num_gpus
    # gpu - 4, process - 4, 
    mp.spawn(train, nprocs=world_size, args=(world_size, ))


In [None]:
!conda install -c pytorch faiss-cpu

In [None]:
# indexer
# max inner product search
import faiss
import logging
import numpy as np
import os
import pickle

class DenseFlatIndexer(object):
    def __init__(self, buffer_size: int = 50000):
        self.buffer_size = buffer_size

    def init_index(self, vector_sz: int):
        self.index = faiss.IndexFlatIP(vector_sz)

    def index_data(self, data):
        n = len(data)
        # indexing in batches is beneficial for many faiss index types
        for i in range(0, n, self.buffer_size):
            vectors = np.array(data[1]).astype('float32')
            self.index.add(vectors)

    def search_knn(self, query_vectors: np.array, top_docs: int):
        scores, indexes = self.index.search(query_vectors, top_docs)
        return (scores, indexes)

    def get_index_name(self):
        return "flat_index"

In [None]:
# indexing

# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import os, json
import sys
sys.path.append('..')
import math
import pickle
from pathlib import Path

import torch
from torch.utils.data import DataLoader

import transformers
from KorOpenQA.model.retriever import Retriever
from dataset import PassageDataset
import KorOpenQA.util
import time
import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from argparse import ArgumentParser

parser = ArgumentParser(description="generate_index")
parser.add_argument_group()

parser.add_argument('--batch_size', type=int, default=2048)
parser.add_argument('--data_path', type=str, default='data/psgs_w100.tsv')
parser.add_argument('--output_path', type=str, default='index')
parser.add_argument('--dropout', type=float, default=0.0, help='dropout rate')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--bert_hidden_dim', type=int, default=768)
parser.add_argument('--proj_dim', type=int, default=256)
parser.add_argument('--shard_size', type=int, default=7)
parser.add_argument('--multi', type=int, default=1)

args = parser.parse_args()

def embed_passages(model, tokenizer, passages, rank=None):
    tot = 0
    tot_step = 0
    allids, allembeddings = [], []
    dataset = PassageDataset(passages, tokenizer)
    #if args.multi:
        #train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        #dataloader = DataLoader(dataset, shuffle=False, num_workers=4, batch_size=args.batch_size, sampler=train_sampler)
    #else:
    dataloader = DataLoader(dataset, shuffle=False, num_workers=4, batch_size=args.batch_size)
    with torch.no_grad():
        for i, batch_data in enumerate(dataloader):
            tot_step += 1
            ids, ctx_ids, ctx_attention_mask, ctx_type_ids = batch_data
            tot += ctx_ids.size(0)
            if args.multi:
                embeddings = model.module.get_ctx_emb(
                    ctx_ids=ctx_ids.to(rank),
                    ctx_attention_mask=ctx_attention_mask.to(rank),
                    ctx_type_ids=ctx_type_ids.to(rank)
                )
            else:
                embeddings = model.get_ctx_emb(
                    ctx_ids=ctx_ids.to(args.device),
                    ctx_attention_mask=ctx_attention_mask.to(args.device),
                    ctx_type_ids=ctx_type_ids.to(args.device)
                )
            embeddings = embeddings.cpu()

            print (str(tot_step) + ' step done...')
            print (str(tot) + ' passages embeded...')
            allids.extend([id.item() for id in ids])
            allembeddings.append(embeddings)

    allembeddings = torch.cat(allembeddings, dim=0).numpy()
    return allids, allembeddings


def dist_main(rank, world_size, passages):
    dist.init_process_group(backend='nccl', rank=rank, world_size=world_size)
    shard_size = math.ceil(len(passages)/world_size)
    start_index = shard_size * rank
    end_index = start_index + shard_size
    shard_passages = passages[start_index:end_index]
    tokenizer = transformers.BertTokenizerFast.from_pretrained('./BERT')
    model = Retriever(args).to(rank)
    model.load_state_dict(torch.load('retriever.pt'))
    model = DDP(model, find_unused_parameters=True)
    model.eval()
    allids, allembeddings = embed_passages(model, tokenizer, shard_passages, rank=rank)

    output_path = Path(args.output_path)
    #save_file = output_path.parent / (output_path.name + '_'+str(rank)+'.txt')
    save_file = output_path.parent / (output_path.name + '_'+str(rank)+'.pkl')
    output_path.parent.mkdir(parents=True, exist_ok=True)
    #with open(save_file, mode='w') as f:
    #    for id, embedding in zip(allids, allembeddings):
    #        f.write(str(id) + '\t' + str(list(embedding)) + '\n')
    with open(save_file, mode='wb') as f:
        pickle.dump((allids, allembeddings), f, protocol=4)
    print ("index saved...")

def main(passages):
    tokenizer = transformers.BertTokenizerFast.from_pretrained('bert-base-uncased')
    model = Retriever(args)
    model.load_state_dict(torch.load('retriever.pt'))
    model.to(args.device)
    model.eval()
    shard_size = math.ceil(len(passages) / args.shard_size)
    for i in range(shard_size):
        start_index = shard_size * i
        end_index = start_index + shard_size
        shard_passages = passages[start_index:end_index]
        allids, allembeddings = embed_passages(model, tokenizer, shard_passages)

        output_path = Path(args.output_path)
        save_file = output_path.parent / (output_path.name + '_' + str(i) + '.pkl')
        output_path.parent.mkdir(parents=True, exist_ok=True)
        with open(save_file, mode='wb') as f:
            pickle.dump((allids, allembeddings), f)
        print ("index saved...")
        del allids
        del allembeddings


if __name__ == '__main__':
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    torch.manual_seed(777)
    start = time.time()
    passages = []
    with open('passages.json', 'r') as f:
        for line in f:
            line = json.loads(line)
            passages.append((line['id'], line['text'], line['title']))
    if args.multi:
        world_size = torch.cuda.device_count()
        mp.spawn(dist_main, nprocs=world_size, args=(world_size, passages))
    else:
        main(passages)
    print ("time : " + str(time.time()-start))

In [None]:
# retrieval
import sys, os
sys.path.append(os.path.abspath('..'))
import torch
from KorOpenQA.model.retriever import Retriever
import torch.nn as nn
import KorOpenQA.util
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.multiprocessing as mp
import torch.distributed as dist
from dataset import RetrieverDataset
from torch.utils.data import DataLoader
from transformers import BertTokenizerFast
import numpy as np
import json
import KorOpenQA.index
from torch.utils.tensorboard import SummaryWriter
import pickle

from argparse import ArgumentParser

parser = ArgumentParser(description="Retriever")
parser.add_argument_group()
parser.add_argument('--epochs', type=int, default=1000)
parser.add_argument('--total_steps', type=int, default=500000)
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--save_freq', type=int, default=5000)
parser.add_argument('--eval_freq', type=int, default=500)
parser.add_argument('--gpus', type=int, default=4)
parser.add_argument('--test_data_path', type=str, default='data/test.json')
parser.add_argument('--checkpoint_path', type=str, default='checkpoint')
parser.add_argument('--dropout', type=float, default=0.0, help='dropout rate')
parser.add_argument('--scheduler', type=str, default='fixed')
parser.add_argument('--bert_hidden_dim', type=int, default=768)
parser.add_argument('--proj_dim', type=int, default=256)
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--topk', type=int, default=5)
parser.add_argument('--shard_num', type=int, default=4)

args = parser.parse_args()

def test():
    model = Retriever(args).to(args.device)
    model.load_state_dict(torch.load('retriever.pt'))
    tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
    test_examples = KorOpenQA.util.load_data(args.test_data_path)
    testset = RetrieverDataset(test_examples, tokenizer, split='test')
    test_loader = DataLoader(dataset=testset, batch_size=args.batch_size, num_workers=4, shuffle=False)
    tot_step = 0
    indexer = KorOpenQA.index.DenseFlatIndexer()
    indexer.init_index(args.proj_dim)
    for i in range(args.shard_num):
        with open('index_'+str(i)+'.pkl', 'rb') as f:
            index = pickle.load(f)
            indexer.index_data(index)
        with open('passage_db.pkl', 'rb') as f:
            passage_db = pickle.load(f)
        with open('data/result.json', 'w') as f:
            result = []
            for i, batch_data in enumerate(test_loader):
                model.train()
                tot_step += 1
                idx, q_ids, q_attention_mask, q_type_ids = batch_data

                ## query
                ## tokenizing
                ## q encoder
                ## faiss -> topk -> search!!
                q_vector = model.get_q_emb(q_ids.to(args.device), q_attention_mask.to(args.device), q_type_ids.to(args.device))
                q_vector = q_vector.detach().cpu().numpy()
                scores, indexes = indexer.search_knn(q_vector, args.topk)
                
                # MS-marco -> q-p (gold)
                # q index -> p 
                # MRR -> avg(1/rank)
                k = 0
                print ("retrieval...")

                # ODQA -> question & answer (passage ??)
                for score, index in zip(scores, indexes):
                    temp = {}
                    temp['question'] = testset.get_example(idx[k])['question']
                    temp['answers'] = testset.get_example(idx[k])['answer']
                    temp['ctxs'] = []
                    for passage_idx in index:
                        temp['ctxs'].append({'title': passage_db[passage_idx][0], 'text': passage_db[passage_idx][1]})
                    result.append(temp)
                    k+=1

            json.dump(result, f)

if __name__ == '__main__':
    torch.manual_seed(777)
    test()
