In [None]:
BASE_DIR = '' # Working directory
MODELS_DIR = f'{BASE_DIR}models/' # Models directory
EXPACE_DIR = '' # ExPACE corpus folder (available on suplementary data)

In [None]:
from os import listdir
from random import shuffle
from collections import defaultdict
import pickle

In [None]:
def read_file(file):
    with open(file, encoding="utf-8") as f:
        return f.read()              
    
expace_files = [EXPACE_DIR + file for file in listdir(EXPACE_DIR) if '.txt' in file]
shuffle(expace_files)

In [None]:
from transformers import GPT2TokenizerFast
gpt2_tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')

In [None]:
class TokenizerWrapper:
        
    def __init__(self, tokenizer, config={}):
        self.device = 'cuda'
        self.tokenizer = tokenizer
        self.config = config
        self.vocab = tokenizer.get_vocab().keys()
        
    def encode_sentence(self, sent):
        sent_special, map_tokens = self.add_special_tokens(sent)
        inputs_ids = self.tokenizer.encode(sent_special, return_tensors='pt').to(device=device)

        return inputs_ids, map_tokens

    def get_special_sentences(self, sentences):
        list_map_tokens = []
        special_sentences = []
        for sent in sentences:
            sent_special, map_tokens = self.add_special_tokens(sent)
            special_sentences.append(sent_special)
            list_map_tokens.append(map_tokens)
        
        return special_sentences, list_map_tokens
    
    def tokenize(self, sentence):
        self.get_special_sentences(sentences)
        
    def encode_batch_inference(self, sentences):        
        special_sentences, map_tokens = self.get_special_sentences(sentences)
        encodings_input = self.tokenizer(special_sentences, truncation=True, padding='longest', return_tensors='pt').to(self.device)
        
        return encodings_input, map_tokens
    
    def decode_batch(self, outputs, map_tokens, original_sentences):        
        return [self.decode_sentence(output, map_token, original) for output, map_token, original in zip(outputs, map_tokens, original_sentences) ]
    
    def encode_batch_training(self, sentences):        
        special_sentences, map_tokens = self.get_special_sentences(sentences)
        encodings_input = self.tokenizer(special_sentences, truncation=True, padding='longest')
        
        return encodings_input, map_tokens
    
    def decode_sentence(self, output, map_tokens, original=None):
        sentence = self.tokenizer.decode(output, skip_special_tokens=True)
        return self.replace_special_tokens(sentence, map_tokens, original)
    
    def in_vocab(self, char):
        return char in self.vocab 

    def has_oov(self, sent):
        chars = list(set([char for char in sent if char != ' ']))
        for char in chars:
            if not in_vocab(char, self.vocab):
                return True
        return False

    def add_special_tokens(self, sent):
        start_index = next((int(item.replace('[', '').replace(']','')) for item in sorted(re.findall(r'(\[[0-9]+\])',sent), reverse=True)), 0)
        keep_to_tokens = {}

        def map_char_to_token(char):
            nonlocal start_index
            start_index += 1
            return f'[KEEP{start_index}]'

        def convert_token(token):    
            nonlocal keep_to_tokens
            chars = list(token)
            for char in chars:
                if not self.in_vocab(char):            
                    mapped_token = map_char_to_token(char)
                    last_char = chars[-1]
                    if last_char == ',' or last_char == '.':
                        keep_to_tokens[mapped_token] = token[:-1]
                        return mapped_token + last_char
                    else:
                        keep_to_tokens[mapped_token] = token         
                        return mapped_token

            return token

        updated_sent = ' '.join([convert_token(token) for token in sent.split()])    
                
        def keep_parenthesis(sent):
            nonlocal start_index

            while re.search('(\(.*?\))+', sent):
                start_index += 1
                mapped_token = f'[KEEP{start_index}]'
                match = str(re.search('(\(.*?\))+', sent)[0])
                sent = re.sub('(\(.*?\))+', mapped_token, sent, 1)     
                keep_to_tokens[mapped_token] = match
            return sent

        if self.config.get('keep_parentheses'):
            updated_sent = keep_parenthesis(updated_sent)

        return updated_sent, keep_to_tokens

    def replace_special_tokens(self, sentence, keep_to_tokens, original):
        for key in keep_to_tokens:
            if key not in sentence and original:
                return original
            sentence = sentence.replace(key, keep_to_tokens[key])
        return sentence
    
tokenizer_wrapper = TokenizerWrapper(tokenizer)

In [None]:
from nltk import word_tokenize, sent_tokenize
from tqdm import tqdm
from functools import reduce

class UnigramModel():
    def __init__(self, tokenizer):
        self.unigram_model = defaultdict(int)
        self.word_count = 0
        self.vocab_count = 0
        self.tokenizer = tokenizer
    
    def tokenize_subwords(self, sent):
        # This is the special token use by GPT-2 to separate words
        return [item.replace('Ġ', '') for item in self.tokenizer.tokenize(sent)]

    def build(self, files, limit=None):
        for file in tqdm(files[:limit]):
            content = read_file(file)
            for sent in sent_tokenize(content):
                if len(sent) > 1024:
                    continue
                for word in self.tokenize_subwords(sent):
                    token = word.lower()
                    self.unigram_model[token] += 1
           
        self.word_count = sum(self.unigram_model.values())
        self.vocab_count = len(self.unigram_model.keys())
    
    def prob(self, words):
        probs = [self.prob_token(sub) for sub in self.tokenize_subwords(words)]
        return reduce((lambda x, y: x * y), probs)
        
    def prob_token(self, token):
        return (self.unigram_model[token]+1)/(self.word_count+1)

In [None]:
uni_model = UnigramModel(gpt2_tokenizer)

In [None]:
uni_model.build(expace_files)

In [None]:
with open(f'{MODELS_DIR}/unigram-expace/model-v2.pkl','wb') as f:
    pickle.dump(uni_model, f)