<a href="https://colab.research.google.com/github/mmsamiei/just-practice-deep/blob/master/Abbas_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torchtext
from torchtext.data import Field
import spacy

from spacy.symbols import ORTH
my_tok = spacy.load('en')

def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]

QUERY = Field(lower=True, tokenize=spacy_tok)
RESPONSE = Field(lower=True, tokenize=spacy_tok, is_target=True, init_token='<bos>', eos_token='<eos>')

In [0]:
import torch
if torch.cuda.is_available:
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [0]:
my_tok.tokenizer.add_special_case("don't", [{ORTH: "do"}, {ORTH: "n't"}])
my_tok.tokenizer.add_special_case("can't", [{ORTH: "can"}, {ORTH: "not"}])

In [0]:
from torchtext.data import TabularDataset

train_dataset = TabularDataset(path="./formatted_movie_lines.txt", format="CSV",
                               fields=[("query", QUERY),("response", RESPONSE)],
                               csv_reader_params={"delimiter":'\t'})

In [6]:
QUERY.build_vocab(train_dataset)
RESPONSE.build_vocab(train_dataset)
print("id of 'film' in query vocab is {}".format(QUERY.vocab.stoi['film']))
print("word of id=33 in query vocab is '{}'".format(QUERY.vocab.itos[33]))
print("len of Query vocab is {}".format(len(QUERY.vocab)))
print("len of Response vocab is {}".format(len(RESPONSE.vocab)))

id of 'film' in query vocab is 1059
word of id=33 in query vocab is 'know'
len of Query vocab is 48505
len of Response vocab is 49036


In [7]:
print("number of rows in train data is {}".format(len(train_dataset)))

number of rows in train data is 221282


In [8]:
from torchtext.data import Dataset

def my_filter_pred(example, limited_word = 3):
  if(len(example.query) < limited_word and len(example.response) < limited_word):
    return True
  else:
    return False

phase_train_dataset = Dataset(examples = train_dataset.examples,
               fields=[("query", QUERY),("response", RESPONSE)],
               filter_pred = my_filter_pred)

print("len of this phase_train_dataset is {}".format(len(phase_train_dataset)))

len of this phase_train_dataset is 1646


In [0]:
from torchtext.data import BucketIterator, interleave_keys

batch_size = 512

train_iterator = BucketIterator(dataset= phase_train_dataset, batch_size=batch_size,
                                device=device,
                                sort_key=lambda x: interleave_keys(len(x.query), len(x.response)),
                                sort = True,
                                shuffle = True,
                                repeat = False)

In [10]:
## test if data loads well?
for a in iter(train_iterator):
  print("response shape : \t",a.response.shape)
  print("query shape :    \t",a.query.shape)

response shape : 	 torch.Size([4, 512])
query shape :    	 torch.Size([2, 512])
response shape : 	 torch.Size([4, 512])
query shape :    	 torch.Size([2, 512])
response shape : 	 torch.Size([4, 512])
query shape :    	 torch.Size([2, 512])
response shape : 	 torch.Size([4, 110])
query shape :    	 torch.Size([2, 110])


In [11]:
num_batch = 0
for batch in (iter(train_iterator)):
  num_batch += 1
print("number of batch is:", num_batch)

number of batch is: 4


In [0]:
import torch.nn as nn
class Abbas(nn.Module):
  def __init__(self, src_voc_sze, trg_voc_sze, hid_sze, num_head, num_enc, num_dec):
    super(Abbas, self).__init__()
    self.hid_sze = hid_sze
    self.src_word_embedding = nn.Embedding(src_voc_sze, self.hid_sze)
    self.trg_word_embedding = nn.Embedding(trg_voc_sze, self.hid_sze)
    self.trg_pos_embedding = nn.Embedding(800, self.hid_sze)
    self.num_head = num_head
    self.transformer = nn.Transformer(self.hid_sze, self.num_head, num_enc, num_dec)
    self.fc = nn.Linear(self.hid_sze, trg_voc_sze)
  
  def forward(self, src, trg):
    #src = [src sent len, batch_size]
    #trg = [trg sent len, batch_size]
    temp_src = self.src_word_embedding(src)
    temp_trg = self.trg_word_embedding(trg)
    trg_sent_len, batch_size = trg.shape[0], trg.shape[1]
    trg_pos = self.trg_pos_embedding(torch.arange(0, trg_sent_len).unsqueeze(0).
                                     repeat(batch_size,1).to(device)).transpose(0,1)
    trg_mask = self._generate_square_subsequent_mask(trg_sent_len)
    temp_trg = temp_trg + trg_pos
    temp = self.transformer(temp_src, temp_trg, tgt_mask=trg_mask)
    return self.fc(temp)
  
  def _generate_square_subsequent_mask(self, sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask = mask.to(device)
    return mask


  def greedy_inference_one_sample(self, src, max_len=20):

    self.eval()
    with torch.no_grad():
      #src = [src sent len]
      src_len = src.shape[0]
      src = src.unsqueeze(1)
      #src = [sent_len, 1]
      trg = src.new_full((1,1), RESPONSE.vocab.stoi['<bos>'])
      #trg = [1,1]

      translation_step = 0
      while translation_step < max_len:
        out = self.forward(src, trg)
        out = out[-1,:]
        #out = [batch_size, trg_vocab_size]
        nex = out.argmax(dim=1).unsqueeze(0)
        #nex = [1, 1]
        trg = torch.cat((trg, nex), dim=0)
        translation_step += 1
    return trg



In [0]:
hid_dim = 512
src_voc_size = len(QUERY.vocab)
trg_voc_size = len(RESPONSE.vocab)
num_head = 8
num_enc = 6
num_dec = 4

abbas_model = Abbas(src_voc_sze=src_voc_size, trg_voc_sze=trg_voc_size,
                        hid_sze=hid_dim, num_head=num_head,
                        num_enc=num_enc, num_dec=num_dec)
abbas_model = abbas_model.to(device)

In [14]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(abbas_model):,} trainable parameters')

