<a href="https://colab.research.google.com/github/mmsamiei/just-practice-deep/blob/master/leon_language_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import numpy as np

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [0]:
class LM(nn.Module):
  
  def __init__(self, hid_size, vocab_size, n_head, n_layers, max_len, device):
    super().__init__()

    self.device = device
    
    self.hid_size = hid_size
    self.max_len = max_len

    self.embedding = nn.Embedding(vocab_size, hid_size)

    self.position_enc = nn.Embedding(self.max_len, self.hid_size)
    self.position_enc.weight.data = self.position_encoding_init(self.max_len, self.hid_size)
    self.scale = torch.sqrt(torch.FloatTensor([self.hid_size])).to(device)

    self.layer_norm = nn.LayerNorm(self.hid_size)
    self.decoder_layer = nn.TransformerDecoderLayer(d_model=hid_size, nhead = n_head)
    self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=n_layers, norm=self.layer_norm)
    self.fc = nn.Linear(hid_size, vocab_size)

    self._init_weights()
  
  def forward(self, x):
    sent_len, batch_size = x.shape[0], x.shape[1]
    memory_mask = self.generate_complete_mask(sent_len)
    tgt_mask = self.generate_triangular_mask(sent_len)
    memory = torch.zeros(1, batch_size, self.hid_size, device=self.device)

    temp = x
    temp = self.embedding(temp)

    pos = torch.arange(0,sent_len).unsqueeze(1).repeat(1,batch_size).to(self.device)
    temp_pos_emb = self.position_enc(pos)

    temp = temp * self.scale + temp_pos_emb
    temp = self.decoder(temp, memory, tgt_mask=tgt_mask)
    temp = self.fc(temp)
    return temp

  def _init_weights(self):
    for p in self.parameters():
      if p.dim() > 1:
        nn.init.xavier_uniform_(p)

  def generate_triangular_mask(self, size):
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)).to(device)
        return
        
  def generate_complete_mask(self, size):
        r"""Generate a square mask for the sequence. The masked positions are filled with float('-inf').
            Unmasked positions are filled with float(0.0).
        """
        mask = torch.empty(size, size).to(device)
        mask.fill_(float('-inf'))
        return mask

  def generate_sequence(self, src):
    #src = [sent_len]
    src = src.unsqueeze(1)
    #src = [sent_len, 1]
    generate_step = 0
    while generate_step < 10:
      out = self.forward(src)
      #out = [sent_len + 1, 1, vocab_size]
      out = torch.argmax(out[-1, :], dim=1) # [1]
      out = out.unsqueeze(0) #[1,1]
      src = torch.cat((src, out), dim=0)
      generate_step += 1
    src = src.squeeze(1)
    return src
  
  def position_encoding_init(self, n_position, d_pos_vec):
    ''' Init the sinusoid position encoding table '''

    # keep dim 0 for padding token position encoding zero vector
    position_enc = np.array([
        [pos / np.power(10000, 2*i/d_pos_vec) for i in range(d_pos_vec)]
        if pos != 0 else np.zeros(d_pos_vec) for pos in range(n_position)])

    position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
    position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
    temp = torch.from_numpy(position_enc).type(torch.FloatTensor)
    temp = temp.to(self.device)
    return temp


In [0]:
sent_len = 20
batch_size = 64
vocab_size = 1000
hid_dim = 512
x = torch.LongTensor(20, 64).random_(1000).to(device)
lm = LM(512, 1000, 8, 6, 1024,device).to(device)

In [0]:
lm(x).shape

torch.Size([20, 64, 1000])

In [0]:
#print(lm(x)[0,0])

In [0]:
import torchtext
from torchtext import data
import spacy
 
my_tok = spacy.load('en')
 
def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]
 
TEXT = data.Field(lower=True, tokenize=spacy_tok)

In [0]:
from torchtext.datasets import WikiText2
 
train, valid, test = WikiText2.splits(TEXT) # loading custom datasets requires passing in the field, but nothing else.

downloading wikitext-2-v1.zip


wikitext-2-v1.zip: 100%|██████████| 4.48M/4.48M [00:00<00:00, 8.47MB/s]


extracting


In [0]:
TEXT.build_vocab(train)
print(len(TEXT.vocab))

28870


In [0]:
train_iter, valid_iter, test_iter = data.BPTTIterator.splits(
    (train, valid, test),
    batch_size=64,
    bptt_len=32, # this is where we specify the sequence length
    device=device,
    repeat=False)

print(len(train_iter))


1093


In [0]:
b = next(iter(train_iter))
b.__dict__.keys()

dict_keys(['batch_size', 'dataset', 'fields', 'text', 'target'])

In [0]:
vocab_size = len(TEXT.vocab)
hid_size = 512
model = LM(hid_size, vocab_size, 8, 6, 1024, device).to(device)

