In [1]:
'''
Changelog

-  (4/2/2023)
   * Implemented global attention for producing grapheme/phoneme embedding. The previous
     approach (concatenate all timesteps into a single vector + feedforward) was
     memory-intensive and risked contaminating the embeddings w/ 
     padding.
   * Removed some redundant layers in both encoder and decoder
   * _generate_ still doesn't work; need to discuss tokenizer w/ group

   --- Luke 
''';

'''
TODO:
Spin up Spot Instance with hosted notebook: Nathan
Training Loop must be completed: Luke
Finish the Generate Function
Tensorboard probably to be used in some capacity: Nathan/Luke in the future

--- Nathan
'''

'\nTODO:\nSpin up Spot Instance with hosted notebook: Nathan\nTraining Loop must be completed: Luke\nFinish the Generate Function\nTensorboard probably to be used in some capacity: Nathan/Luke in the future\n\n--- Nathan\n'

### Get Data

Load some data into the content folder. This should be in our shared ConnTextUL folder. You may need to move the shared folder to a location in your drive with the same full path as indicated here. Or, we can devise a more efficient way to sahre data in the future.



In [2]:
from google.colab import drive
drive.flush_and_unmount()
drive.mount('/gdrive', force_remount=True)
!ln -s "/gdrive/My Drive/Projects/Modeling Reading Programs/ConnTextUL/data" "/content"

Drive not mounted, so nothing to flush and unmount.
Mounted at /gdrive


Download and install huggingface transformer module for the CanineTokenizer

