<a href="https://colab.research.google.com/github/deterministic-algorithms-lab/Speech-Explorations/blob/master/Wikipron/g2p_seq2seq_attn_lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Installing and Importing Packages

In [None]:
!pip install git+https://github.com/deepmind/dm-haiku
!pip install git+git://github.com/deepmind/optax.git
!pip install -q -U trax

In [1]:
import haiku as hk
import jax.numpy as jnp
import jax

In [2]:
lg = 'eng' 

In [3]:
config = {
    #data params
    'lang' : lg,
    'n_epochs' : 500,
    'batch_size' : 4,
    'data_files' : [f'{lg}.tsv'],
    'input_vocab_size' : None,
    'output_vocab_size' : None, 
    
    #optimizer params
    'max_grad_norm' : 1,
    'learning_rate' : 1e-4,

    #Model Params
    'd_model' : 512,
}

# Getting Data

In [None]:
!pip install wikipron

In [7]:
#!wikipron {lg} > {lg}.tsv
!shuf {lg}.tsv > shuf_{lg}.tsv
!rm {lg}.tsv
!mv shuf_{lg}.tsv {lg}.tsv

In [4]:
import pandas as pd
import functools

class tsv_loader:
    
    def __init__(self, config):
       self.config = config
       self.df = pd.read_csv(f'{lg}.tsv', sep='\t', names=['graphemes', 'phonemes'])
       self.setup_data()
       self.unique = {}
       self.unique['graphemes'] = self.get_unique(key='graphemes')
       self.unique['phonemes'] = self.get_unique(key='phonemes')
       self.token_to_id = {}
       self.id_to_token = {}
       self.generate_encodings(key='graphemes')
       self.generate_encodings(key='phonemes')
       self.set_config()
       
    def setup_data(self):
        
        def split_func(x):
            x['graphemes'] = [char for char in x['graphemes'].strip()]
            x['phonemes'] = x['phonemes'].strip().split()

        self.df.apply(split_func, axis=1)

    def get_unique(self, key):
        return functools.reduce(lambda a,b: set(a).union(set(b)), self.df[key].to_list(), set([]))
    
    def generate_encodings(self, key):
        self.token_to_id[key]= {}
        self.token_to_id[key]['<pad>'] = len(self.token_to_id[key])
        self.token_to_id[key]['<s>'] = len(self.token_to_id[key])
        self.token_to_id[key]['</s>'] = len(self.token_to_id[key])

        self.token_to_id[key].update( { token : id+3 for id, token in enumerate(list(self.unique[key])) } )
        self.id_to_token[key] = { id : token for token, id in self.token_to_id[key].items() }

    def set_config(self):
        self.config['max_length'] = {}
        self.config['max_length']['graphemes'] = functools.reduce(lambda a,b : max(a,len(b)), self.df['graphemes'], 0) + 2
        self.config['max_length']['phonemes'] = functools.reduce(lambda a,b : max(a,len(b)), self.df['phonemes'], 0) + 2
        self.config['input_vocab_size'] = len(self.token_to_id['graphemes'])
        self.config['output_vocab_size'] = len(self.token_to_id['phonemes'])
        
    def generate_batches(self):
        i = 0
        while i<len(self.df)//self.config['batch_size']:
            graphemes_batch = self.df['graphemes'][i:i+self.config['batch_size']]
            phonemes_batch = self.df['phonemes'][i:i+self.config['batch_size']]
            i+=1
            yield graphemes_batch, phonemes_batch
    
    def encode_batch(self, batch, key='graphemes', add_eos=True):
        padded_batch = []
        
        for elem in batch:
            padded_batch.append( ['<s>'] + elem + (['</s>'] if add_eos else [])
                                +['<pad>']
                                *(self.config['max_length'][key]-len(elem)) 
                               )
        return    jnp.asarray( [ [self.token_to_id[key][token] 
                                  for token in elem] 
                                  for elem in padded_batch], dtype=jnp.int16)
        
    def decode_batch(self, batch, key='phonemes'):
        decoded_batch = []
        for elem in batch :
            decoded_seq = []
            for id in elem:
                if id==self.token_to_id[key]['</s>']:
                    break
                if id!=self.token_to_id[key]['<pad>'] and id!=self.token_to_id[key]['<s>']:
                    decoded_seq.append(self.id_to_token[key][int(id)])
            decoded_batch.append(decoded_seq)
        return decoded_batch

In [5]:
data_loader = tsv_loader(config)
n_examples = len(data_loader.df)

# Trax LSTM

In [6]:
from trax.models.rnn import LSTMSeq2SeqAttn
from trax.layers.metrics import CategoryCrossEntropy
from trax import shapes

In [7]:
lstm = LSTMSeq2SeqAttn(input_vocab_size=config['input_vocab_size'],
                       target_vocab_size=config['output_vocab_size'],
                       d_model=config['d_model'])

In [8]:
grapheme_batch, phoneme_batch = next(data_loader.generate_batches())
input = data_loader.encode_batch(grapheme_batch)
labels = data_loader.encode_batch(phoneme_batch, key='phonemes')