In [0]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()


In [0]:
optimizer = NoamOpt(hid_size, 1, 2000,
            torch.optim.Adam(model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
criterion = torch.nn.CrossEntropyLoss()

In [46]:
from tqdm import tqdm

clip = 1

N_EPOCH = 3
for epoch in range(N_EPOCH):
  epoch_loss = 0
  model.train()
  for batch in tqdm(train_iter):
    optimizer.zero_grad()
    batch_text = batch.text
    batch_target = batch.target
    result = model(batch_text)
    loss = criterion(result.view(-1, result.shape[-1]), batch_target.view(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
    optimizer.step()
    epoch_loss += loss.item()
  print("\n", epoch_loss / len(train_iter))


  0%|          | 0/8737 [00:00<?, ?it/s][A
  0%|          | 1/8737 [00:00<1:55:32,  1.26it/s][A
  0%|          | 2/8737 [00:01<1:30:49,  1.60it/s][A
  0%|          | 3/8737 [00:01<1:10:23,  2.07it/s][A
  0%|          | 4/8737 [00:01<55:47,  2.61it/s]  [A
  0%|          | 5/8737 [00:01<45:18,  3.21it/s][A
  0%|          | 6/8737 [00:01<38:01,  3.83it/s][A
  0%|          | 7/8737 [00:01<32:32,  4.47it/s][A
  0%|          | 8/8737 [00:01<28:50,  5.05it/s][A
  0%|          | 9/8737 [00:02<25:58,  5.60it/s][A
  0%|          | 10/8737 [00:02<24:02,  6.05it/s][A
  0%|          | 11/8737 [00:02<23:11,  6.27it/s][A
  0%|          | 12/8737 [00:02<22:14,  6.54it/s][A
  0%|          | 13/8737 [00:02<21:37,  6.72it/s][A
  0%|          | 14/8737 [00:02<21:04,  6.90it/s][A
  0%|          | 15/8737 [00:02<20:59,  6.92it/s][A
  0%|          | 16/8737 [00:02<20:47,  6.99it/s][A
  0%|          | 17/8737 [00:03<20:29,  7.09it/s][A
  0%|          | 18/8737 [00:03<20:19,  7.15it/s][A
  


 3.268285764741368



  0%|          | 1/8737 [00:00<1:38:39,  1.48it/s][A
  0%|          | 2/8737 [00:00<1:14:47,  1.95it/s][A
  0%|          | 3/8737 [00:00<58:09,  2.50it/s]  [A
  0%|          | 4/8737 [00:01<46:30,  3.13it/s][A
  0%|          | 5/8737 [00:01<38:23,  3.79it/s][A
  0%|          | 6/8737 [00:01<32:40,  4.45it/s][A
  0%|          | 7/8737 [00:01<28:41,  5.07it/s][A
  0%|          | 8/8737 [00:01<26:04,  5.58it/s][A
  0%|          | 9/8737 [00:01<24:02,  6.05it/s][A
  0%|          | 10/8737 [00:01<22:35,  6.44it/s][A
  0%|          | 11/8737 [00:02<21:34,  6.74it/s][A
  0%|          | 12/8737 [00:02<20:56,  6.94it/s][A
  0%|          | 13/8737 [00:02<20:28,  7.10it/s][A
  0%|          | 14/8737 [00:02<20:09,  7.21it/s][A
  0%|          | 15/8737 [00:02<19:53,  7.31it/s][A
  0%|          | 16/8737 [00:02<19:40,  7.39it/s][A
  0%|          | 17/8737 [00:02<19:35,  7.42it/s][A
  0%|          | 18/8737 [00:02<19:32,  7.43it/s][A
  0%|          | 19/8737 [00:03<19:27,  7.47it/s

KeyboardInterrupt: ignored

In [47]:
TEXT.vocab.itos[25645]

'frilled'

In [51]:
source_sentence = ["i","like"]
print(source_sentence)
model.eval()
print(' '.join(source_sentence))
print()
x = TEXT.numericalize([source_sentence]).to(device).squeeze(1)
generated_sequence =model.generate_sequence(x)
words = [TEXT.vocab.itos[word_idx] for word_idx in generated_sequence]
print(' '.join(words))

['i', 'like']
i like

i like the the the the the the < the the "


In [0]:
pos = torch.arange(0,100).unsqueeze(1).repeat(1,64)

In [0]:
pos

tensor([[ 0,  0,  0,  ...,  0,  0,  0],
        [ 1,  1,  1,  ...,  1,  1,  1],
        [ 2,  2,  2,  ...,  2,  2,  2],
        ...,
        [97, 97, 97,  ..., 97, 97, 97],
        [98, 98, 98,  ..., 98, 98, 98],
        [99, 99, 99,  ..., 99, 99, 99]])