In [3]:
#https://pypi.org/project/transformers/
!pip install transformers
from transformers import CanineTokenizer

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.27.4-py3-none-any.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0
  Downloading huggingface_hub-0.13.3-py3-none-any.whl (199 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m199.8/199.8 KB[0m [31m21.7 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m53.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.13.3 tokenizers-0.13.3 transformers-4.27.4


In [213]:
from torch.utils.data import Dataset
import pandas as pd
import torch as pt
import numpy as np

In [226]:
class CUDA_Dict(dict):
    def to(self,device):
        return {key:self[key].to(device) for key in self.keys()}

In [234]:
class CharacterTokenizer:
    def __init__(self,list_of_characters):

        self.char_2_idx = {'[BOW]':0,'[EOW]':1,'[CLS]':2,'[UNK]':3,'[PAD]':4}
        for idx,character in enumerate(list_of_characters): self.char_2_idx[character] = idx+5
        self.idx_2_char = {self.char_2_idx[char]:char for char in self.char_2_idx}

    def __len__(self): return len(self.char_2_idx)

    def encode(self,list_of_strings):
        assert isinstance(list_of_strings,str) or (isinstance(list_of_strings,list) \
                 and all(isinstance(string,str) for string in list_of_strings))
        if isinstance(list_of_strings,str): list_of_strings = [list_of_strings]

        lengths = [len(string) for string in list_of_strings]
        max_length = max(lengths)

        pad = lambda string: ['[BOS]'] + list(string) + (max_length - len(string)) * ['[PAD]'] + ['[EOS]']
        list_of_strings = list(map(pad,list_of_strings))

        tokens = pt.zeros((len(list_of_strings),2 + max_length),dtype=pt.long)
        for idx,string in enumerate(list_of_strings):
            for jdx,char in enumerate(string):
                tokens[idx,jdx] = self.char_2_idx.get(char,3)

        attention_mask = (pt.arange(max_length+2)[None] >= (pt.Tensor(lengths)+2)[:,None]).float()
        return CUDA_Dict({'input_ids':tokens,'attention_mask':attention_mask.bool()})
            

    def decode(self,list_of_ints):
        assert isinstance(list_of_ints,int) or (isinstance(list_of_ints,int) \
                 and all(isinstance(ints,int) for ints in list_of_ints))
        if isinstance(list_of_ints,int): list_of_ints = [list_of_ints]

        outputs = [''.join([self.idx_2_char.get(i) for i in ints]) for ints in list_of_ints]

In [215]:
class GraphoneDataset(Dataset):
    """GraphoneDataset

    Dataset of word/phoneme pairs. The phonemes are predicted from the DeepPhonemizer. 
    The final embeddings output from this dataset come from Google's Canine model.

    """
    def __init__(self):

        # The orthography and phonology are stored in separate files
        self.letters = pd.read_csv("/gdrive/MyDrive/data/orth_phon_mappings/all_orth.csv", header=None)[0].to_numpy()
        self.phons = pd.read_csv("/gdrive/MyDrive/data/orth_phon_mappings/all_phon.csv", header=None)[0].to_numpy()
        self.max_len = max( max(map(len, self.letters)), max(map(len, self.phons)) )

    def __len__(self):
        length = len(self.letters)  
        assert length == len(self.phons), "Dataset size mismatch!"

        return length  

    def __getitem__(self, idx):
        orth_string = self.letters[idx]
        phon_string = self.phons[idx]        
        return (orth_string,phon_string)

def collate(batches,orthography_tokenizer,phoneme_tokenizer):
      orthography = [batch[0] for batch in batches]
      phonology = [batch[1] for batch in batches]

      return {'orthography':orthography_tokenizer.encode(orthography),'phonology':phoneme_tokenizer.encode(phonology)}

In [36]:
class Encoder(pt.nn.Module):
    def __init__(self, d_model=768, nhead=1, num_layers=1):
        super(Encoder, self).__init__()
        encoder_layer = pt.nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_encoder = pt.nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

    def forward(self, src, src_mask=None, src_key_padding_mask=None):
        output = self.transformer_encoder(src, mask=src_mask, src_key_padding_mask=src_key_padding_mask)
        return output
    
class Decoder(pt.nn.Module):
    def __init__(self, d_model=768, nhead=1, num_layers=1):
        super().__init__()
        decoder_layer = pt.nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer_decoder = pt.nn.TransformerDecoder(decoder_layer, num_layers=num_layers)

    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None):
        output = self.transformer_decoder(tgt, memory, 
                                            tgt_mask=tgt_mask, 
                                            memory_mask=memory_mask, 
                                            tgt_key_padding_mask=tgt_key_padding_mask, 
                                            memory_key_padding_mask=memory_key_padding_mask)

        return output

In [163]:
class GraphoneModel(pt.nn.Module):
    def __init__(self, orth_vocab_size, phon_orth_vocab_size, d_model=768, nhead=1, num_layers=1, max_seq_len=21):
        super().__init__()

        self.orthography_embedding = pt.nn.Embedding(orth_vocab_size,d_model)
        self.phonology_embedding = pt.nn.Embedding(phon_orth_vocab_size,d_model)
        self.position_embedding = pt.nn.Embedding(max_seq_len,d_model)

        self.vocab_sizes = (orth_vocab_size,phon_orth_vocab_size)
        self.d_model = d_model
        self.max_seq_len = max_seq_len

        self.global_embedding = pt.nn.Parameter(pt.randn((1,1,self.d_model))/self.d_model**.5,requires_grad=True)

        self.grapheme_encoder = Encoder(d_model=d_model, nhead=nhead, num_layers=num_layers)
        self.phoneme_encoder = Encoder(d_model=d_model, nhead=nhead, num_layers=num_layers)

        self.gp_multihead_attention = pt.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)
        self.pg_multihead_attention = pt.nn.MultiheadAttention(embed_dim=d_model, num_heads=nhead, batch_first=True)

        self.transformer_mixer = Encoder(d_model=self.d_model, nhead=nhead, num_layers=num_layers)
        self.reduce = pt.nn.Linear(self.d_model,self.d_model)

        self.grapheme_decoder = Decoder(d_model=self.d_model, nhead=nhead, num_layers=num_layers)
        self.linear_grapheme_decoder = pt.nn.Linear(self.d_model, self.vocab_sizes[0])

        self.phoneme_decoder = Decoder(d_model=self.d_model, nhead=nhead, num_layers=num_layers)
        self.linear_phoneme_decoder = pt.nn.Linear(self.d_model, self.vocab_sizes[1])

    def generate_triangular_mask(self, size, device):
        mask = pt.triu(pt.ones((size, size), dtype=pt.bool,device=device),1)
        return mask

    def embed_tokens(self,tokens,mode='o'):
        assert mode in ['o','p']

        if mode == 'o':
           return self.orthography_embedding(tokens) + self.position_embedding.weight[None,:tokens.shape[1]]
        else:
           return self.phonology_embedding(tokens) + self.position_embedding.weight[None,:tokens.shape[1]]

    def embed(self, graphemes, grapheme_padding_mask, phonemes, phoneme_padding_mask):
        graphemes,phonemes = self.embed_tokens(graphemes,'o'),self.embed_tokens(phonemes,'p')

        grapheme_encoding = self.grapheme_encoder(graphemes,src_key_padding_mask=grapheme_padding_mask)
        phoneme_encoding = self.phoneme_encoder(phonemes,src_key_padding_mask=phoneme_padding_mask)

        gp_encoding = self.gp_multihead_attention(grapheme_encoding, phoneme_encoding, phoneme_encoding,
                                                  key_padding_mask = phoneme_padding_mask)[0]
        pg_encoding = self.pg_multihead_attention(phoneme_encoding, grapheme_encoding, grapheme_encoding,
                                                  key_padding_mask = grapheme_padding_mask)[0]

        gp_pg = pt.cat((gp_encoding, pg_encoding), dim=1) + pt.cat((grapheme_encoding, phoneme_encoding), dim=1)
        gp_pg_padding_mask = pt.cat((grapheme_padding_mask,phoneme_padding_mask),dim=-1)

        gp_pg = pt.cat((self.global_embedding.repeat(gp_pg.shape[0],1,1),gp_pg),dim=1)
        gp_pg_padding_mask = pt.cat((pt.zeros((gp_pg.shape[0],1),device=gp_pg.device,dtype=pt.bool),gp_pg_padding_mask),dim=-1)
        mixed_encoding = self.transformer_mixer(gp_pg,src_key_padding_mask=gp_pg_padding_mask) 

        final_encoding = self.reduce(mixed_encoding[:,0]).unsqueeze(-2)
        return final_encoding,graphemes,phonemes


    def forward(self, graphemes, grapheme_padding_mask, phonemes, phoneme_padding_mask):
        mixed_encoding,graphemes,phonemes = self.embed(graphemes, grapheme_padding_mask, phonemes, phoneme_padding_mask)

        grapheme_ar_mask = self.generate_triangular_mask(graphemes.shape[1],graphemes.device)
        grapheme_output = self.grapheme_decoder(graphemes, mixed_encoding, tgt_mask = grapheme_ar_mask)

        phoneme_ar_mask = self.generate_triangular_mask(phonemes.shape[1],phonemes.device)
        phoneme_output = self.grapheme_decoder(phonemes, mixed_encoding, tgt_mask = phoneme_ar_mask)

        grapheme_token_logits = self.linear_grapheme_decoder(grapheme_output)
        phoneme_token_logits = self.linear_phoneme_decoder(phoneme_output)
        return grapheme_token_logits, phoneme_token_logits


    def generate(self, graphemes, grapheme_mask, phonemes, phoneme_mask, max_new_tokens=21):
        self.eval()
        device = next(self.parameters()).device

        with pt.no_grad():
            prompt_encoding = self.embed(graphemes, grapheme_mask, phonemes, phoneme_mask)[0]

        mask = self.generate_triangular_mask(self.max_seq_len, device)

        generated_tokens = pt.zeros((2,prompt_encoding.shape[0]),dtype=pt.long)
        generated_embeddings = self.embed_tokens(generated_tokens)[:,:,None]
        generated_tokens = generated_tokens[:,:,None]

        dummy_mask = pt.zeros((1,15),device=device)
        dummy_mask[0,0] = 1

        for step in range(max_new_tokens):
            step_mask = mask[:step+1, :step+1]

            with pt.no_grad():
                grapheme_token_logits = self.linear_grapheme_decoder(self.grapheme_decoder(generated_embeddings[0], prompt_encoding, tgt_mask=step_mask))
                phoneme_token_logits = self.linear_phoneme_decoder(self.phoneme_decoder(generated_embeddings[1], prompt_encoding, tgt_mask=step_mask))

            last_token_logits = (grapheme_token_logits[:,-1, :],phoneme_token_logits[:,-1, :])
            last_token_probs = (
                                    pt.softmax(last_token_logits[0], dim=-1),
                                    pt.softmax(last_token_logits[1], dim=-1)
                                )

            new_grapheme_token = pt.multinomial(last_token_probs[0], num_samples=1)
            new_phoneme_token = pt.multinomial(last_token_probs[1], num_samples=1)

            generated_tokens = pt.cat((generated_tokens,pt.stack((new_grapheme_token,new_phoneme_token),dim=0)),dim=2)

            generated_embeddings = pt.cat((generated_embeddings,pt.stack(
                                                (
                                                    self.embed_tokens(new_grapheme_token,'o'),
                                                    self.embed_tokens(new_phoneme_token,'p')
                                                  ),
                                            dim=0)),dim=2)
            
        return generated_tokens

In [175]:
if pt.cuda.is_available():
   device = pt.device('cuda:0')
else:
   device = pt.device('cpu')

In [235]:
ds = GraphoneDataset()

orthography_tokenizer = CharacterTokenizer(set(''.join(ds.letters)))
phonology_tokenizer = CharacterTokenizer(set(''.join(ds.phons)))

train,validation = pt.utils.data.random_split(ds,(int(.8 * len(ds)),len(ds) - int(.8 * len(ds))))

collate_fn = lambda x: collate(x,orthography_tokenizer,phonology_tokenizer)
train_loader = pt.utils.data.DataLoader(train, batch_size=64, shuffle=True,collate_fn = collate_fn)
val_loader = pt.utils.data.DataLoader(validation, batch_size=64, collate_fn = collate_fn)

In [230]:
### Luke: And voila! It works. 
gm = GraphoneModel(len(orthography_tokenizer),len(phonology_tokenizer), max_seq_len=100)
_ = gm.embed(pt.randint(0,10,(11, 21)), pt.zeros((11,21),dtype=pt.bool), pt.randint(0,10,(11, 21)), pt.zeros((11,21),dtype=pt.bool))
_ = gm(pt.randint(0,10,(11, 21)), pt.zeros((11,21),dtype=pt.bool), pt.randint(0,10,(11, 21)), pt.zeros((11,21),dtype=pt.bool))
_ = gm.generate(pt.randint(0,10,(11, 21)), pt.zeros((11,21),dtype=pt.bool), pt.randint(0,10,(11, 21)), pt.zeros((11,21),dtype=pt.bool))

  return torch._native_multi_head_attention(


In [243]:
import tqdm

num_epochs = 100
pbar = tqdm.tqdm(range(num_epochs),position=0)

gm.to(device)
opt = pt.optim.Adam(gm.parameters(),1e-3)

for epoch in pbar:
    gm.train()
    for batch in train_loader:
        orthography,phonology = batch['orthography'].to(device),batch['phonology'].to(device)
        logits = gm(orthography['input_ids'],orthography['attention_mask'],
                      phonology['input_ids'],phonology['attention_mask'])
        
        loss = pt.nn.CrossEntropyLoss(ignore_index=4)(logits[0].transpose(1,2),orthography['input_ids']) 
        loss = loss + pt.nn.CrossEntropyLoss(ignore_index=4)(logits[1].transpose(1,2),phonology['input_ids'])

        loss.backward()
        opt.step()
        opt.zero_grad()

    gm.eval()
    with pt.no_grad():
       for batch in val_loader:
           orthography,phonology = batch['orthography'].to(device),batch['phonology'].to(device)
           logits = gm(orthography['input_ids'],orthography['attention_mask'],
                      phonology['input_ids'],phonology['attention_mask'])

  0%|          | 0/100 [00:00<?, ?it/s]


KeyboardInterrupt: ignored