<a href="https://colab.research.google.com/github/mmsamiei/just-practice-deep/blob/master/Bradley_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torchtext
from torchtext.data import Field
import spacy

from spacy.symbols import ORTH
my_tok = spacy.load('en')

def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]

QUERY = Field(lower=True, tokenize=spacy_tok)
RESPONSE = Field(lower=True, tokenize=spacy_tok, is_target=True, init_token='<bos>', eos_token='<eos>')

In [0]:
import torch
if torch.cuda.is_available:
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [0]:
my_tok.tokenizer.add_special_case("don't", [{ORTH: "do"}, {ORTH: "n't"}])

In [0]:
from torchtext.data import TabularDataset

train_dataset = TabularDataset(path="./formatted_movie_lines.txt", format="CSV",
                               fields=[("query", QUERY),("response", RESPONSE)],
                               csv_reader_params={"delimiter":'\t'})

In [5]:
QUERY.build_vocab(train_dataset)
RESPONSE.build_vocab(train_dataset)
print(QUERY.vocab.stoi['film'])
print(QUERY.vocab.itos[33])
print(len(QUERY.vocab))
print(len(RESPONSE.vocab))

1059
're
48505
49036


In [0]:
from torchtext.data import BucketIterator

batch_size = 16

train_iterator = BucketIterator(dataset= train_dataset, batch_size=batch_size
                                ,device=device
                                ,sort_key=lambda x: data.interleave_keys(len(x.query), len(x.response))
                                , repeat = False)

In [7]:
for batch in (iter(train_iterator)):
  print(batch.query.shape)

torch.Size([56, 16])
torch.Size([24, 16])
torch.Size([32, 16])
torch.Size([47, 16])
torch.Size([134, 16])
torch.Size([45, 16])
torch.Size([66, 16])
torch.Size([39, 16])
torch.Size([49, 16])
torch.Size([105, 16])
torch.Size([91, 16])
torch.Size([32, 16])
torch.Size([65, 16])
torch.Size([65, 16])
torch.Size([49, 16])
torch.Size([61, 16])
torch.Size([29, 16])
torch.Size([40, 16])
torch.Size([45, 16])
torch.Size([72, 16])
torch.Size([53, 16])
torch.Size([46, 16])
torch.Size([96, 16])
torch.Size([72, 16])
torch.Size([48, 16])
torch.Size([33, 16])
torch.Size([18, 16])
torch.Size([67, 16])
torch.Size([34, 16])
torch.Size([32, 16])
torch.Size([42, 16])
torch.Size([31, 16])
torch.Size([25, 16])
torch.Size([76, 16])
torch.Size([38, 16])
torch.Size([49, 16])
torch.Size([22, 16])
torch.Size([22, 16])
torch.Size([34, 16])
torch.Size([199, 16])
torch.Size([53, 16])
torch.Size([26, 16])
torch.Size([23, 16])
torch.Size([36, 16])
torch.Size([47, 16])
torch.Size([45, 16])
torch.Size([27, 16])
torch.Size

In [0]:
import torch.nn as nn
class Bradley(nn.Module):
  def __init__(self, src_voc_sze, trg_voc_sze, hid_sze, num_head, num_enc, num_dec):
    super(Bradley, self).__init__()
    self.hid_sze = hid_sze
    self.src_word_embedding = nn.Embedding(src_voc_sze, self.hid_sze)
    self.trg_word_embedding = nn.Embedding(trg_voc_sze, self.hid_sze)
    self.trg_pos_embedding = nn.Embedding(1000, self.hid_sze)
    self.num_head = num_head
    self.transformer = nn.Transformer(self.hid_sze, self.num_head, num_enc, num_dec)
    self.fc = nn.Linear(self.hid_sze, trg_voc_sze)
  
  def forward(self, src, trg):
    #src = [src sent len, batch_size]
    #trg = [trg sent len, batch_size]
    temp_src = self.src_word_embedding(src)
    temp_trg = self.trg_word_embedding(trg)
    trg_sent_len, batch_size = trg.shape[0], trg.shape[1]
    trg_pos = self.trg_pos_embedding(torch.arange(0, trg_sent_len).unsqueeze(0).
                                     repeat(batch_size,1).to(device)).transpose(0,1)
    trg_mask = self._generate_square_subsequent_mask(trg_sent_len)
    temp_trg = temp_trg + trg_pos
    temp = self.transformer(temp_src, temp_trg, tgt_mask=trg_mask)
    return self.fc(temp)
  
  def _generate_square_subsequent_mask(self, sz):
    mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    mask = mask.to(device)
    return mask


  def greedy_inference_one_sample(self, src, max_len=10):

    self.eval()
    #src = [src sent len]
    src_len = src.shape[0]
    src = src.unsqueeze(1)
    #src = [sent_len, 1]
    trg = src.new_full((1,1), RESPONSE.vocab.stoi['<bos>'])
    #trg = [1,1]

    translation_step = 0
    while translation_step < max_len:
      out = self.forward(src, trg)
      out = out[-1,:]
      for i, val in enumerate(out[0]):
        if(val > 7):
          print(i, val)
      #out = [batch_size, trg_vocab_size]
      nex = out.argmax(dim=1).unsqueeze(0)
      #nex = [1, 1]
      trg = torch.cat((trg, nex), dim=0)
      translation_step += 1
    return trg



