In [None]:
model_checkpoint = 'klue/bert-base'

In [11]:
import time
import tqdm
import string
import pickle
import os.path as p

import torch
import numpy as np
import torch.nn as nn
from datasets import load_from_disk
from transformers import AdamW, TrainingArguments
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, BertConfig


class QueryTokenizer:
    def __init__(self):
        self.tok = BertTokenizerFast.from_pretrained(model_checkpoint)

        self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.tok.convert_tokens_to_ids("[unused0]")
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
        self.query_maxlen = self.tok.model_max_length


    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst) + 3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        ids = self.tok(batch_text, add_special_tokens=False)["input_ids"]

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst) + 3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], type(batch_text)

        # add placehold for the [Q] marker
        batch_text = [". " + x for x in batch_text]

        obj = self.tok(
            batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=self.tok.model_max_length
        )

        ids, mask = obj["input_ids"], obj["attention_mask"]

        # postprocess for the [Q] marker and the [MASK] augmentation
        ids[:, 1] = self.Q_marker_token_id
        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask


class DocTokenizer:
    def __init__(self):
        self.tok = BertTokenizerFast.from_pretrained(model_checkpoint)

        self.D_marker_token, self.D_marker_token_id = "[D]", self.tok.convert_tokens_to_ids("[unused1]")
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

        # assert self.D_marker_token_id == 1

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        ids = self.tok(batch_text, add_special_tokens=False)["input_ids"]

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], type(batch_text)

        # add placehold for the [D] marker
        batch_text = [". " + x for x in batch_text]

        obj = self.tok(
            batch_text, padding="max_length", truncation=True, return_tensors="pt", max_length=self.tok.model_max_length
        )

        ids, mask = obj["input_ids"], obj["attention_mask"]

        # postprocess for the [D] marker
        ids[:, 1] = self.D_marker_token_id

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask


def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    # assert len(queries) == len(positives) == len(negatives)
    # assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    queries = queries.to_list()
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)

    positives = positives.to_list()
    negatives = negatives.to_list()

    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches


def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices


def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset : offset + bsize], mask[offset : offset + bsize]))

    return batches


import os
import tqdm
import torch
import datetime
import itertools

from multiprocessing import Pool
from collections import OrderedDict, defaultdict


def print_message(*s, condition=True):
    s = ' '.join([str(x) for x in s])
    msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)

    if condition:
        print(msg, flush=True)

    return msg


def timestamp():
    format_str = "%Y-%m-%d_%H.%M.%S"
    result = datetime.datetime.now().strftime(format_str)
    return result


def file_tqdm(file):
    print(f"#> Reading {file.name}")

    with tqdm.tqdm(total=os.path.getsize(file.name) / 1024.0 / 1024.0, unit="MiB") as pbar:
        for line in file:
            yield line
            pbar.update(len(line) / 1024.0 / 1024.0)

        pbar.close()


def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer, arguments=None):
    print(f"#> Saving a checkpoint to {path} ..")

    if hasattr(model, 'module'):
        model = model.module  # extract model from a distributed/data-parallel wrapper

    checkpoint = {}
    checkpoint['epoch'] = epoch_idx
    checkpoint['batch'] = mb_idx
    checkpoint['model_state_dict'] = model.state_dict()
    checkpoint['optimizer_state_dict'] = optimizer.state_dict()
    checkpoint['arguments'] = arguments

    torch.save(checkpoint, path)


