In [None]:
import re, collections
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
class WordPiece():
  
  def __init__(self,n_merges,data_path, min_ouccurance_count):
    # Number of times to be iterated
    self.n_merges = n_merges
    # path to the data file
    self.path =  data_path
    self.min_ouccurance_count = min_ouccurance_count
    # Builds tokens from the data file
    self.sorted_tokens = self.generate_tokens()

  # Splits sentences to words based on spaces and adds </w> token at the end of each word and space between each character of the word
  def get_vocab(self,filename):
    vocab = collections.defaultdict(int)
    with open(filename, 'r', encoding='utf-8') as fs:
        for l in fs:
            sent = l.strip().split(" ")
            for word in sent:
                #vocab is dictionary which has words along with its frequency of ocuurance
                vocab[' '.join(list(word)) + ' </w>'] += 1
    return vocab

  # Finds the frequency of occurance of 2 consecutive tokens in the vocab
  def get_pair_stats(self,vocab,tokens):
    pairs = collections.defaultdict(int)
    prob = collections.defaultdict(int)
    for word, freq in vocab.items():
        symbols = word.split()
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq
    pair_values = np.array(list(pairs.values()))
    vocab_sum = sum(list(tokens.values()))
    pair_prob = pair_values/vocab_sum  #np.sum(pair_values)
    pair_keys = list(pairs.keys())
    keys_prob = []
    pair_likelihood = collections.defaultdict(float)
    for i in range(len(pair_keys)):
      k1,k2 = pair_keys[i][0], pair_keys[i][1] 
      p1,p2 = tokens[k1]/vocab_sum, tokens[k2]/vocab_sum
      likelihood = pair_prob[i] - p1 - p2
      pair_likelihood[pair_keys[i]] = likelihood
    return pair_likelihood


  # merges 2 consecutive tokens of a word in the vocab
  def merge(self,pair,v_in):
    v_out = {}
    for word in v_in:
      w_out = re.sub(' '.join(pair), ''.join(pair),word)
      v_out[w_out] = v_in[word]
    return v_out
  #extracts tokens and its frequency from the vocab
  #extracts words and its corresponding token from the vocab
  def extract_tokens(self,vocab):
    tokens = collections.defaultdict(int)
    vocab_tokens = {}
    for word, freq in vocab.items():
        word_tokens = word.split()
        for token in word_tokens:
            tokens[token] += freq
        vocab_tokens[''.join(word_tokens)] = word_tokens
    return tokens, vocab_tokens

  #calculates length of the token
  def measure_token_length(self,token):
      if token[-4:] == '</w>':
          return len(token[:-4]) + 1
      else:
          return len(token)
  #generates subword tokens for the word from the data in a iterative approach
  def generate_tokens(self):
    vocab = self.get_vocab(self.path)
    n_merges  = self.n_merges
    #print("Before Merging")
    tokens, vocab_tokens = self.extract_tokens(vocab)
    #print('All tokens: {}'.format(tokens.keys()))
    #print('Number of tokens: {}'.format(len(tokens.keys())))
    #print(5*'==========')
    n_tokens = [len(tokens.keys())]
    for i in range(n_merges):
      pairs = self.get_pair_stats(vocab,tokens)
      if not pairs:
        break
      #extracts max freq co-occured tokens from the vocab
      best_pair = max(pairs, key = pairs.get)
      #print('Best Pair : {}, count : {}'.format(best_pair,pairs[best_pair]))
      #merges max co-occured token in the vocab
      vocab = self.merge(best_pair,vocab)
      #Extracts subword tokens from the vocab 
      tokens, vocab_tokens = self.extract_tokens(vocab)
      #print('All tokens: {}'.format(tokens.keys()))
      #print('Number of tokens: {}'.format(len(tokens.keys())))
      #print('==========')
      n_tokens.append(len(tokens.keys()))
    #sorts subword tokens based on its length in a decresing order
    sorted_tokens_tuple = sorted(tokens.items(), key=lambda item: (self.measure_token_length(item[0]), item[1]), reverse=True)
    sorted_tokens = [token for (token, freq) in sorted_tokens_tuple]
    #plt.plot(n_tokens)
    return sorted_tokens
  #generates token for new word from the existing set of tokens created from the data
  #A recursive function that identifies subword tokens for a word  
  def tokenize_word(self,string, unknown_token='</u>'):
      sorted_tokens = self.sorted_tokens
      left_substring = ""
      right_substring = ""
      matched_position = []
      if string == "":
          return ''
      if sorted_tokens == []:
          return unknown_token

      string_tokens = []
      for i in range(len(sorted_tokens)):
          token = sorted_tokens[i]
          matched_positions = [(m.start(0), m.end(0)) for m in re.finditer(token, string)]
          if len(matched_positions)!=0 :
            matched_position = matched_positions[0]
            left_substring = self.tokenize_word(string[0:matched_position[0]], unknown_token)
            right_substring = self.tokenize_word(string[matched_position[1]:len(string)],unknown_token)
            break
      return left_substring + " " + string[matched_position[0]:matched_position[1]] + " " + right_substring
  #performs subword tokenisation for a sentence
  def tokenize_sentence(self,sentence,unknown_token='</u>'):
    token  = ""
    for word in sentence.split():
      word += "</w>"
      if word == "" or word == " ":
        continue
      token+=self.tokenize_word(word)
    return np.array(token.split(" ")).T

In [None]:
path  = '/content/sample.txt'
min_ouccurance_count = 2
byte_pair_encoding = WordPiece(500,path, min_ouccurance_count)

In [None]:
byte_pair_encoding.sorted_tokens 

In [None]:
byte_pair_encoding.tokenize_word("Cambridge</w>")

' Cambridg  e  </w> '

In [None]:
byte_pair_encoding.tokenize_sentence("this is a notes")

array(['', 't', '', 'his</w>', '', 'is</w>', '', 'a', '', '</w>', '', 'n',
       '', 'o', '', 't', '', 'e', '', 's</w>', ''], dtype='<U7')