In [0]:
bradley_model = Bradley(src_voc_sze=len(QUERY.vocab), trg_voc_sze=len(RESPONSE.vocab),
                        hid_sze=64, num_head=4,
                        num_enc=3, num_dec=2)
bradley_model = bradley_model.to(device)

In [46]:
from tqdm import tqdm

criterion = nn.CrossEntropyLoss(ignore_index=RESPONSE.vocab.stoi['<pad>'])
lr = 0.5
optimizer = torch.optim.SGD(bradley_model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
epoch_number = 2

print(len(train_dataset)/batch_size)

for epoch in range(1, epoch_number+1):
  print("epoch ", epoch)
  bradley_model.train()
  for i, batch in enumerate(iter(train_iterator)):
    src = batch.query
    trg = batch.response
    optimizer.zero_grad()
    out = bradley_model(src, trg)
    loss = criterion(out[:-1,:].view(-1, out.shape[2]), trg[1:,:].view(-1))
    if(i%50==0):
      print("{}: {}".format(i, loss.item()))
    loss.backward()
    torch.nn.utils.clip_grad_norm_(bradley_model.parameters(), 0.5)
    optimizer.step()

13830.125
epoch  1
0: 10.899937629699707
50: 6.901618957519531
100: 6.507107734680176
150: 6.122533321380615
200: 6.052553653717041
250: 6.0089898109436035
300: 6.295837879180908
350: 6.096019268035889
400: 6.526276111602783
450: 5.659459590911865
500: 5.27095365524292
550: 5.2232441902160645
600: 5.813779354095459
650: 5.456448078155518
700: 5.803133010864258
750: 5.831628322601318
800: 6.0747904777526855
850: 5.233911037445068
900: 5.8898844718933105
950: 5.412067890167236
1000: 5.825284481048584
1050: 5.700466156005859
1100: 5.8393778800964355
1150: 5.119700908660889
1200: 5.418527603149414
1250: 5.604645252227783
1300: 5.449784278869629
1350: 5.345880031585693
1400: 5.425142288208008
1450: 5.454386234283447
1500: 5.3811821937561035
1550: 4.960111618041992
1600: 5.390223026275635
1650: 4.6738996505737305
1700: 5.0065693855285645
1750: 5.391326904296875
1800: 4.8972978591918945
1850: 4.975146293640137
1900: 4.940505504608154
1950: 5.012862682342529
2000: 5.221890926361084
2050: 5.593

KeyboardInterrupt: ignored

In [29]:
source_sentence = train_dataset[55].query
print(source_sentence)

['it', "'s", 'more']


In [30]:
x = QUERY.numericalize([source_sentence]).to(device)
x = x.flatten()
print(x.shape)
result = bradley_model.greedy_inference_one_sample(x)

torch.Size([3])
4 tensor(9.6419, device='cuda:0', grad_fn=<SelectBackward>)
7 tensor(7.6678, device='cuda:0', grad_fn=<SelectBackward>)
8 tensor(8.3314, device='cuda:0', grad_fn=<SelectBackward>)
12 tensor(7.0404, device='cuda:0', grad_fn=<SelectBackward>)
13 tensor(9.2171, device='cuda:0', grad_fn=<SelectBackward>)
15 tensor(10.2563, device='cuda:0', grad_fn=<SelectBackward>)
16 tensor(9.7627, device='cuda:0', grad_fn=<SelectBackward>)
18 tensor(7.6901, device='cuda:0', grad_fn=<SelectBackward>)
21 tensor(8.1877, device='cuda:0', grad_fn=<SelectBackward>)
23 tensor(7.0643, device='cuda:0', grad_fn=<SelectBackward>)
24 tensor(7.0875, device='cuda:0', grad_fn=<SelectBackward>)
28 tensor(7.4394, device='cuda:0', grad_fn=<SelectBackward>)
32 tensor(7.0171, device='cuda:0', grad_fn=<SelectBackward>)
38 tensor(8.9275, device='cuda:0', grad_fn=<SelectBackward>)
39 tensor(7.5258, device='cuda:0', grad_fn=<SelectBackward>)
44 tensor(7.3637, device='cuda:0', grad_fn=<SelectBackward>)
48 tensor(

In [32]:
RESPONSE.vocab.itos[73]

'"'

In [31]:
result

tensor([[ 2],
        [73],
        [73],
        [73],
        [73],
        [73],
        [73],
        [73],
        [73],
        [73],
        [73]], device='cuda:0')

In [14]:
def batch_index_to_strings(trg):
  # trg = [sent_len, batch_size]
  temp = trg.transpose(0,1)
  for i, row in enumerate(temp):
    print(row)

batch_index_to_strings(result)

tensor([   2, 4160, 4160, 4160, 4160, 4160, 4160, 4160, 4160, 4160, 4160],
       device='cuda:0')


In [0]:
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

In [0]:
mask = (torch.triu(torch.ones(5, 5)) == 1).transpose(0, 1)
mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

In [40]:
mask

tensor([[0., -inf, -inf, -inf, -inf],
        [0., 0., -inf, -inf, -inf],
        [0., 0., 0., -inf, -inf],
        [0., 0., 0., 0., -inf],
        [0., 0., 0., 0., 0.]])