In [5]:
#!conda install -y -c conda-forge elasticsearch=7.17.0

In [1]:
#!conda install -y -c conda-forge accelerate sacremoses transformers

In [1]:
import os
data_dir = os.path.expanduser("~/Google Drive/Shared drives/Data")
dataset = "bioasq"
data_path = f'{data_dir}/BEIR/{dataset}'

In [2]:
#from vectorspace.utils import json_load
#from beir.datasets.data_loader import GenericDataLoader
#_, queries, qrels = GenericDataLoader(data_path).load(split="test")
#bm25_results = json_load(f"{data_dir}/BEIR/results_{dataset}_bm25.json")
#from vectorspace.store import Store as ESDict
#from elasticsearch import Elasticsearch
#corpus = ESDict(Elasticsearch(['http://localhost:9200']), dataset)

In [3]:
prompt = 'Documents are searched to find matches with the same content.\nThe document "{}" is a good search result for "'


In [4]:
from vectorspace.utils import jsonl_load
sentence_pairs = [[p['query'], p['text']] for p in jsonl_load(f'{data_path}/sentence_pairs_100.jsonl')]
pair_ids = [[p['query_id'], p['doc_id']] for p in jsonl_load(f'{data_path}/pair_ids_100.jsonl')]

## GPT

In [5]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from scipy.spatial.distance import cosine

# Get models - The package will take care of downloading the models automatically
# For best performance: EleutherAI/gpt-j-6B
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
device = 'mps'
model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-neo-125M").to(device).eval()

In [7]:
import transformers
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np

def encode(requests, tokenizer):
    new_reqs = []
    for continuation, context in requests:
        if context == "":
            # end of text as context
            context_enc = [tokenizer.eos_token_id]
        else:
            context_enc = tokenizer.encode(context, add_special_tokens=False)
        continuation_enc = tokenizer.encode(continuation, add_special_tokens=False)
        new_reqs.append(((context, continuation), context_enc, continuation_enc))
    return new_reqs 

import collections

def group(arr, fn):
    res = collections.defaultdict(list)

    for ob in arr:
        res[fn(ob)].append(ob)
    
    return list(res.values())

class Reorderer:
    def __init__(self, arr, fn):
        self.size = len(arr)
        arr = list(enumerate(arr))
        arr = group(arr, lambda x: fn(x[1]))
        arr = [
            ([y[0] for y in x], x[0][1]) for x in arr
        ]
        arr.sort(key=lambda x: fn(x[1]))

        self.arr = arr
        
    
    def get_reordered(self):
        return [x[1] for x in self.arr]
    
    def get_original(self, newarr):
        res = [None] * self.size
        cov = [False] * self.size

        for (inds, _), v in zip(self.arr, newarr):
            for ind in inds: 
                res[ind] = v
                cov[ind] = True
        
        assert all(cov)
        
        return res

def chunks(iter, n):
    arr = []
    for x in iter:
        arr.append(x)
        if len(arr) == n:
            yield arr
            arr = []
    
    if arr: yield arr

def _model_call(inps, model):
    """
    inps: a torch tensor of shape [batch, sequence]
    the size of sequence may vary from call to call
    returns: a torch tensor of shape [batch, sequence, vocab] with the
    logits retuned from the model
    """
    return model(inps)[0][:, :, :50257]

def _loglikelihood_tokens(requests, model, max_length, device, disable_tqdm=False, batch_size=1, 
                          instruction_len=0, tokenizer=None):
    res = []
    with torch.no_grad():

        def _collate(x):
            toks = x[1] + x[2]
            return (-len(toks), tuple(toks))
        
        # TODO: automatic (variable) batch size detection for vectorization
        reord = Reorderer(requests, _collate)
        for chunk in chunks(tqdm(reord.get_reordered(), disable=disable_tqdm), batch_size):
            inps = []
            contlens = []
            inplens = []
            padding_length = None

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= max_length
                inp = torch.tensor(
                    (context_enc[:instruction_len] + ((context_enc[instruction_len:] 
                    + continuation_enc)[-(max_length+1-instruction_len):]))[:-1]
                    , dtype=torch.long).to(device)
                inplen, = inp.shape
                cont = continuation_enc
                # since in _collate we make sure length is descending, the longest is always the first one.
                padding_length = padding_length if padding_length is not None else inplen
                # pad to length
                inp = torch.cat([
                    inp, # [seq]
                    torch.zeros(padding_length - inplen, dtype=torch.long).to(inp.device) # [padding_length - seq]
                ], dim=0)
                inps.append(inp.unsqueeze(0))
                contlens.append(cont)
                inplens.append(inplen)
            # [batch, seq, vocab]
            multi_logits = F.log_softmax(_model_call(torch.cat(inps, dim=0), model), dim=-1).cpu()  
            

            for (cache_key, _, _), logits, inp, inplen, cont_toks in zip(chunk, multi_logits, inps, inplens, contlens):
                contlen = len(cont_toks)
                logits = logits[inplen-contlen:inplen].unsqueeze(0) # [1, seq, vocab]
                # cont_toks :: [1, seq]
                cont_toks = torch.tensor(cont_toks, dtype=torch.long).unsqueeze(0)

                # cont_toks are the vocab indices that make up the perfect continuation
                # Hence we gather those vocab indices from the logits, i.e. their probabilities
                logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(-1) # [1, seq]
                # Sum to get a total score of that continuation
                res.append(float(logits.sum()))

    return reord.get_original(res)

In [8]:
sentences = [(query, prompt.format(doc)) for (query, doc) in sentence_pairs]
encoded = encode(sentences, tokenizer)


In [10]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [12]:
%%time
max_len = model.config.max_position_embeddings
instruction_len = len(tokenizer.tokenize(prompt[:prompt.index("{")]))
log_probs = _loglikelihood_tokens(encoded[:200], model, max_len, 'mps', instruction_len=instruction_len, 
                                  tokenizer=tokenizer, batch_size=32)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [01:19<00:00,  2.52it/s]


CPU times: user 2.3 s, sys: 28.4 s, total: 30.7 s
Wall time: 1min 22s
