# colbert
stanford-futuredata

In [2]:
import json
import random
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from tqdm import trange
from tqdm import tqdm

from pprint import pprint

from sklearn.feature_extraction.text import TfidfVectorizer

import torch
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    BertModel, RobertaModel,
    BertPreTrainedModel,
    AdamW, get_linear_schedule_with_warmup,
    TrainingArguments,
)
from datasets import (
    Dataset,
    load_from_disk,
    concatenate_datasets,
)

from typing import List
from torch.utils.data import Sampler

In [3]:
# 난수 고정
def set_seed(random_seed):
    torch.manual_seed(random_seed)
    torch.cuda.manual_seed(random_seed)
    torch.cuda.manual_seed_all(random_seed)  # if use multi-GPU
    random.seed(random_seed)
    np.random.seed(random_seed)
    
set_seed(42) # magic number :)

In [4]:
print ("PyTorch version:[%s]."%(torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print ("device:[%s]."%(device))

PyTorch version:[1.7.1].
device:[cuda:0].


In [5]:
dataset = load_from_disk('/opt/ml/data/train_dataset')
train_dataset = dataset['train']
model_checkpoint = 'klue/bert-base'

## colbert , tokenizer

In [6]:
import time
import tqdm
import string
import pickle
import os.path as p

import torch
import numpy as np
import torch.nn as nn
from datasets import load_from_disk
from transformers import AdamW, TrainingArguments
from transformers import BertPreTrainedModel, BertModel, BertTokenizerFast, BertConfig


class QueryTokenizer:
    def __init__(self):
        self.tok = BertTokenizerFast.from_pretrained(model_checkpoint)

        self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.tok.convert_tokens_to_ids("[unused0]")
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id
        self.mask_token, self.mask_token_id = self.tok.mask_token, self.tok.mask_token_id
        self.query_maxlen = self.tok.model_max_length


    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.Q_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix + [self.mask_token] * (self.query_maxlen - (len(lst) + 3)) for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        ids = self.tok(batch_text, add_special_tokens=False)["input_ids"]

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.Q_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix + [self.mask_token_id] * (self.query_maxlen - (len(lst) + 3)) for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], type(batch_text)

        # add placehold for the [Q] marker
        batch_text = [". " + x for x in batch_text]

        obj = self.tok(
            batch_text, padding="longest", truncation=True, return_tensors="pt", max_length=self.tok.model_max_length
        )

        ids, mask = obj["input_ids"], obj["attention_mask"]

        # postprocess for the [Q] marker and the [MASK] augmentation
        ids[:, 1] = self.Q_marker_token_id
        ids[ids == 0] = self.mask_token_id

        if bsize:
            batches = _split_into_batches(ids, mask, bsize)
            return batches

        return ids, mask


class DocTokenizer:
    def __init__(self):
        self.tok = BertTokenizerFast.from_pretrained(model_checkpoint)

        self.D_marker_token, self.D_marker_token_id = "[D]", self.tok.convert_tokens_to_ids("[unused1]")
        self.cls_token, self.cls_token_id = self.tok.cls_token, self.tok.cls_token_id
        self.sep_token, self.sep_token_id = self.tok.sep_token, self.tok.sep_token_id

        # assert self.D_marker_token_id == 1

    def tokenize(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        tokens = [self.tok.tokenize(x, add_special_tokens=False) for x in batch_text]

        if not add_special_tokens:
            return tokens

        prefix, suffix = [self.cls_token, self.D_marker_token], [self.sep_token]
        tokens = [prefix + lst + suffix for lst in tokens]

        return tokens

    def encode(self, batch_text, add_special_tokens=False):
        assert type(batch_text) in [list, tuple], type(batch_text)

        ids = self.tok(batch_text, add_special_tokens=False)["input_ids"]

        if not add_special_tokens:
            return ids

        prefix, suffix = [self.cls_token_id, self.D_marker_token_id], [self.sep_token_id]
        ids = [prefix + lst + suffix for lst in ids]

        return ids

    def tensorize(self, batch_text, bsize=None):
        assert type(batch_text) in [list, tuple], type(batch_text)

        # add placehold for the [D] marker
        batch_text = [". " + x for x in batch_text]

        obj = self.tok(
            batch_text, padding="max_length", truncation=True, return_tensors="pt", max_length=self.tok.model_max_length
        )

        ids, mask = obj["input_ids"], obj["attention_mask"]

        # postprocess for the [D] marker
        ids[:, 1] = self.D_marker_token_id

        if bsize:
            ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
            batches = _split_into_batches(ids, mask, bsize)
            return batches, reverse_indices

        return ids, mask


def tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
    # assert len(queries) == len(positives) == len(negatives)
    # assert bsize is None or len(queries) % bsize == 0

    N = len(queries)
    queries = queries.to_list()
    # print(f'queries-----------------------------------------------------')
    # print(queries)
    Q_ids, Q_mask = query_tokenizer.tensorize(queries)

    positives = positives.to_list()
    negatives = negatives.to_list()

    D_ids, D_mask = doc_tokenizer.tensorize(positives + negatives)
    D_ids, D_mask = D_ids.view(2, N, -1), D_mask.view(2, N, -1)

    # Compute max among {length of i^th positive, length of i^th negative} for i \in N
    maxlens = D_mask.sum(-1).max(0).values

    # Sort by maxlens
    indices = maxlens.sort().indices
    Q_ids, Q_mask = Q_ids[indices], Q_mask[indices]
    D_ids, D_mask = D_ids[:, indices], D_mask[:, indices]

    (positive_ids, negative_ids), (positive_mask, negative_mask) = D_ids, D_mask

    query_batches = _split_into_batches(Q_ids, Q_mask, bsize)
    positive_batches = _split_into_batches(positive_ids, positive_mask, bsize)
    negative_batches = _split_into_batches(negative_ids, negative_mask, bsize)

    batches = []
    for (q_ids, q_mask), (p_ids, p_mask), (n_ids, n_mask) in zip(query_batches, positive_batches, negative_batches):
        Q = (torch.cat((q_ids, q_ids)), torch.cat((q_mask, q_mask)))
        D = (torch.cat((p_ids, n_ids)), torch.cat((p_mask, n_mask)))
        batches.append((Q, D))

    return batches


def _sort_by_length(ids, mask, bsize):
    if ids.size(0) <= bsize:
        return ids, mask, torch.arange(ids.size(0))

    indices = mask.sum(-1).sort().indices
    reverse_indices = indices.sort().indices

    return ids[indices], mask[indices], reverse_indices


def _split_into_batches(ids, mask, bsize):
    batches = []
    for offset in range(0, ids.size(0), bsize):
        batches.append((ids[offset : offset + bsize], mask[offset : offset + bsize]))

    return batches


import os
import tqdm
import torch
import datetime
import itertools

from multiprocessing import Pool
from collections import OrderedDict, defaultdict


def print_message(*s, condition=True):
    s = ' '.join([str(x) for x in s])
    msg = "[{}] {}".format(datetime.datetime.now().strftime("%b %d, %H:%M:%S"), s)

    if condition:
        print(msg, flush=True)

    return msg


def timestamp():
    format_str = "%Y-%m-%d_%H.%M.%S"
    result = datetime.datetime.now().strftime(format_str)
    return result


def file_tqdm(file):
    print(f"#> Reading {file.name}")

    with tqdm.tqdm(total=os.path.getsize(file.name) / 1024.0 / 1024.0, unit="MiB") as pbar:
        for line in file:
            yield line
            pbar.update(len(line) / 1024.0 / 1024.0)

        pbar.close()


def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer, arguments=None):
    print(f"#> Saving a checkpoint to {path} ..")

    if hasattr(model, 'module'):
        model = model.module  # extract model from a distributed/data-parallel wrapper

    checkpoint = {}
    checkpoint['epoch'] = epoch_idx
    checkpoint['batch'] = mb_idx
    checkpoint['model_state_dict'] = model.state_dict()
    checkpoint['optimizer_state_dict'] = optimizer.state_dict()
    checkpoint['arguments'] = arguments

    torch.save(checkpoint, path)


def load_checkpoint(path, model, optimizer=None, do_print=True):
    if do_print:
        print_message("#> Loading checkpoint", path, "..")

    if path.startswith("http:") or path.startswith("https:"):
        checkpoint = torch.hub.load_state_dict_from_url(path, map_location='cpu')
    else:
        checkpoint = torch.load(path, map_location='cpu')

    state_dict = checkpoint['model_state_dict']
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        name = k
        if k[:7] == 'module.':
            name = k[7:]
        new_state_dict[name] = v

    checkpoint['model_state_dict'] = new_state_dict

    try:
        model.load_state_dict(checkpoint['model_state_dict'])
    except:
        print_message("[WARNING] Loading checkpoint with strict=False")
        model.load_state_dict(checkpoint['model_state_dict'], strict=False)

    if optimizer:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if do_print:
        print_message("#> checkpoint['epoch'] =", checkpoint['epoch'])
        print_message("#> checkpoint['batch'] =", checkpoint['batch'])

    return checkpoint


def create_directory(path):
    if os.path.exists(path):
        print('\n')
        print_message("#> Note: Output directory", path, 'already exists\n\n')
    else:
        print('\n')
        print_message("#> Creating directory", path, '\n\n')
        os.makedirs(path)

# def batch(file, bsize):
#     while True:
#         L = [ujson.loads(file.readline()) for _ in range(bsize)]
#         yield L
#     return


def f7(seq):
    """
    Source: https://stackoverflow.com/a/480227/1493011
    """

    seen = set()
    return [x for x in seq if not (x in seen or seen.add(x))]


def batch(group, bsize, provide_offset=False):
    offset = 0
    while offset < len(group):
        L = group[offset: offset + bsize]
        yield ((offset, L) if provide_offset else L)
        offset += len(L)
    return


class dotdict(dict):
    """
    dot.notation access to dictionary attributes
    Credit: derek73 @ https://stackoverflow.com/questions/2352181
    """
    __getattr__ = dict.__getitem__
    __setattr__ = dict.__setitem__
    __delattr__ = dict.__delitem__


def flatten(L):
    return [x for y in L for x in y]


def zipstar(L, lazy=False):
    """
    A much faster A, B, C = zip(*[(a, b, c), (a, b, c), ...])
    May return lists or tuples.
    """

    if len(L) == 0:
        return L

    width = len(L[0])

    if width < 100:
        return [[elem[idx] for elem in L] for idx in range(width)]

    L = zip(*L)

    return L if lazy else list(L)


def zip_first(L1, L2):
    length = len(L1) if type(L1) in [tuple, list] else None

    L3 = list(zip(L1, L2))

    assert length in [None, len(L3)], "zip_first() failure: length differs!"

    return L3


def int_or_float(val):
    if '.' in val:
        return float(val)
        
    return int(val)

def load_ranking(path, types=None, lazy=False):
    print_message(f"#> Loading the ranked lists from {path} ..")

    try:
        lists = torch.load(path)
        lists = zipstar([l.tolist() for l in tqdm.tqdm(lists)], lazy=lazy)
    except:
        if types is None:
            types = itertools.cycle([int_or_float])

        with open(path) as f:
            lists = [[typ(x) for typ, x in zip_first(types, line.strip().split('\t'))]
                     for line in file_tqdm(f)]

    return lists


def save_ranking(ranking, path):
    lists = zipstar(ranking)
    lists = [torch.tensor(l) for l in lists]

    torch.save(lists, path)

    return lists


def groupby_first_item(lst):
    groups = defaultdict(list)

    for first, *rest in lst:
        rest = rest[0] if len(rest) == 1 else rest
        groups[first].append(rest)

    return groups


def process_grouped_by_first_item(lst):
    """
        Requires items in list to already be grouped by first item.
    """

    groups = defaultdict(list)

    started = False
    last_group = None

    for first, *rest in lst:
        rest = rest[0] if len(rest) == 1 else rest

        if started and first != last_group:
            yield (last_group, groups[last_group])
            assert first not in groups, f"{first} seen earlier --- violates precondition."

        groups[first].append(rest)

        last_group = first
        started = True

    return groups


def grouper(iterable, n, fillvalue=None):
    """
    Collect data into fixed-length chunks or blocks
        Example: grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx"
        Source: https://docs.python.org/3/library/itertools.html#itertools-recipes
    """

    args = [iter(iterable)] * n
    return itertools.zip_longest(*args, fillvalue=fillvalue)


# see https://stackoverflow.com/a/45187287
class NullContextManager(object):
    def __init__(self, dummy_resource=None):
        self.dummy_resource = dummy_resource
    def __enter__(self):
        return self.dummy_resource
    def __exit__(self, *args):
        pass


def load_batch_backgrounds(args, qids):
    if args.qid2backgrounds is None:
        return None

    qbackgrounds = []

    for qid in qids:
        back = args.qid2backgrounds[qid]

        if len(back) and type(back[0]) == int:
            x = [args.collection[pid] for pid in back]
        else:
            x = [args.collectionX.get(pid, '') for pid in back]

        x = ' [SEP] '.join(x)
        qbackgrounds.append(x)
    
    return qbackgrounds


class ColBERT(BertPreTrainedModel):
    def __init__(self, config, mask_punctuation=string.punctuation, dim=128, similarity_metric="cosine"):
        super(ColBERT, self).__init__(config)

        self.similarity_metric = similarity_metric
        self.dim = dim

        self.mask_punctuation = mask_punctuation
        self.skiplist = {}

        if self.mask_punctuation:
            self.tokenizer = BertTokenizerFast.from_pretrained('klue/bert-base')
            self.skiplist = {
                w: True
                for symbol in string.punctuation
                for w in [symbol, self.tokenizer.encode(symbol, add_special_tokens=False)[0]]
            }

        self.bert = BertModel(config)
        self.linear = nn.Linear(config.hidden_size, dim, bias=False)

        self.init_weights()

    def forward(self, Q=None, D=None):
        # return self.query(**Q), self.doc(**D)
        return self.score(self.query(**Q), self.doc(**D))

    def query(self, input_ids, attention_mask=None, token_type_ids=None):
        # input_ids, attention_mask = input_ids.to("cuda"), attention_mask.to("cuda")
        Q = self.bert(input_ids, attention_mask=attention_mask)[0]
        # Q_pooled_outputs = Q_outputs[1]
        Q = self.linear(Q)
        # return Q
        return torch.nn.functional.normalize(Q, p=2, dim=2)

    def doc(self, input_ids, attention_mask=None, token_type_ids=None):
        # input_ids, attention_mask = input_ids.to("cuda"), attention_mask.to("cuda")
        D = self.bert(input_ids, attention_mask=attention_mask)[0]
        D = self.linear(D)

        mask = torch.tensor(self.mask(input_ids), device="cuda").unsqueeze(2).float()
        D = D * mask

        D = torch.nn.functional.normalize(D, p=2, dim=2)

        # if not keep_dims:
        #     D, mask = D.cpu().to(dtype=torch.float16), mask.cpu().bool().squeeze(-1)
        #     D = [d[mask[idx]] for idx, d in enumerate(D)]

        return D

    def score(self, Q, D):
        if self.similarity_metric == "cosine":
            return (Q @ D.permute(0, 2, 1)).max(2).values.sum(1)

        assert self.similarity_metric == "l2"
        return (-1.0 * ((Q.unsqueeze(2) - D.unsqueeze(1)) ** 2).sum(-1)).max(-1).values.sum(-1)

    def mask(self, input_ids):
        mask = [[(x not in self.skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
        return mask

In [7]:
class CustomSampler(Sampler) :
    def __init__(self, data_source, batch_size) :
        self.data_source = data_source
        self.batch_size = batch_size

    def __iter__(self) :
        n = len(self.data_source)
        index_list = []
        while True :
            out = True
            for i in range(self.batch_size) :
                tmp_data = random.randint(0, n-1)
                index_list.append(tmp_data)
            for f, s in zip(index_list, index_list[1:]) :
                if abs(s-f) <= 2 :
                    out = False
            if out == True :
                break

        while True : # 추가 삽입
            tmp_data = random.randint(0, n-1)
            if (tmp_data not in index_list) and \
                (abs(tmp_data-index_list[-i]) > 2 for i in range(1,self.batch_size+1)) \
            : 
                index_list.append(tmp_data)
            if len(index_list) == n :
                break
        return iter(index_list)

    def __len__(self) :
        return len(self.data_source)

In [8]:
class DenseRetrieval:
    def __init__(self,
        args,
        dataset,
        # tokenizer,
        q_tokenizer,
        d_tokenizer,
        colbert_encoder,
        df

    ):
        """
        학습과 추론에 사용될 여러 셋업을 마쳐봅시다.
        """

        self.args = args
        self.dataset = dataset

        self.tokenizer = tokenizer
        self.q_tokenizer = q_tokenizer
        self.d_tokenizer = d_tokenizer
        self.colbert_encoder = colbert_encoder

        self.df = df
        # self.p_encoder=None
        # self.q_encoder=None

    def train(self, args=None, tokenizer = None, df=None):
        if args is None:
            args = self.args
        if tokenizer is None :
            tokenizer = self.tokenizer

        # q_seqs = tokenizer(self.dataset['question'], padding="max_length", truncation=True, return_tensors='pt')
        # p_seqs = tokenizer(self.dataset['context'], padding="max_length", truncation=True, return_tensors='pt') 

        # train_dataset = TensorDataset(p_seqs['input_ids'], p_seqs['attention_mask'], p_seqs['token_type_ids'], 
        #                 q_seqs['input_ids'], q_seqs['attention_mask'], q_seqs['token_type_ids'])
        # train_dataloader = DataLoader(train_dataset, batch_size=args.per_device_train_batch_size)

        # tensorize_triples(query_tokenizer, doc_tokenizer, queries, positives, negatives, bsize):
        train_dataloader = tensorize_triples(
            self.q_tokenizer
            , self.d_tokenizer
            , self.df["question"]
            , self.df["original_context"]
            , self.df["context"]
            , self.args.per_device_train_batch_size, )

        no_decay = ["bias" ,"LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {"params": [p for n, p in self.colbert_encoder.named_parameters() if not any(nd in n for nd in no_decay)], "weight_decay": args.weight_decay},
            {"params": [p for n, p in self.colbert_encoder.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0}
        ]
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            # eps=args.adam_epsilon
        )

        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
        
        global_step = 0

        # self.p_encoder.zero_grad()
        # self.q_encoder.zero_grad()
        self.colbert_encoder.zero_grad()
        # torch.cuda.empty_cache()

        train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
        # self.q_encoder.train()
        # self.p_encoder.train()
        self.colbert_encoder.train()
        
        criterion = nn.CrossEntropyLoss()

        for epoch, _ in enumerate(train_iterator):
            # epoch_iterator = tqdm(train_dataloader, desc="Iteration")
            # loss_value=0 # Accumulation할 때 진행
            losses = 0
            for step, batch in enumerate(train_dataloader):
                # if torch.cuda.is_available():
                #     batch = tuple(t.to_device('cuda') for t in batch)

                # D
                p_inputs = {'input_ids': batch[0][0].cuda(),
                            'attention_mask': batch[0][1].cuda()
                            }

                # Q
                q_inputs = {'input_ids': batch[1][0].cuda(),
                            'attention_mask': batch[1][1].cuda()
                            }

                sim_scores = self.colbert_encoder(Q=q_inputs, D=p_inputs)

                # Calculate similarity score & loss
                # sim_scores = self.colbert_encoder.score(q_outputs, p_outputs)

                # target: position of positive samples = diagonal element 
                targets = torch.zeros(args.per_device_train_batch_size).long()
                
                if torch.cuda.is_available():
                    targets = targets.to('cuda')

                # sim_scores = F.log_softmax(sim_scores, dim=1)
                # loss = -F.log_softmax(sim_scores)[:,0].mean()
                # print(f'sim_scores {sim_scores} | targets {targets}')
                # print(f'sim_scores.shape, targets.shape {sim_scores.shape} | {targets.shape}')
                
                # sim_scores shaped as batch_size x 2
                sim_scores = sim_scores.view(-1, 2)
                
                # get mean of the loss
                loss = criterion(sim_scores, targets)
                losses += loss.item()
                if step % 100 == 0 :
                    print(f'{epoch}epoch loss: {losses/(step+1)}') # Accumulation할 경우 주석처리

                
                #################ACCUMULATION###############################
                # loss_value += loss
                # if (step+1) % args.gradient_accumulation_steps == 0 :
                #     optimizer.step()
                #     scheduler.step()
                #     self.q_encoder.zero_grad()
                #     self.p_encoder.zero_grad()
                #     global_step += 1
                #     print(loss_value/args.gradient_accumulation_steps)
                #     loss_value = 0
                ############################################################
                self.colbert_encoder.zero_grad()
                # self.p_encoder.zero_grad()
                loss.backward()
                optimizer.step()
                scheduler.step()

                global_step += 1
                
                #torch.cuda.empty_cache()
                del p_inputs, q_inputs

        return self.colbert_encoder

In [9]:
args = TrainingArguments(
    output_dir="dense_retireval",
    evaluation_strategy="epoch",
    learning_rate=3e-6,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=1,
    num_train_epochs=5,
    weight_decay=0.01
)

# 혹시 위에서 사용한 encoder가 있다면 주석처리 후 진행해주세요 (CUDA ...)
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
# q_encoder = BertEncoder.from_pretrained(model_checkpoint).to(args.device)

In [11]:
colbert_encoder = ColBERT.from_pretrained(model_checkpoint).to(args.device)

Some weights of the model checkpoint at klue/bert-base were not used when initializing ColBERT: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing ColBERT from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing ColBERT from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of ColBERT were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['linear.weight'

In [12]:
df = pd.read_csv('/opt/ml/data/colbertdata_join_top10_wikipedia.csv')

In [13]:
!nvidia-smi

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Thu Nov  4 05:00:50 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.80.02    Driver Version: 450.80.02    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-PCIE...  Off  | 00000000:00:05.0 Off |                  Off |
| N/A   34C    P0    36W / 250W |   1703MiB / 32510MiB |      0%      Default |
|                               |            

In [14]:
# Retriever는 아래와 같이 사용할 수 있도록 코드를 짜봅시다.
q_tokenizer = QueryTokenizer()
d_tokenizer = DocTokenizer()

retriever = DenseRetrieval(
    args=args,
    dataset=train_dataset,
    q_tokenizer = q_tokenizer,
    d_tokenizer = d_tokenizer,
    colbert_encoder=colbert_encoder,
    df=df
)
colbert_encoder = retriever.train()

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

0epoch loss: 29.455158233642578
0epoch loss: 10.449555490866747
0epoch loss: 7.939226127679075


Epoch:  20%|██        | 1/5 [04:25<17:41, 265.27s/it]

1epoch loss: 4.91359806060791
1epoch loss: 3.2785977432043247
1epoch loss: 3.0742776646691174


Epoch:  40%|████      | 2/5 [08:51<13:16, 265.53s/it]

2epoch loss: 3.0598464012145996
2epoch loss: 2.2762490493236203
2epoch loss: 2.123538496630702


Epoch:  60%|██████    | 3/5 [13:18<08:51, 265.85s/it]

3epoch loss: 2.677745819091797
3epoch loss: 1.830045880362539
3epoch loss: 1.7334236750258736


Epoch:  80%|████████  | 4/5 [17:44<04:25, 265.99s/it]

4epoch loss: 1.7828195095062256
4epoch loss: 1.5680520944666154
4epoch loss: 1.5433576952906984


Epoch: 100%|██████████| 5/5 [22:10<00:00, 266.18s/it]


In [16]:
torch.save(colbert_encoder, '/opt/ml/code/models/colbert_encoder.pt')

In [12]:
import torch

class ModelInference():
    # https://github.dev/stanford-futuredata/ColBERT/tree/master/colbert/evaluation
    def __init__(self, colbert: ColBERT, amp=False):
        assert colbert.training is False

        self.colbert = colbert
        self.query_tokenizer = QueryTokenizer(colbert.query_maxlen)
        self.doc_tokenizer = DocTokenizer(colbert.doc_maxlen)

        self.amp_manager = MixedPrecisionManager(amp)

    def query(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                Q = self.colbert.query(*args, **kw_args)
                return Q.cpu() if to_cpu else Q

    def doc(self, *args, to_cpu=False, **kw_args):
        with torch.no_grad():
            with self.amp_manager.context():
                D = self.colbert.doc(*args, **kw_args)
                return D.cpu() if to_cpu else D

    def queryFromText(self, queries, bsize=None, to_cpu=False):
        if bsize:
            batches = self.query_tokenizer.tensorize(queries, bsize=bsize)
            batches = [self.query(input_ids, attention_mask, to_cpu=to_cpu) for input_ids, attention_mask in batches]
            return torch.cat(batches)

        input_ids, attention_mask = self.query_tokenizer.tensorize(queries)
        return self.query(input_ids, attention_mask)

    def docFromText(self, docs, bsize=None, keep_dims=True, to_cpu=False):
        if bsize:
            batches, reverse_indices = self.doc_tokenizer.tensorize(docs, bsize=bsize)

            batches = [self.doc(input_ids, attention_mask, keep_dims=keep_dims, to_cpu=to_cpu)
                       for input_ids, attention_mask in batches]

            if keep_dims:
                D = _stack_3D_tensors(batches)
                return D[reverse_indices]

            D = [d for batch in batches for d in batch]
            return [D[idx] for idx in reverse_indices.tolist()]

        input_ids, attention_mask = self.doc_tokenizer.tensorize(docs)
        return self.doc(input_ids, attention_mask, keep_dims=keep_dims)

    def score(self, Q, D, mask=None, lengths=None, explain=False):
        if lengths is not None:
            assert mask is None, "don't supply both mask and lengths"

            mask = torch.arange(D.size(1), device=DEVICE) + 1
            mask = mask.unsqueeze(0) <= lengths.to(DEVICE).unsqueeze(-1)

        scores = (D @ Q)
        scores = scores if mask is None else scores * mask.unsqueeze(-1)
        scores = scores.max(1)

        if explain:
            assert False, "TODO"

        return scores.values.sum(-1).cpu()


def _stack_3D_tensors(groups):
    bsize = sum([x.size(0) for x in groups])
    maxlen = max([x.size(1) for x in groups])
    hdim = groups[0].size(2)

    output = torch.zeros(bsize, maxlen, hdim, device=groups[0].device, dtype=groups[0].dtype)

    offset = 0
    for x in groups:
        endpos = offset + x.size(0)
        output[offset:endpos, :x.size(1)] = x
        offset = endpos

    return output


In [13]:
def retrieve(args):
    inference = ModelInference(args.colbert, amp=args.amp)
    ranker = Ranker(args, inference, faiss_depth=args.faiss_depth)

    ranking_logger = RankingLogger(Run.path, qrels=None)
    milliseconds = 0

    with ranking_logger.context('ranking.tsv', also_save_annotations=False) as rlogger:
        queries = args.queries
        qids_in_order = list(queries.keys())

        for qoffset, qbatch in batch(qids_in_order, 100, provide_offset=True):
            qbatch_text = [queries[qid] for qid in qbatch]

            rankings = []

            for query_idx, q in enumerate(qbatch_text):
                torch.cuda.synchronize('cuda:0')
                s = time.time()

                Q = ranker.encode([q])
                pids, scores = ranker.rank(Q)

                torch.cuda.synchronize()
                milliseconds += (time.time() - s) * 1000.0

                if len(pids):
                    print(qoffset+query_idx, q, len(scores), len(pids), scores[0], pids[0],
                          milliseconds / (qoffset+query_idx+1), 'ms')

                rankings.append(zip(pids, scores))

            for query_idx, (qid, ranking) in enumerate(zip(qbatch, rankings)):
                query_idx = qoffset + query_idx

                if query_idx % 100 == 0:
                    print_message(f"#> Logging query #{query_idx} (qid {qid}) now...")

                ranking = [(score, pid, None) for pid, score in itertools.islice(ranking, args.depth)]
                rlogger.log(qid, ranking, is_ranked=True)

    print('\n\n')
    print(ranking_logger.filename)
    print("#> Done.")
    print('\n\n')


## Torch practice

In [14]:
# https://stackoverflow.com/questions/48377214/runtimeerror-dimension-out-of-range-expected-to-be-in-range-of-1-0-but-go/48389451
# https://stackoverflow.com/questions/61501417/input-dimension-for-crossentropy-loss-in-pytorch
b_logits = torch.tensor([0.1198, 0.1911], requires_grad=True)
b_labels = torch.tensor([1])
loss_criterion = nn.CrossEntropyLoss()

loss = loss_criterion( b_logits.view(1,-1), b_labels )
loss

tensor(0.6581, grad_fn=<NllLossBackward>)

In [15]:
import torch
import torch.nn as nn
# import nll loss

#                       [127.0519, 152.8302, | 142.6836, 144.4586, | 142.2092, 145.1991]
b_logits = torch.tensor([127.0519, 152.8302, 142.6836, 144.4586, 142.2092, 145.1991], requires_grad=True)
b_labels = torch.tensor([0,0,0])
loss_criterion = nn.CrossEntropyLoss()
# loss_criterion = nn.NLLLoss()

# (A, B) (A>B)

# split b_logits by arg.per_device_train_batch_size -> 6 divided by model.args.per_device_train_batch_size
# b_logits_batch_1 = b_logits[:2]
# b_logits_batch_2 = b_logits[2:4]
# b_logits_batch_3 = b_logits[4:]

b_logits = b_logits.view(-1, 2)

# make 3 x 2 tensor using torch.stack
# b_logits = torch.stack([b_logits_batch_1, b_logits_batch_2, b_logits_batch_3])

b_logits

tensor([[127.0519, 152.8302],
        [142.6836, 144.4586],
        [142.2092, 145.1991]], grad_fn=<ViewBackward>)

In [16]:
b_labels

tensor([0, 0, 0])

In [17]:
loss = loss_criterion(b_logits, b_labels)

loss

tensor(10.2496, grad_fn=<NllLossBackward>)