In [9]:
lstm.init_weights_and_state((shapes.signature(input), shapes.signature(labels)))

## Loss and Update Functions

In [10]:
def loss_fn(params, input, labels, mask):
    prev_weights = lstm.weights
    lstm.weights = params
    logits, labels = lstm((input, labels))
    logits, labels = logits[:,:-1,:], labels[:, 1:]
    logits = jax.vmap(jnp.multiply, in_axes=(2, None), out_axes=2)(logits, labels!=0)
    loss = CategoryCrossEntropy()((logits, labels))
    lstm.weights = prev_weights
    return loss

@jax.jit
def update(params, opt_state, input, labels, mask):
    batch_loss, grads = jax.value_and_grad(loss_fn)(params, input, labels, mask)
    updates, opt_state = opt.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, batch_loss

params = lstm.weights

# Optimizer

In [11]:
import optax
import numpy as np

In [12]:
def make_lr_schedule(warmup_percentage, total_steps):
    
    def lr_schedule(step):
        percent_complete = step/total_steps
        
        #0 or 1 based on whether we are before peak
        before_peak = jax.lax.convert_element_type((percent_complete<=warmup_percentage),
                                                   np.float32)
        #Factor for scaling learning rate
        scale = ( before_peak*(percent_complete/warmup_percentage)
                + (1-before_peak) ) * (1-percent_complete)
        
        return scale
    
    return lr_schedule

In [13]:
total_steps = config['n_epochs']*(n_examples//config['batch_size'])

lr_schedule = make_lr_schedule(warmup_percentage=0.1, total_steps=total_steps)

In [14]:
opt = optax.chain(
    optax.clip_by_global_norm(config['max_grad_norm']),
    optax.adam(learning_rate=config['learning_rate']),
    optax.scale_by_schedule(lr_schedule),
)
opt_state = opt.init(params)

# Training Loop

In [15]:
for _ in range(config['n_epochs']):
    losses = []
    i=0
    for grapheme_batch, phoneme_batch in data_loader.generate_batches():
        input = data_loader.encode_batch(grapheme_batch)
        labels = data_loader.encode_batch(phoneme_batch, key='phonemes')
        mask = list(labels!=0)
        params, opt_state, batch_loss = update(params, opt_state, input, labels, mask)
        losses.append(batch_loss)
        if i%100==0:
            print(sum(losses)/100)
            losses = []
        i+=1

0.04482197
0.044597864
0.043900125
0.042742714
0.04131021
0.040239327
0.0397346
0.039486874
0.03934227
0.039235782
0.039152663
0.03908205
0.038992964
0.03889775
0.038761962
0.038601343
0.03842967
0.03836578
0.038281117
0.038347516
0.038225465
0.038306188
0.03782034
0.037672333
0.03761198
0.037461594
0.037201207
0.037030663
0.036825392
0.0365864
0.03643137
0.036413636
0.03621721
0.036176622
0.036547292
0.036158137
0.035956524
0.036052078
0.03617616
0.03566277
0.035829414
0.03896674
0.03678335
0.039515298
0.037940636
0.03844556
0.036840256
0.03828436
0.037817996
0.036145147
0.036238234
0.03681084
0.03743762
0.037317835
0.035944905
0.035325892
0.03505342
0.035487898
0.034950305
0.03482866
0.034202993
0.03464408
0.03421261
0.033963185
0.03424811
0.034192257
0.034341436
0.03404871
0.03528464
0.03434643
0.033301167
0.03433478
0.03333162
0.03420208
0.03343512
0.03306187
0.034235395
0.033275098
0.033736177
0.032264438
0.03257799
0.0328663
0.031970326
0.034451813
0.03183285
0.03180996
0.0319322

# Generating Phonemes

In [16]:
lstm = LSTMSeq2SeqAttn(input_vocab_size=config['input_vocab_size'],
                       target_vocab_size=config['output_vocab_size'],
                        )

lstm.init_weights_and_state((shapes.signature(input), shapes.signature(labels)))

lstm.weights = params

graphemes = [['a', 'b', 'a', 'c', 'k']]
phonemes = [[]]

grapheme_token_ids = data_loader.encode_batch(graphemes)
phoneme_token_ids = data_loader.encode_batch(phonemes, key='phonemes', add_eos=False)

for i in range(config['max_length']['phonemes']-1):
    
    logits = lstm((grapheme_token_ids, 
                              phoneme_token_ids))[0]
    
    preds = jnp.argmax(logits[:,i,:], axis=-1)
    
    phoneme_token_ids = jax.ops.index_update(phoneme_token_ids, 
                                             jax.ops.index[:,i+1],
                                             preds)
    
print(phoneme_token_ids)


[[ 1 66 40 20 26  2  2  2  2  2  2  2  2  2  2  2  0]]


In [17]:
print(data_loader.decode_batch(phoneme_token_ids))

[['ə', 'ˈb', 'æ', 'k']]