The model has 111,238,540 trainable parameters


In [0]:
for p in abbas_model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [0]:
class NoamOpt:
    "Optim wrapper that implements rate."
    def __init__(self, model_size, factor, warmup, optimizer):
        self.optimizer = optimizer
        self._step = 0
        self.warmup = warmup
        self.factor = factor
        self.model_size = model_size
        self._rate = 0
        
    def step(self):
        "Update parameters and rate"
        self._step += 1
        rate = self.rate()
        for p in self.optimizer.param_groups:
            p['lr'] = rate
        self._rate = rate
        self.optimizer.step()
        
    def rate(self, step = None):
        "Implement `lrate` above"
        if step is None:
            step = self._step
        return self.factor * \
            (self.model_size ** (-0.5) *
            min(step ** (-0.5), step * self.warmup ** (-1.5)))
    
    def zero_grad(self):
        self.optimizer.zero_grad()

In [24]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss(ignore_index=RESPONSE.vocab.stoi['<pad>'])
optimizer = NoamOpt(hid_dim, 1, 2000,
            torch.optim.Adam(abbas_model.parameters(), lr=0, betas=(0.9, 0.98), eps=1e-9))
epoch_number = 100

for epoch in range(1, epoch_number+1):
  epoch_loss = 0
  abbas_model.train()
  for i, batch in enumerate(iter(train_iterator)):
    src = batch.query
    trg = batch.response
    optimizer.zero_grad()
    out = abbas_model(src, trg)
    loss = criterion(out[:-1,:].view(-1, out.shape[2]), trg[1:,:].view(-1))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(abbas_model.parameters(), 0.5)
    optimizer.step()
    epoch_loss += loss.item()
  avr_epoch_loss = epoch_loss / len(train_iterator)
  print("epoch {} loss is: {}".format(epoch, avr_epoch_loss))


epoch 1 loss is: 9.392859935760498
epoch 2 loss is: 9.378726482391357
epoch 3 loss is: 9.347369909286499
epoch 4 loss is: 9.305662870407104
epoch 5 loss is: 9.247285842895508
epoch 6 loss is: 9.177646398544312
epoch 7 loss is: 9.100929021835327
epoch 8 loss is: 9.016826152801514
epoch 9 loss is: 8.925750017166138
epoch 10 loss is: 8.834258079528809
epoch 11 loss is: 8.738357305526733
epoch 12 loss is: 8.642095565795898
epoch 13 loss is: 8.547211408615112
epoch 14 loss is: 8.445712566375732
epoch 15 loss is: 8.343979597091675
epoch 16 loss is: 8.237195014953613
epoch 17 loss is: 8.124429941177368
epoch 18 loss is: 8.004160046577454
epoch 19 loss is: 7.8751060962677
epoch 20 loss is: 7.737109899520874
epoch 21 loss is: 7.590033650398254
epoch 22 loss is: 7.436927080154419
epoch 23 loss is: 7.274193286895752
epoch 24 loss is: 7.106865644454956
epoch 25 loss is: 6.9355162382125854
epoch 26 loss is: 6.765461087226868
epoch 27 loss is: 6.580899953842163
epoch 28 loss is: 6.397045135498047
ep

In [25]:
source_sentence = phase_train_dataset[50].query
#source_sentence = ['what','is','?']
print(source_sentence)

['pupils', '?']


In [27]:
x = QUERY.numericalize([source_sentence]).to(device)
x = x.flatten()
print(x.shape)
result = abbas_model.greedy_inference_one_sample(x)

torch.Size([2])


In [28]:
result = result.flatten()
for wrd_ind in result:
  print(RESPONSE.vocab.itos[wrd_ind])

<bos>
please
.
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>
<eos>


In [0]:
print(RESPONSE.vocab.itos[2])
print(RESPONSE.vocab.itos[5])
print(RESPONSE.vocab.itos[8])
print(RESPONSE.vocab.itos[3])
print(RESPONSE.vocab.itos[22])
print(RESPONSE.vocab.itos[4])
print(RESPONSE.vocab.itos[66])
print(RESPONSE.vocab.itos[125])


<bos>
,
?
<eos>
!
.
think
little


In [0]:
def batch_index_to_strings(trg):
  # trg = [sent_len, batch_size]
  temp = trg.transpose(0,1)
  for i, row in enumerate(temp):
    print(row)

batch_index_to_strings(result)

tensor([2, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
       device='cuda:0')


In [0]:
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

NameError: ignored

In [0]:
mask = (torch.triu(torch.ones(5, 5)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

In [0]:
mask