In [1]:
from transformers import *
import torch
from torch.utils.data import DataLoader
from collections import defaultdict


In [2]:
bert = DistilBertModel.from_pretrained('distilbert-base-uncased')

In [3]:
bert_emb = bert.embeddings.word_embeddings.weight.data
bert_emb.shape

torch.Size([30522, 768])

In [4]:
torch.save(bert_emb, "distilbert_embedding_matrix.pt")

In [5]:
bert_vocab = []
with open('bert-base-uncased-vocab.txt') as f:
  for l in f.readlines():
    bert_vocab.append(l.strip())

assert len(bert_vocab) == len(bert_emb)
longest_word_in_bert_vocab = max([len(w) for w in bert_vocab])
word_length = 12
def GetCharEncoding(word):
  enc = [0]*word_length
  for i, c in enumerate(word):
    if i>= word_length:
      break
    enc[i] = char_to_idx_map.get(c,0)
  return enc

In [6]:
torch.cuda.is_available()

True

In [7]:
imdb = torch.load('../../../character_convolution/data/torch_imdb.pt')

In [8]:
char_to_idx_map =  imdb['char_to_idx_map']

In [9]:
torch.save(char_to_idx_map, 'char_to_idx_map.pt')

In [2]:
chars = ['\x00', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.',
         '/', ':', ';', '<', '=', '>', '?', '@', '\\', '^', '_', '`', '{', '|', '}', '~', '[',  ']',
         '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f', 'g',
         'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
 

In [3]:
char_to_idx_map2 = {}

In [7]:
char_to_idx_map2['['], char_to_idx_map2['m'], char_to_idx_map2[']']

(32, 56, 33)

In [5]:
for i, c in enumerate(chars):
  char_to_idx_map2[c] = i

In [14]:
torch.save(char_to_idx_map2, 'char_to_idx_map2.pt')

In [27]:
word_encodings = []
for word in bert_vocab:
  encoding = GetCharEncoding(word.lower())
  word_encodings.append(encoding)

In [28]:
word_encodings[0]

[34, 55, 40, 43, 36, 0, 0, 0, 0, 0, 0, 0]

In [55]:

repeat_map = defaultdict(int)
for w in word_encodings:
  repeat_map[tuple(w)]+=1
valid_list = [1 if x<=1 else 0 for x in repeat_map.values()]
word_char_encoding = torch.tensor(word_encodings, dtype=torch.int64, device='cuda')
valid_list = torch.tensor(valid_list, dtype=torch.int64, device='cuda')
word_char_encoding = word_char_encoding[torch.where(valid_list==1)]
embeddings = bert_emb[torch.where(valid_list==1)]

In [57]:
data = {}

In [59]:
data['word_char_encoding'] = word_char_encoding.cuda()
data['embeddings'] = embeddings.cuda()
data['char_to_idx_map'] = char_to_idx_map

In [60]:
torch.save(data, 'preproc_distilbert_vocab_emb.pt')

In [61]:
# 2 way comparison -- load preprocessed data directly into GPU
# load embedding dict and word text dict into cpu and run multiprocessing to load each word into GPU during training
   # more flexible and simpler code (can be a part of the experiment at least)

In [15]:
class VocabDataset(torch.utils.data.Dataset):
  """Dataset with features: token characters and labels: token index or embedding
     provides a way to add random words after the token of interest and random misspellings.
  """     
  def __init__(self, vocab_file, char_to_idx_file, embedding_file, word_length, shuffle=False, normalize=False,
               misspelling_rate=None, misspelling_transforms=None, misspelling_type=None,
               add_next_word=False, add_random_count=0, space_freq=1.0, device='cpu'):
    self.data = self._preprocess(vocab_file, char_to_idx_file, embedding_file, word_length)
    self.keys=['word_char_encoding', 'embeddings']
    if normalize:
      self.data['embeddings'] = self.data['embeddings']/self.data['embeddings'].norm(dim=-1, p=2).unsqueeze(-1)
    self.misspelling_rate =misspelling_rate
    self.misspelling_transforms=misspelling_transforms
    self.misspelling_type=misspelling_type
    self.add_next_word=add_next_word
    self.device=device
    self.add_random_count = add_random_count
    self.space_freq = space_freq
    self.embedding_matrix = self.data['embeddings']
    if shuffle:
      r = torch.randperm(self.nelement())
      for k in self.keys:
        self.data[k][r] = self.data[k]

    self.data['word_indices'] = torch.arange(len(self.data['embeddings']))
    self.char_to_idx_map = self.data['char_to_idx_map']

    chars_for_insertion = [
         'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p',
         'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

    self.insertion_lookup = torch.tensor([self.char_to_idx_map[c] for c in chars_for_insertion])
    self.transform_options = [self.add_letter,
                              self.substitute_letter,
                              self.transpose_letters,
                              self.delete_letter,
                              self.repeat_letter]
    
  def __len__(self):
    return len(self.data[self.keys[0]])

  def __getitem__(self, idx):
    if self.misspelling_rate == "not set":
      raise "Must set mispelling rate (or explicitly set to None) as a hyperparameter"
    char_encoded_input = self.data['word_char_encoding'][idx]

    item = {
      'target_embeddings': self.data['embeddings'][idx],
      'labels': self.data['word_indices'][idx],
    }

    if self.misspelling_rate and torch.rand((1,)) <= self.misspelling_rate:
      char_encoded_input = self.apply_misspelling(char_encoded_input)

    first_end_ind = None
    for i in range(self.add_random_count):
      if torch.rand(size=())<.8: # skip 1/5 of the time
        char_encoded_input, end_indx=self.add_random_next_word(char_encoded_input)
        if not first_end_ind:
          first_end_ind = end_indx
          #item['end_of_word_index'] = first_end_ind
    item['features'] = char_encoded_input
    
    return item
  
  def GetCharEncoding(word, char_to_idx_map, word_length):
    enc = [0]*word_length
    for i, c in enumerate(word):
      if i>= word_length:
        break
      enc[i] = char_to_idx_map.get(c,0)
    return enc

  def _preprocess(self, vocab_file, char_to_idx_file, embedding_file, word_length):
    bert_vocab = []
    with open(vocab_file) as f:
      for l in f.readlines():
        bert_vocab.append(l.strip())

    bert_emb = torch.load(embedding_file)
    assert len(bert_vocab) == len(bert_emb)
    longest_word_in_bert_vocab = max([len(w) for w in bert_vocab])

    char_to_idx_map = torch.load(char_to_idx_file)
    
    word_encodings = []
    for word in bert_vocab:
      encoding = self.GetCharEncoding(word.lower(), char_to_idx_map, word_length)
      word_encodings.append(encoding)
      
    repeat_map = defaultdict(int)
    for w in word_encodings:
      repeat_map[tuple(w)]+=1
    valid_list = [1 if x<=1 else 0 for x in repeat_map.values()]
    word_char_encoding = torch.tensor(word_encodings, dtype=torch.int64)
    valid_list = torch.tensor(valid_list, dtype=torch.int64)
    word_char_encoding = word_char_encoding[torch.where(valid_list==1)]

    embeddings = bert_emb[torch.where(valid_list==1)].to(dev)
    return {"word_char_encoding":word_char_encoding,
            "embeddings": embeddings,
            "char_to_idx_map":char_to_idx_map}

  def random_location(self, word_size):
    return torch.randint(0, max(word_size, 0), size=())

  def random_letter(self,):
    return self.insertion_lookup[torch.randint(0, self.insertion_lookup.shape[0], size=())]

  def word_len(self, word):
    return (word!=0).sum()
  
  def add_letter(self, word):
    word = word.clone()
    n = self.random_location(min(word.shape[0], self.word_len(word)+1))
    p2 = word[n:].clone()
    rl = self.random_letter()
    word[n+1:] = p2[:-1] # this will remove the last letter of some words (very few. not concerned for now)
    word[n] = rl  
    return word

  def repeat_letter(self, word):
    word = word.clone()
    n = self.random_location(self.word_len(word))
    p2 = word[n:].clone()
    word[n+1:] = p2[:-1] # this will remove the last letter of some words (very few. not concerned for now)
    return word
  
  def substitute_letter(self, word):
    word = word.clone()
    n = self.random_location(self.word_len(word))
    rl = self.random_letter()
    word[n] = rl
    return word

  def delete_letter(self, word):
    word = word.clone()
    n = self.random_location(self.word_len(word))
    word[n:-1] = word[n+1:].clone()
    word[-1] = 0
    return word

  def transpose_letters(self, word):
    word = word.clone()
    n = self.random_location(self.word_len(word)-1)
    t = word[n].clone()
    word[n] = word[n+1]
    word[n+1] = t
    return word

  def AlterWord(self, word, transforms, misspelling_type=None):
    if self.word_len(word) <= 2:
      return word

    if misspelling_type == "add":
      change_fn = self.add_letter
    if misspelling_type == "repeat":
      change_fn = self.repeat_letter      
    if misspelling_type == "substitute":
      change_fn = self.substitute_letter
    if misspelling_type == "delete":
      change_fn = self.delete_letter
    if misspelling_type == "transpose":
      change_fn = self.transpose_letters

    for i in range(transforms):
      if not misspelling_type:
        change_fn = self.transform_options[torch.randint(0, len(self.transform_options), size=())]
      word = change_fn(word)
    return word

  def apply_misspelling(self, char_encoded_input):
    return self.AlterWord(char_encoded_input,
                          transforms=self.misspelling_transforms,
                          misspelling_type=self.misspelling_type
    )
  
  def add_random_next_word(self, char_encoded_input):
    # returns the word with a random word added at the end of the word
    # and the index after the end of the word.
    pads = torch.where(char_encoded_input==0)
    if pads[0].nelement() == 0:
      idx = torch.LongTensor(char_encoded_input.size())[0] -1
      return char_encoded_input, idx
    end = pads[0][0]
    add_word = torch.randint(0, len(self.data['embeddings']), ())
    char_encoded_input = char_encoded_input.clone()
    # add a space 1/2 the time
    if torch.rand(size=())>=self.space_freq:
      char_encoded_input[end:] = self.data['word_char_encoding'][add_word][:len(char_encoded_input)-end]
    else:
      char_encoded_input[end] = self.data['char_to_idx_map'][' '] # space char
      char_encoded_input[end+1:] = self.data['word_char_encoding'][add_word][:len(char_encoded_input)-end-1]
    return char_encoded_input, end

In [14]:
%%timeit
vocab_file = 'bert-base-uncased-vocab.txt'
char_to_idx_file = 'char_to_idx_map.pt'
embedding_file = 'distilbert_embedding_matrix.pt'
bs = 128
vds = VocabDataset(vocab_file, char_to_idx_file, embedding_file, add_random_count=2)
train_loader = DataLoader(vds, batch_size=bs, shuffle=True, num_workers=10)
for i in range(1):
  for d in train_loader:
    a = d['target_embeddings'].cuda()
    a+=1

1.03 s ± 14.3 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
#preproc inline: 299 ms ± 4.44 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# add 2 random words: 9.66 s ± 37.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#  all on GPU -- gonna add random words on cpu and see how that looks?

# do everything on the cpu with 12 workers and then convert to gpu: 957 ms ± 26.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 2 workers: 1.88 s ± 29 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# 1 worker: 3.41 s ± 29.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# without load preproc load: 188 ms ± 2.32 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# with load preproc: 245 ms ± 1.04 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)