In [3]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import numpy as np
from transformers import AutoTokenizer, AutoModel
from transformers import BatchEncoding

os.chdir("C:\\Users\\ndgig\\Repositories\\embed_gram")
print(os.getcwd())

%load_ext autoreload
%autoreload 2

C:\Users\ndgig\Repositories\embed_gram


In [88]:
class TokenizedDataset(Dataset):
    def __init__(self, inputs):
        self.inputs = inputs

    def __len__(self):
        return len(self.inputs["input_ids"])

    def __getitem__(self, idx):
        return {k: v[idx].squeeze(0) for k, v in self.inputs.items()}

def get_ngram_idx(input_ids, ngram_range=(3, 6)):
    if isinstance(input_ids, torch.Tensor):
        input_ids = input_ids.cpu().numpy()
    ngram_idx = []
    for ngram_size in range(ngram_range[0], ngram_range[1] + 1):
        idx = np.vstack([np.arange(input_ids.shape[1]) + i for i in range(ngram_size)]).T
        idx = idx[(idx[:, -1] < input_ids.shape[1])]
        ngram_idx.append(idx)
    return ngram_idx

class NgramEncoder:
    def __init__(self, model_name, device="cuda"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)
        self.model.eval().to(device)

    @property
    def device(self):
        return self.model.device

    def extract_ngrams(self, input_ids, token_embeds, ngram_range=(3, 6)):
        ngrams = []
        ngram_vecs = []
        ngram_idx = get_ngram_idx(input_ids, ngram_range)
        for idx in ngram_idx:
            ngrams.append(input_ids[:, idx])
            ngram_vecs.append(token_embeds[:, idx].mean(axis=2))
        ngrams = [self.tokenizer.batch_decode(np.vstack(x)) for x in ngrams]
        ngram_vecs = [np.vstack(x) for x in ngram_vecs]
        ngrams = [y for x in ngrams for y in x]
        ngram_vecs = np.vstack(ngram_vecs)
        return ngrams, ngram_vecs
    
    def encode(self, docs, max_length=512, batch_size=32, amp=True, amp_dtype=torch.bfloat16):
        if max_length is None:
            if self.tokenizer.model_max_length is None:
                raise ValueError(
                    "max_length must be specified if tokenizer.model_max_length is None"
                )
            max_length = self.tokenizer.model_max_length
            print(f"max_length set to {max_length}")
        inputs = self.tokenizer(docs, max_length=max_length, padding="longest", truncation=True, return_tensors="pt")
        loader = DataLoader(TokenizedDataset(inputs), batch_size=batch_size, shuffle=False, pin_memory=True)
        outputs = []
        with torch.no_grad():
            with torch.cuda.amp.autocast(enabled=amp, dtype=amp_dtype):
                for batch in loader:
                    batch = {k: v.to(self.device, non_blocking=True) for k, v in batch.items()}
                    outputs.append(self.model(**batch).last_hidden_state.cpu().numpy())
        outputs = np.concatenate(outputs, axis=0)
        return self.extract_ngrams(inputs["input_ids"], outputs)

In [89]:
np.vstack([np.arange(75) + i for i in range(3)]).T

