<a href="https://colab.research.google.com/github/mmsamiei/just-practice-deep/blob/master/Bradley_Chatbot.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torchtext
from torchtext.data import Field
import spacy

from spacy.symbols import ORTH
my_tok = spacy.load('en')

def spacy_tok(x):
    return [tok.text for tok in my_tok.tokenizer(x)]

QUERY = Field(lower=True, tokenize=spacy_tok)
RESPONSE = Field(lower=True, tokenize=spacy_tok, is_target=True, init_token='<bos>', eos_token='<eos>')

In [0]:
import torch
if torch.cuda.is_available:
  device = torch.device('cuda')
else:
  device = torch.device('cpu')

In [0]:
my_tok.tokenizer.add_special_case("don't", [{ORTH: "do"}, {ORTH: "n't"}])

In [0]:
from torchtext.data import TabularDataset

train_dataset = TabularDataset(path="./simpleR21.csv", format="CSV",
                               fields=[("query", QUERY),("response", RESPONSE)],
                               csv_reader_params={"delimiter":'\t'})

In [7]:
QUERY.build_vocab(train_dataset)
RESPONSE.build_vocab(train_dataset)
print(QUERY.vocab.stoi['film'])
print(QUERY.vocab.itos[33])
print(len(QUERY.vocab))

13
religion
66


In [0]:
from torchtext.data import BucketIterator

train_iterator = BucketIterator(dataset= train_dataset, batch_size=16
                                ,device=device
                                ,sort_key=lambda x: data.interleave_keys(len(x.query), len(x.response))
                                , repeat = False)

In [9]:
for batch in (iter(train_iterator)):
  print(batch.query.shape)

torch.Size([7, 16])
torch.Size([6, 16])
torch.Size([12, 16])
torch.Size([7, 16])


In [0]:
import torch.nn as nn
class Bradley(nn.Module):
  def __init__(self, src_voc_sze, trg_voc_sze, hid_sze, num_head, num_enc, num_dec):
    super(Bradley, self).__init__()
    self.hid_sze = hid_sze
    self.src_word_embedding = nn.Embedding(src_voc_sze, self.hid_sze)
    self.trg_word_embedding = nn.Embedding(trg_voc_sze, self.hid_sze)
    self.num_head = num_head
    self.transformer = nn.Transformer(self.hid_sze, self.num_head, num_enc, num_dec)
    self.fc = nn.Linear(self.hid_sze, trg_voc_sze)
  
  def forward(self, src, trg):
    temp_src = self.src_word_embedding(src)
    temp_trg = self.src_word_embedding(trg)
    temp = self.transformer(temp_src, temp_trg)
    return self.fc(temp)


In [0]:
bradley_model = Bradley(src_voc_sze=len(QUERY.vocab), trg_voc_sze=len(RESPONSE.vocab),
                        hid_sze=256, num_head=4,
                        num_enc=4, num_dec=2)
bradley_model = bradley_model.to(device)

In [79]:
criterion = nn.CrossEntropyLoss(ignore_index=RESPONSE.vocab.stoi['<pad>'])
lr = 0.5
optimizer = torch.optim.SGD(bradley_model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)
epoch_number = 10

for epoch in range(1, epoch_number+1):
  bradley_model.train()
  for batch in iter(train_iterator):
    src = batch.query
    trg = batch.query
    out = bradley_model(src, trg)
    loss = criterion(out.view(-1, out.shape[2]), trg.view(-1))
    print(loss.item())
    loss.backward()
    torch.nn.utils.clip_grad_norm_(bradley_model.parameters(), 0.5)
    optimizer.step()

5.154428482055664
3.9079723358154297
3.331676959991455
3.0151851177215576
2.178809642791748
2.5015740394592285
1.9549877643585205
1.8358585834503174
1.3946281671524048
1.6510370969772339
1.0656342506408691
1.4822208881378174
1.3972320556640625
0.8926560282707214
0.9742722511291504
0.752997100353241
0.8059644103050232
0.9954825043678284
0.4343242049217224
0.5353960990905762
0.2735218405723572
0.6622867584228516
0.722267210483551
0.3450008034706116
0.30299052596092224
0.30385521054267883
0.22809918224811554
0.4106246829032898
0.19677285850048065
0.2871069312095642
0.1225435733795166
0.1825486421585083
0.22453446686267853
0.10449979454278946
0.0790897011756897
0.07790741324424744
0.07518883049488068
0.03419020399451256
0.045394983142614365
0.024038830772042274


In [83]:
source_sentence = train_dataset[7].query
print(' '.join(source_sentence))

what is your favorite team ?


In [0]:
x = QUERY.numericalize([source_sentence]).to(device)