In [1]:
import transformers, torch, faiss
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from mips import MIPS

#check version
print (transformers.__version__)
print (torch.__version__)
print (faiss.__version__)

#Load GPT2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)
model = model.eval()


#Set datastore
STORE_FILE='keys.npy'
MAXIMUM_SIZE=10000
DIMENSION=768
all_keys = np.memmap(STORE_FILE, dtype=np.float32, mode='w+', shape=(MAXIMUM_SIZE, DIMENSION))
finished_keys = 0

TOKEN_FILE='tokens.npy'
all_tokens = np.memmap(TOKEN_FILE, dtype=np.int, mode='w+', shape=(MAXIMUM_SIZE,))
all_lengths = []
finished_tokens = 0




I1030 15:29:58.193505 4481746368 __init__.py:43] Loading faiss.


3.4.0
1.7.0
1.6.1


In [2]:
batch = ["My dog is cute", "My idea is brilliant.", "My paper is very very good!"]
data = [batch] * 100
# encode data
for batch in data:
    inputs = tokenizer(batch,
                       padding=True,
                       return_length=True,
                       return_tensors="pt")
    assert (inputs['length'] > 1).all()
    with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'],
                        attention_mask = inputs['attention_mask'],
                        output_hidden_states=True)
        # We pick the hidden state at the last layer as the key
        keys = outputs['hidden_states'][-1]
        bsz, seq_len, dim = keys.shape
        for i in range(bsz):
            len_i = inputs['length'][i]
            all_keys[finished_keys:finished_keys+len_i-1] = keys[i,:len_i-1] # we do not need the last key 
            all_tokens[finished_tokens:finished_tokens+len_i] = inputs['input_ids'][i,:len_i]
            finished_keys += (len_i -1)
            finished_tokens += len_i
            all_lengths.extend(inputs['length'].tolist())
    #print ('finished_keys', finished_keys, 'finished_tokens', finished_tokens)
        

In [3]:
# make index
INDEX_TYPE = "IVF10_HNSW32,SQ8" # change it to 'IVF4096_HNSW32,SQ8' for actual use
mips = MIPS(DIMENSION, INDEX_TYPE, efSearch=128, nprobe=64)
mips.train(all_keys[:finished_keys])
mips.add(all_keys[:finished_keys])
cumsum_keys = np.cumsum(np.array(all_lengths)-1)

# This function is used to return corresponding sentence and word
def find_in_corpus(idx):
    if idx < cumsum_keys[0]:
        sent_idx = 0
        sent_start = 0
        sent_end = cumsum_keys[0] + 1
        word_pos = idx + 1
    else:
        sent_idx = np.searchsorted(cumsum_keys, idx, side='right') #cumsum_keys[i-1] <= v < cumsum_keys[i]
        sent_end = cumsum_keys[sent_idx] + sent_idx + 1  
        sent_start = sent_end - (cumsum_keys[sent_idx] - cumsum_keys[sent_idx-1] + 1)
        word_pos = idx - cumsum_keys[sent_idx-1] + 1
    sent = all_tokens[sent_start:sent_end]
    word = sent[word_pos]
    return tokenizer.decode(sent), tokenizer.decode([word]), word_pos

In [4]:
#test any prefix
topk = 3 # topk search
inputs = tokenizer("My", return_tensors="pt")
with torch.no_grad():
    outputs = model(input_ids=inputs['input_ids'],
                    attention_mask = inputs['attention_mask'],
                    output_hidden_states=True)
    # We pick the hidden state at the last layer as the key
    keys = outputs['hidden_states'][-1]
    bsz, seq_len, dim = keys.shape
    search_key = keys[0,-1].numpy()
    D, I = mips.search(np.array([search_key]), topk)
    # dis is 
    for rnk, (idx, dist) in enumerate(zip(I[0], D[0])):
        sent, word, word_pos = find_in_corpus(idx)
        print("rank %d, distance %.2f, sent %d: %s, next word position %d: %s"%(rnk, dist, idx, sent, word_pos, word))
               


rank 0, distance 0.00, sent 0: My dog is cute, next word position 1:  dog
rank 1, distance 0.00, sent 7: My paper is very very good!, next word position 1:  paper
rank 2, distance 0.00, sent 3: My idea is brilliant., next word position 1:  idea