array([[ 0,  1,  2],
       [ 1,  2,  3],
       [ 2,  3,  4],
       [ 3,  4,  5],
       [ 4,  5,  6],
       [ 5,  6,  7],
       [ 6,  7,  8],
       [ 7,  8,  9],
       [ 8,  9, 10],
       [ 9, 10, 11],
       [10, 11, 12],
       [11, 12, 13],
       [12, 13, 14],
       [13, 14, 15],
       [14, 15, 16],
       [15, 16, 17],
       [16, 17, 18],
       [17, 18, 19],
       [18, 19, 20],
       [19, 20, 21],
       [20, 21, 22],
       [21, 22, 23],
       [22, 23, 24],
       [23, 24, 25],
       [24, 25, 26],
       [25, 26, 27],
       [26, 27, 28],
       [27, 28, 29],
       [28, 29, 30],
       [29, 30, 31],
       [30, 31, 32],
       [31, 32, 33],
       [32, 33, 34],
       [33, 34, 35],
       [34, 35, 36],
       [35, 36, 37],
       [36, 37, 38],
       [37, 38, 39],
       [38, 39, 40],
       [39, 40, 41],
       [40, 41, 42],
       [41, 42, 43],
       [42, 43, 44],
       [43, 44, 45],
       [44, 45, 46],
       [45, 46, 47],
       [46, 47, 48],
       [47, 4

In [90]:
ng = NgramEncoder("microsoft/deberta-v3-small")
ng



<__main__.NgramEncoder at 0x18a81d27200>

In [91]:
ngrams, ngram_vecs = ng.encode(["hello world", "goodbye world", "seriously goodbye", "oh no this one is much longer"], batch_size=32)
ngram_vecs

array([[ 0.419083  , -0.20553964,  0.20437278, ...,  0.0238938 ,
         0.08819325, -0.22163458],
       [ 0.40543842, -0.2113377 ,  0.2095865 , ...,  0.0425633 ,
         0.08094309, -0.2358725 ],
       [ 0.09142544, -0.32302827,  0.06221219, ...,  0.06342554,
        -0.1965838 , -0.17689157],
       ...,
       [-0.21452475,  0.14686857,  0.34473768, ..., -0.02249772,
        -0.14121933,  0.5197796 ],
       [-0.39555085, -0.00161369,  0.21678615, ...,  0.09426519,
        -0.05263494,  0.48606697],
       [-0.5440827 , -0.04249418,  0.19302572, ...,  0.08400199,
        -0.05081439,  0.30111313]], dtype=float32)

In [92]:
ngrams, ngram_vecs.shape

(['[CLS] hello world',
  'hello world[SEP]',
  'world[SEP][PAD]',
  '[SEP][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[CLS] goodbye world',
  'goodbye world[SEP]',
  'world[SEP][PAD]',
  '[SEP][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[CLS] seriously goodbye',
  'seriously goodbye[SEP]',
  'goodbye[SEP][PAD]',
  '[SEP][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[PAD][PAD][PAD]',
  '[CLS] oh no',
  'oh no this',
  'no this one',
  'this one is',
  'one is much',
  'is much longer',
  'much longer[SEP]',
  '[CLS] hello world[SEP]',
  'hello world[SEP][PAD]',
  'world[SEP][PAD][PAD]',
  '[SEP][PAD][PAD][PAD]',
  '[PAD][PAD][PAD][PAD]',
  '[PAD][PAD][PAD][PAD]',
  '[CLS] goodbye world[SEP]',
  'goodbye world[SEP][PAD]',
  'world[SEP][PAD][PAD]',
  '[SEP][PAD][PAD][PAD]',
  '[PAD][PAD][PAD][PAD]',
  '[PAD][PAD][PAD][PAD]',
  '[CLS] seriously goodbye[SEP]',
  'seriously goodbye[SEP][PAD]',
  'goodbye[SEP][P

In [93]:
len(ngrams)

88

In [94]:
ngram_vecs

array([[ 0.419083  , -0.20553964,  0.20437278, ...,  0.0238938 ,
         0.08819325, -0.22163458],
       [ 0.40543842, -0.2113377 ,  0.2095865 , ...,  0.0425633 ,
         0.08094309, -0.2358725 ],
       [ 0.09142544, -0.32302827,  0.06221219, ...,  0.06342554,
        -0.1965838 , -0.17689157],
       ...,
       [-0.21452475,  0.14686857,  0.34473768, ..., -0.02249772,
        -0.14121933,  0.5197796 ],
       [-0.39555085, -0.00161369,  0.21678615, ...,  0.09426519,
        -0.05263494,  0.48606697],
       [-0.5440827 , -0.04249418,  0.19302572, ...,  0.08400199,
        -0.05081439,  0.30111313]], dtype=float32)

In [95]:
from sklearn.neighbors import NearestNeighbors

nn = NearestNeighbors(n_neighbors=5, metric="cosine")
query = ngram_vecs[-1]
nn.fit(ngram_vecs)
dists, idx = nn.kneighbors([query])
[ngrams[i] for i in idx[0]]

['this one is much longer[SEP]',
 'one is much longer[SEP]',
 'this one is much longer',
 'no this one is much longer',
 'one is much longer']