def load_checkpoint(path, model, optimizer=None, do_print=True):
    if do_print:
        print_message("#> Loading checkpoint", path, "..")

    if path.startswith("http:") or path.startswith("https:"):
        checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
    else:
        checkpoint = torch.load(path, map_location='cpu')

    state_dict = checkpoint['model_state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if k[:7] == 'module.':
            name = k[7:]
        new_state_dict[name] = v

    checkpoint['model_state_dict'] = new_state_dict

    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        print_message("[WARNING] Loading checkpoint with strict=False")
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if do_print:
        print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
        print_message("#> checkpoint['batch'] =", checkpoint['batch'])

    return checkpoint


def create_directory(path):
    if os.path.exists(path):
        print('\n')
        print_message("#> Note: Output directory", path, 'already exists\n\n')
    else:
        print('\n')
        print_message("#> Creating directory", path, '\n\n')
        os.makedirs(path)

# def batch(file, bsize):
#     while True:
#         L = [ujson.loads(file.readline()) for _ in range(bsize)]
#         yield L
#     return


def f7(seq):
    """
    Source: https://stackoverflow.com/a/480227/1493011
    """

    seen = set()
    return [x for x in seq if not (x in seen or seen.add(x))]


def batch(group, bsize, provide_offset=False):
    offset = 0
    while offset < len(group):
        L = group[offset: offset + bsize]
        yield ((offset, L) if provide_offset else L)
        offset += len(L)
    return


class dotdict(dict):
    """
    dot.notation access to dictionary attributes
    Credit: derek73 @ https://stackoverflow.com/questions/2352181
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def flatten(L):
    return [x for y in L for x in y]


def zipstar(L, lazy=False):
    """
    A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...])
    May return lists or tuples.
    """

    if len(L) == 0:
        return L

    width = len(L[0])

    if width < 100:
        return [[elem[idx] for elem in L] for idx in range(width)]

    L = zip(*L)

    return L if lazy else list(L)


def zip_first(L1, L2):
    length = len(L1) if type(L1) in [tuple, list] else None

    L3 = list(zip(L1, L2))

    assert length in [None, len(L3)], "zip_first() failure: length differs!"

    return L3


def int_or_float(val):
    if '.' in val:
        return float(val)
        
    return int(val)

def load_ranking(path, types=None, lazy=False):
    print_message(f"#> Loading the ranked lists from {path} ..")

    try:
        lists = torch.load(path)
        lists = zipstar([l.tolist() for l in tqdm.tqdm(lists)], lazy=lazy)
    except:
        if types is None:
            types = itertools.cycle([int_or_float])

        with open(path) as f:
            lists = [[typ(x) for typ, x in zip_first(types, line.strip().split('\t'))]
                     for line in file_tqdm(f)]

    return lists


def save_ranking(ranking, path):
    lists = zipstar(ranking)
    lists = [torch.tensor(l) for l in lists]

    torch.save(lists, path)

    return lists


def groupby_first_item(lst):
    groups = defaultdict(list)

    for first, *rest in lst:
        rest = rest[0] if len(rest) == 1 else rest
        groups[first].append(rest)

    return groups


def process_grouped_by_first_item(lst):
    """
        Requires items in list to already be grouped by first item.
    """

    groups = defaultdict(list)

    started = False
    last_group = None

    for first, *rest in lst:
        rest = rest[0] if len(rest) == 1 else rest

        if started and first != last_group:
            yield (last_group, groups[last_group])
            assert first not in groups, f"{first} seen earlier --- violates precondition."

        groups[first].append(rest)

        last_group = first
        started = True

    return groups


def grouper(iterable, n, fillvalue=None):
    """
    Collect data into fixed-length chunks or blocks
        Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
        Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
    """

    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)


# see https://stackoverflow.com/a/45187287
class NullContextManager(object):
    def __init__(self, dummy_resource=None):
        self.dummy_resource = dummy_resource
    def __enter__(self):
        return self.dummy_resource
    def __exit__(self, *args):
        pass


def load_batch_backgrounds(args, qids):
    if args.qid2backgrounds is None:
        return None

    qbackgrounds = []

    for qid in qids:
        back = args.qid2backgrounds[qid]

        if len(back) and type(back[0]) == int:
            x = [args.collection[pid] for pid in back]
        else:
            x = [args.collectionX.get(pid, '') for pid in back]

        x = ' [SEP] '.join(x)
        qbackgrounds.append(x)
    
    return qbackgrounds


class ColBERT(BertPreTrainedModel):
    def __init__(self, config, mask_punctuation=string.punctuation, dim=128, similarity_metric="cosine"):
        super(ColBERT, self).__init__(config)

        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained(model_checkpoint)
            self.skiplist = {
                w: True
                for symbol in string.punctuation
                for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]
            }

        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q=None, D=None):
        # return self.query(**Q), self.doc(**D)
        
        return self.score(self.query(**Q), self.doc(**D))

    def query(self, input_ids, attention_mask=None, token_type_ids=None):
        # input_ids, attention_mask = input_ids.to("cuda"), attention_mask.to("cuda")
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        # Q_pooled_outputs = Q_outputs[1]
        Q = self.linear(Q)
        # return Q
        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask=None, token_type_ids=None):
        # input_ids, attention_mask = input_ids.to("cuda"), attention_mask.to("cuda")
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device="cuda").unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        # if not keep_dims:
        #     D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
        #     D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        
        # print(f'score shape ------------- {Q.shape} {D.shape}')

        if self.similarity_metric == "cosine":
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == "l2"
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1)) ** 2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

In [None]:
# class ModelInference():
#     def __init__(self, colbert: ColBERT, amp=False):
#         assert colbert.training is False

#         self.colbert = colbert
#         self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
#         self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)

#         self.amp_manager = MixedPrecisionManager(amp)

#     def query(self, *args, to_cpu=False, **kw_args):
#         with torch.no_grad():
#             with self.amp_manager.context():
#                 Q = self.colbert.query(*args, **kw_args)
#                 return Q.cpu() if to_cpu else Q

#     def doc(self, *args, to_cpu=False, **kw_args):
#         with torch.no_grad():
#             with self.amp_manager.context():
#                 D = self.colbert.doc(*args, **kw_args)
#                 return D.cpu() if to_cpu else D

#     def queryFromText(self, queries, bsize=None, to_cpu=False):
#         if bsize:
#             batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
#             batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
#             return torch.cat(batches)

#         input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
#         return self.query(input_ids, attention_mask)

#     def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
#         if bsize:
#             batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

#             batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
#                        for input_ids, attention_mask in batches]

#             if keep_dims:
#                 D = _stack_3D_tensors(batches)
#                 return D[reverse_indices]

#             D = [d for batch in batches for d in batch]
#             return [D[idx] for idx in reverse_indices.tolist()]

#         input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
#         return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

#     def score(self, Q, D, mask=None, lengths=None, explain=False):
#         if lengths is not None:
#             assert mask is None, "don't supply both mask and lengths"

#             mask = torch.arange(D.size(1), device=DEVICE) + 1
#             mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

#         scores = (D @ Q)
#         scores = scores if mask is None else scores * mask.unsqueeze(-1)
#         scores = scores.max(1)

#         if explain:
#             assert False, "TODO"

#         return scores.values.sum(-1).cpu()


# def _stack_3D_tensors(groups):
#     bsize = sum([x.size(0) for x in groups])
#     maxlen = max([x.size(1) for x in groups])
#     hdim = groups[0].size(2)

#     output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

#     offset = 0
#     for x in groups:
#         endpos = offset + x.size(0)
#         output[offset:endpos, :x.size(1)] = x
#         offset = endpos

#     return output

In [None]:
# def batch_retrieve(args, queries):
#     assert args.retrieve_only, "TODO: Combine batch (multi-qu+ery) retrieval with batch re-ranking"

#     # faiss_index = FaissIndex(args.index_path, args.faiss_index_path, args.nprobe, args.part_range)
#     inference = ModelInference(args.colbert, amp=args.amp)

#     # ranking_logger = RankingLogger(Run.path, qrels=None)

#     # with ranking_logger.context('unordered.tsv', also_save_annotations=False) as rlogger:
#     with torch.no_grad() :
#         # queries = args.queries
#         qids_in_order = queries #list(queries.keys())

#         for qoffset, qbatch in batch(qids_in_order, 100_000, provide_offset=True):
#             qbatch_text = [queries[qid] for qid in qbatch]

#             print_message(f"#> Embedding {len(qbatch_text)} queries in parallel...")
#             Q = inference.queryFromText(qbatch_text, bsize=16)

#             print_message("#> Starting batch retrieval...")
#             all_pids = faiss_index.retrieve(args.faiss_depth, Q, verbose=True)

#             # Log the PIDs with rank -1 for all
#             for query_idx, (qid, ranking) in enumerate(zip(qbatch, all_pids)):
#                 query_idx = qoffset + query_idx

#                 if query_idx % 1000 == 0:
#                     print_message(f"#> Logging query #{query_idx} (qid {qid}) now...")

#                 ranking = [(None, pid, None) for pid in ranking]
#                 rlogger.log(qid, ranking, is_ranked=False)

#     print('\n\n')
#     print(ranking_logger.filename)
#     print("#> Done.")
#     print('\n\n')

In [None]:
q_tokenizer = QueryTokenizer()
d_tokenizer = DocTokenizer()

In [2]:
# def query(self, input_ids, attention_mask=None, token_type_ids=None): -> Q
# def doc(self, input_ids, attention_mask=None, token_type_ids=None): -> D
# def score(self, Q, D):

def get_p_embs(corpus, colbert_encoder):
    with torch.no_grad() :
        colbert_encoder.eval()

        p_embs = []
        for p in tqdm(corpus) :
            # p = tokenizer(p, padding='max_length', truncation=True, return_tensors='pt').to('cuda')
            # p_emb = p_encoder(**p).to('cpu').numpy()
            p_emb = colbert_encoder.docs(d_tokenizer.tensorize(p))
            p_embs.append(p_emb)
    p_embs = torch.Tensor(p_embs).squeeze()  
    return p_embs


In [None]:
# import pickle

# file_path = '/opt/ml/custom/passage_embedding.bin'
# with open(file_path, 'wb') as file :
#     pickle.dump(p_embs, file)


In [4]:
import json
with open('/opt/ml/data/wikipedia_documents.json', "r", encoding="utf-8") as f:
    wiki = json.load(f)

corpus = list(
    dict.fromkeys([v["text"] for v in wiki.values()])
)  # set 은 매번 순서가 바뀌므로

In [None]:
# import pickle
# with open('/opt/ml/custom/passage_embedding.bin', 'rb') as file :
#     p_embs = pickle.load(file)
# p_embs = p_embs

In [12]:
import torch

colbert_encoder = torch.load('/opt/ml/models/colbert_encoder.pt')


In [None]:
p_embs = get_p_embs(corpus, colbert_encoder)

In [None]:
# def query(self, input_ids, attention_mask=None, token_type_ids=None): -> Q
# def doc(self, input_ids, attention_mask=None, token_type_ids=None): -> D
# def score(self, Q, D):
    
def get_relavant_doc(queries, colbert_encoder, k=1) :
    with torch.no_grad() :
        colbert_encoder.eval()
        # query (input_id, attention_mask) -> Q emb
        q_emb = colbert_encoder.query(q_tokenizer.tensorize(queries))
        
        # q_seqs_val = tokenizer(queries, padding='max_length',truncation=True,return_tensors='pt').to(device) 
        # q_emb = q_encoder(**q_seqs_val).to('cpu') # 
        
    # dot_prod_scores = colbert_encoder.score.mm(q_emb, p_embs.T)
    dot_prod_scores = colbert_encoder.score(q_emb, p_embs)
    sort_result = torch.sort(dot_prod_scores, dim=1, descending=True)

    scores, ranks = sort_result[0], sort_result[1]

    result_scores = []
    result_indices = []
    for i in range(len(ranks)) :
        result_scores.append(scores[i][:k])
        result_indices.append(ranks[i][:k])
    
    return result_scores, result_indices

In [None]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
# dataset = load_from_disk('/opt/ml/data/test_dataset')

doc_scores, doc_indices = get_relavant_doc(dataset['validation']['question'], colbert_encoder, k = 20)


In [None]:
import pandas as pd

total = []
for idx, example in enumerate(
        tqdm(dataset['validation'], desc="Dense retrieval: ")
    ):
        tmp = {
            # Query와 해당 id를 반환합니다.
            "question": example["question"],
            "id": example["id"],
            # Retrieve한 Passage의 id, context를 반환합니다.
            "context_id": doc_indices[idx],
            "context": " ".join(  # 기존에는 ' '.join()
                [corpus[pid] for pid in doc_indices[idx]]
            ),
        }
        if "context" in example.keys() and "answers" in example.keys():
            # validation 데이터를 사용하면 ground_truth context와 answer도 반환합니다.
            tmp["original_context"] = example["context"]
            tmp["answers"] = example["answers"]
        total.append(tmp)

cqas = pd.DataFrame(total)

In [None]:
correct_length = []
for i in range(len(cqas)) :
    if cqas['original_context'][i] in cqas['context'][i] :
        correct_length.append(i)

In [None]:
print(len(correct_length) / len(dataset['validation']))
