In [1]:
import transformers, torch, faiss
import numpy as np
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from mips import MIPS

#check version
print (transformers.__version__)
print (torch.__version__)
print (faiss.__version__)

I1129 22:40:14.643877 4525012416 __init__.py:43] Loading faiss.


3.4.0
1.7.0
1.6.1


In [2]:
#Load GPT2
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)
model = model.eval()


#Set datastore
MAXIMUM_SIZE=10000
DIMENSION=768
STORE_FILE='keys.npy'
all_keys = np.memmap(STORE_FILE, dtype=np.float32, mode='w+', shape=(MAXIMUM_SIZE, DIMENSION))
finished_keys = 0

TOKEN_FILE='tokens.npy'
all_tokens = np.memmap(TOKEN_FILE, dtype=np.int, mode='w+', shape=(MAXIMUM_SIZE,))
all_lengths = []
finished_tokens = 0

In [3]:
# suppose we have the following data
batch = ["My dog is cute",
         "My idea is brilliant.",
         "My paper is very very good!",
         "My cat is also cute"]
data = [batch] * 10

# encode data
for batch in data:
    inputs = tokenizer(batch,
                       padding=True,
                       return_length=True,
                       return_tensors="pt")
    assert (inputs['length'] > 1).all()
    with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'],
                        attention_mask = inputs['attention_mask'],
                        output_hidden_states=True)
        # We pick the hidden state at the last layer as the key
        keys = outputs['hidden_states'][-1]
        bsz, seq_len, dim = keys.shape
        for i in range(bsz):
            len_i = inputs['length'][i]
            all_keys[finished_keys:finished_keys+len_i-1] = keys[i,:len_i-1] # we do not need the last key 
            all_tokens[finished_tokens:finished_tokens+len_i] = inputs['input_ids'][i,:len_i]
            finished_keys += (len_i -1)
            finished_tokens += len_i
            all_lengths.extend(inputs['length'].tolist())
    #print ('finished_keys', finished_keys, 'finished_tokens', finished_tokens)
        

In [4]:
# make index
INDEX_TYPE = "Flat" # change it to 'IVF4096_HNSW32,SQ8' or whatever when dealing with big data
mips = MIPS(DIMENSION, INDEX_TYPE, efSearch=128, nprobe=64)
mips.train(all_keys[:finished_keys])
mips.add(all_keys[:finished_keys])
cumsum_keys = np.cumsum(np.array(all_lengths)-1)


# save everything
CUMSUM_KEYS_FILE = 'cumsum.npy'
MIPS_INDEX_FILE = 'mips.index'
np.save(open(CUMSUM_KEYS_FILE, 'wb'), cumsum_keys)
mips.save(MIPS_INDEX_FILE)


print ("preprocess done!")

preprocess done!


In [5]:
# load everything
MAXIMUM_SIZE=10000
DIMENSION=768
STORE_FILE='keys.npy'
TOKEN_FILE='tokens.npy'
CUMSUM_KEYS_FILE = 'cumsum.npy'
MIPS_INDEX_FILE = 'mips.index'
all_keys = np.memmap(STORE_FILE, dtype=np.float32, mode='r', shape=(MAXIMUM_SIZE, DIMENSION))
all_tokens = np.memmap(TOKEN_FILE, dtype=np.int, mode='r', shape=(MAXIMUM_SIZE,))
cumsum_keys = np.load(open(CUMSUM_KEYS_FILE, 'rb'))
mips = MIPS.from_built(MIPS_INDEX_FILE, nprobe=64)

# this function is used to return corresponding sentence and word
def find_in_corpus(idx):
    if idx < cumsum_keys[0]:
        sent_idx = 0
        sent_start = 0
        sent_end = cumsum_keys[0] + 1
        word_pos = idx + 1
    else:
        sent_idx = np.searchsorted(cumsum_keys, idx, side='right') #cumsum_keys[sent_idx-1] <= idx < cumsum_keys[sent_idx]
        sent_end = cumsum_keys[sent_idx] + sent_idx + 1  
        sent_start = sent_end - (cumsum_keys[sent_idx] - cumsum_keys[sent_idx-1] + 1)
        word_pos = idx - cumsum_keys[sent_idx-1] + 1
    sent = all_tokens[sent_start:sent_end]
    word = sent[word_pos]
    return tokenizer.decode(sent), tokenizer.decode([word]), sent_idx, word_pos

In [6]:
# test
def test(prefix, topk):
    inputs = tokenizer(prefix, return_tensors="pt")
    with torch.no_grad():
        outputs = model(input_ids=inputs['input_ids'],
                        attention_mask = inputs['attention_mask'],
                        output_hidden_states=True)
        # We pick the hidden state at the last layer as the key
        keys = outputs['hidden_states'][-1]
        bsz, seq_len, dim = keys.shape
        search_key = keys[0,-1].numpy()
        D, I = mips.search(np.array([search_key]), topk)
        for rnk, (idx, dist) in enumerate(zip(I[0], D[0])):
            sent, word, sent_idx, word_pos = find_in_corpus(idx)
            print("rank %d, distance %.2f, sent %d: %s, next word position %d: %s"%(rnk, dist, sent_idx, sent, word_pos, word))
prefix ="My dog"# test any prefix
topk = 3 # topk search
test(prefix, topk)

print ("="*55)
prefix ="My"# test any prefix
topk = 5 # topk search
test(prefix, topk)


rank 0, distance 0.00, sent 0: My dog is cute, next word position 2:  is
rank 1, distance 0.00, sent 8: My dog is cute, next word position 2:  is
rank 2, distance 0.00, sent 4: My dog is cute, next word position 2:  is
rank 0, distance 0.00, sent 2: My paper is very very good!, next word position 1:  paper
rank 1, distance 0.00, sent 1: My idea is brilliant., next word position 1:  idea
rank 2, distance 0.00, sent 4: My dog is cute, next word position 1:  dog
rank 3, distance 0.00, sent 3: My cat is also cute, next word position 1:  cat
rank 4, distance 0.00, sent 0: My dog is cute, next word position 1:  dog
