In [1]:
!git clone https://github.com/karpathy/mingpt.git
%cd mingpt

Cloning into 'mingpt'...
remote: Enumerating objects: 175, done.[K
remote: Total 175 (delta 0), reused 0 (delta 0), pack-reused 175[K
Receiving objects: 100% (175/175), 1.37 MiB | 5.87 MiB/s, done.
Resolving deltas: 100% (101/101), done.
/content/mingpt


In [2]:
!nvidia-smi

NVIDIA-SMI has failed because it couldn't communicate with the NVIDIA driver. Make sure that the latest NVIDIA driver is installed and running.



In [3]:

import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [4]:
from mingpt.utils import set_seed
set_seed(42)

In [5]:
import urllib
import re
import random
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

from collections import OrderedDict, Counter

In [7]:
!gdown --id 1oAsRKCXzzzvhcElawXfjv1CBUo_Kr8g_
!gdown --id 1YRW6j_Gp0Xxkad-fHNzyfECxssYAsmRE

Downloading...
From: https://drive.google.com/uc?id=1oAsRKCXzzzvhcElawXfjv1CBUo_Kr8g_
To: /content/mingpt/cmudict-0.7b
3.72MB [00:00, 118MB/s]
Downloading...
From: https://drive.google.com/uc?id=1YRW6j_Gp0Xxkad-fHNzyfECxssYAsmRE
To: /content/mingpt/cmudict.symbols
100% 281/281 [00:00<00:00, 476kB/s]


In [8]:
IS_KAGGLE = False
import os 

CMU_DICT_PATH = os.path.join(
    '.', 'cmudict-0.7b')
CMU_SYMBOLS_PATH = os.path.join(
    '.','cmudict.symbols')

# Skip words with numbers or symbols
ILLEGAL_CHAR_REGEX = "[^A-Z-'.]"

# Only 3 words are longer than 20 chars
# Setting a limit now simplifies training our model later
MAX_DICT_WORD_LEN = 20
MIN_DICT_WORD_LEN = 2


def load_clean_phonetic_dictionary():

    def is_alternate_pho_spelling(word):
        # No word has > 9 alternate pronounciations so this is safe
        return word[-1] == ')' and word[-3] == '(' and word[-2].isdigit() 

    def should_skip(word):
        if not word[0].isalpha():  # skip symbols
            return True
        if word[-1] == '.':  # skip abbreviations
            return True
        if re.search(ILLEGAL_CHAR_REGEX, word):
            return True
        if len(word) > MAX_DICT_WORD_LEN:
            return True
        if len(word) < MIN_DICT_WORD_LEN:
            return True
        return False

    phonetic_dict = {}
    with open(CMU_DICT_PATH, encoding="ISO-8859-1") as cmu_dict:
        for line in cmu_dict:

            # Skip commented lines
            if line[0:3] == ';;;':
                continue

            word, phonetic = line.strip().split('  ')

            # Alternate pronounciations are formatted: "WORD(#)  F AH0 N EH1 T IH0 K"
            # We don't want to the "(#)" considered as part of the word
            if is_alternate_pho_spelling(word):
                word = word[:word.find('(')]

            if should_skip(word):
                continue

            if word not in phonetic_dict:
                phonetic_dict[word] = []
            phonetic_dict[word].append(phonetic)

    if IS_KAGGLE: # limit dataset to 5,000 words
        phonetic_dict = {key:phonetic_dict[key] 
                         for key in random.sample(list(phonetic_dict.keys()), 5000)}
    return phonetic_dict

phonetic_dict = load_clean_phonetic_dictionary()
example_count = np.sum([len(prons) for _, prons in phonetic_dict.items()])

In [9]:
phonetic_dict[' '] = ' '
print("\n".join([k+' --> '+phonetic_dict[k][0] for k in random.sample(list(phonetic_dict.keys()), 10)]))
print('\nAfter cleaning, the dictionary contains %s words and %s pronunciations (%s are alternate pronunciations).' % 
      (len(phonetic_dict), example_count, (example_count-len(phonetic_dict))))

PEDIGREES --> P EH1 D AH0 G R IY0 Z
BRUNCH --> B R AH1 N CH
AMERIFIRST --> AH0 M EH1 R IH0 F ER0 S T
SALE'S --> S EY1 L Z
ESTHETICS --> EH0 S TH EH1 T IH0 K S
DRIVES --> D R AY1 V Z
DEXTROSE --> D EH1 K S T R OW0 S
CELONA --> CH EH0 L OW1 N AH0
RUNNELLS --> R AH1 N AH0 L Z
BRANDTNER --> B R AE1 N T N ER0

After cleaning, the dictionary contains 124815 words and 133569 pronunciations (8754 are alternate pronunciations).


In [10]:
data = ''.join(x.lower() + '\n' for x in phonetic_dict)
targets = ''.join(re.sub('\d', '',phonetic_dict[x][0]) + '\n' for x in phonetic_dict)

In [11]:
class Tokenizer:
    def __init__(self, data, vocab_size):
        self.vocab_size = vocab_size
        self.vocab = self.build_vocab(data)
        
        self.stoi = { ch:i for i,ch in enumerate(self.vocab) }
        self.itos = { i:ch for i,ch in enumerate(self.vocab) }
    
    def sort_vocab(self, vocab):
        """
        Vocab should have the followind order: hashtag, numbers, characters sorted by length.
        Hashtags should go first, because they will be used as dividers on tokenization step.
        Numbers should go before characters, because token ids are numbers. Otherwise token ids will be considered as usual numbers and replaced twice.
        """
        sorted_vocab = sorted(vocab, key=lambda x: len(x), reverse=True)
        tag = [int(s) for s in sorted_vocab if s == '#']
        
        numeric = [int(s) for s in sorted_vocab if s.isnumeric()]
        numeric = [str(s) for s in sorted(numeric, reverse=True)]
        rest = [s for s in sorted_vocab if not s.isnumeric()]
        
        sorted_vocab = tag + numeric + rest
        
        return sorted_vocab
    
    def build_vocab(self, data):
        """
        Build vocabluary using BPE alghorithm.
        """
        vocab = set(data)
        if len(vocab) > self.vocab_size:
            raise ValueError('Vocab size should be greater than unique char count')

        # check all available characters
        char_set = {c for c in vocab if c.isalpha()}
        
        # candidates dictionary will contain a set of all available tokens to search
        candidate_dict = dict().fromkeys(char_set, 0)
        
        # occurrences will contain all matched tokens and the count, how many times the token has been found.
        token_occurrences = OrderedDict()
        while len(vocab) < self.vocab_size:
            for candidate in candidate_dict.keys():
                occurrences = data.count(candidate)
                candidate_dict[candidate] = occurrences

            candidate_dict = {candidate: count for candidate, count in candidate_dict.items() if count}
            vocab.update(set(candidate_dict.keys()))
            token_occurrences.update(candidate_dict)

            # build new candidates
            temp_candidate_set = set()
            for char in char_set:
                # don't test candidates with occurency <= 2. New candidates won't have occurency higher than 2
                temp_candidate_set.update({candidate + char for candidate in candidate_dict.keys() if token_occurrences[candidate] > 2})

            candidate_dict = dict().fromkeys(temp_candidate_set, 0)

        tokens_to_remove = len(vocab) - self.vocab_size
        token_occurrences = OrderedDict(sorted(token_occurrences.items(), key=lambda x: x[1], reverse=True))
        for _ in range(tokens_to_remove):
            token, _ = token_occurrences.popitem()
            vocab.remove(token)

        sorted_vocab = self.sort_vocab(vocab)
        
        # add a special token for unknown tokens
        sorted_vocab.append('<unk>')
        self.vocab_size += 1 # plus <unk> special token
        
        return sorted_vocab
    
    def tokenize(self, data, block_size):
        for token in self.vocab:
            data = data.replace(token, f'#{self.stoi[token]}#')

        # If everything went well, first and last characters won't have # pair. Need to trim them
        data = data[1:-1]
        # Split by ## pairs
        tokenized_text = data.split('##')
        # Filter empty strings
        tokenized_text = [x for x in tokenized_text if x]
        result = []
        for tokenized in tokenized_text:
            # In case other single # found, replace them with <unk> special token, marking the element as unknown
            if '#' in tokenized:
                for unknown_candidate in tokenized.split('#'):
                    if unknown_candidate.isnumeric():
                        result.append(self.itos[int(unknown_candidate)])
                    else:
                        result.append('<unk>')
            else:
                result.append(self.itos[int(tokenized)])

        # all texts should have equal size. We can make text length equal by filling text with spaces
        for _ in range(block_size - len(result)):
            result.append(' ')
            
        # in case the sentence is longer, than block_size, we trim the sentence
        return result[:block_size]
    
    def encode(self, data):
        return [self.stoi[s] for s in data]
    
    def decode(self, data):
        return ''.join([self.itos[int(i)] for i in data])

In [12]:
vocab_size = 49
# building vocabluary can take some time. ~5 minutes for 10_000 tokens for each tokenizer. 
tokenizer_data = Tokenizer(data, vocab_size)
tokenizer_targets = Tokenizer(targets, vocab_size)

In [23]:
tokenizer_data.encode(" ")l,

[46]

In [13]:
torch.save(tokenizer_data,'../tok_data.pt')
torch.save(tokenizer_targets,'../tok_tar.pt')

In [17]:
tokenizer_data = torch.load('../tok_data.pt')
tokenizer_targets = torch.load('../tok_tar.pt')

In [21]:
tokenizer_data.encode(" ")

[46]

In [25]:
import math
from torch.utils.data import Dataset

class Dataset(Dataset):

    def __init__(self, original, modern, tokenizer_targets, tokenizer_data, block_size):
        self.tokenizer_targets = tokenizer_targets
        self.tokenizer_data = tokenizer_data
        
        self.block_size = block_size * 2
        self.original = [tokenizer_targets.tokenize(t, block_size) for t in original]
        self.modern = [tokenizer_data.tokenize(t, block_size) for t in modern]
    
    def __len__(self):
        return len(self.original)

    def __getitem__(self, idx):
        """
        The idea is to get a sentence in a modern English
        and translate it to Shakespeare English.
        
        In the init method we already split a sentence into tokens and filled with spaces,
        to have an equal sentence size. In this method we just encode the tokens to
        ids (a list of numbers), and we're trying to map ids sequences
        (original Englisn and modern English)
        """
        
        modern_text = self.tokenizer_data.encode(self.modern[idx])
        original_text = self.tokenizer_targets.encode(self.original[idx])
        dix = modern_text + original_text
        
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        y[:int(self.block_size / 2) - 1] = -100
        
        return x, y

In [26]:

# Shuffle texts by lines
texts = list(zip(data.splitlines(), targets.splitlines()))
random.shuffle(texts)

data, targets = zip(*texts)

In [27]:

# Split texts into train, test and validation datasets
train_dataset_size = round(0.9 * len(data))
test_dataset_size = round(0.1 * len(data))

train_data = data[:train_dataset_size]
test_modern = data[train_dataset_size:train_dataset_size + test_dataset_size]

train_targets = targets[:train_dataset_size]
test_original = targets[train_dataset_size:train_dataset_size + test_dataset_size]


In [28]:

block_size = 100  # the estimate how long lines the text could be (token count)

train_dataset = Dataset(train_targets, train_data, tokenizer_targets, tokenizer_data, block_size)
test_dataset = Dataset(test_original, test_modern, tokenizer_targets, tokenizer_data, block_size)

In [29]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(tokenizer_targets.vocab_size, train_dataset.block_size,
                  n_layer=2, n_head=4, n_embd=512)
model = GPT(mconf)

05/02/2021 13:55:05 - INFO - mingpt.model -   number of parameters: 6.459392e+06


In [None]:
from mingpt.trainer import Trainer, TrainerConfig

tokens_per_epoch = len(train_dataset) * block_size
train_epochs = 20

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=train_epochs, batch_size=64, learning_rate=3e-4,
                      lr_decay=True, warmup_tokens=tokens_per_epoch, final_tokens=train_epochs*tokens_per_epoch,
                      num_workers=2)
trainer = Trainer(model, train_dataset, test_dataset, tconf)
trainer.train()

In [34]:
torch.save(model.state_dict(),'../model.pt')

In [35]:
def predict(context):
  x = torch.tensor(tokenizer_data.encode(tokenizer_data.tokenize(context, block_size)), dtype=torch.long)[None,...].to(trainer.device)
  y = sample(model, x, block_size, temperature=1.0, sample=True, top_k=10)[0]

  predicted = y[block_size:]
  return tokenizer_targets.decode(predicted)

In [45]:
from mingpt.utils import sample
from random import choice
model.to('cuda:0')
for _ in range(5):
    idx = choice(range(len(test_original)))

    context = test_modern[idx]
    
    print(f'Word:                    {context}')
    print(f'Predicted pronouncation: {predict(context)}')
    print(f'Real pronunciation:      {test_original[idx]}')
    print('--------------------------------------------------')

Word:                    consumers'
Predicted pronouncation: K AH N S UW M ER Z                                                                                     
Real pronunciation:      K AH N S UW M ER Z
--------------------------------------------------
Word:                    zsa
Predicted pronouncation: SH AA                                                                                                 
Real pronunciation:      ZH AA
--------------------------------------------------
Word:                    backbiting
Predicted pronouncation: B AE K B AY T IH NG                                                                                     
Real pronunciation:      B AE K B AY T IH NG
--------------------------------------------------
Word:                    gurion
Predicted pronouncation: G Y UH R IY AH N                                                                                       
Real pronunciation:      G Y UH R IY AH N
------------------------------------

In [42]:
print(predict('aditya'))
print(predict('vishal'))

AH D IH T Y AH                                                                                         
V IH SH AH L                                                                                           
