In [39]:
import torch, model, importlib
importlib.reload(model);

In [40]:
import datasets

dataset = datasets.load_dataset('daily_dialog')
dataset

Found cached dataset daily_dialog (C:/Users/Erwan/.cache/huggingface/datasets/daily_dialog/default/1.0.0/1d0a58c7f2a4dab5ed9d01dbde8e55e0058e589ab81fce5c2df929ea810eabcd)
100%|██████████| 3/3 [00:00<00:00, 61.65it/s]


DatasetDict({
    train: Dataset({
        features: ['dialog', 'act', 'emotion'],
        num_rows: 11118
    })
    validation: Dataset({
        features: ['dialog', 'act', 'emotion'],
        num_rows: 1000
    })
    test: Dataset({
        features: ['dialog', 'act', 'emotion'],
        num_rows: 1000
    })
})

In [41]:
from torchtext.vocab import GloVe, vocab

pretrained_vectors = GloVe(name="6B", dim=50)
pretrained_vocab = vocab(pretrained_vectors.stoi)
pretrained_vocab.insert_token("<unk>", 0)
pretrained_vocab.insert_token("<pad>", 1)
pretrained_vocab.set_default_index(0)
pretrained_embeddings = pretrained_vectors.vectors
pretrained_embeddings = torch.cat((torch.zeros(1,pretrained_embeddings.shape[1]),pretrained_embeddings))
vocab_stoi = pretrained_vocab.get_stoi()

In [42]:
import numpy as np
from torchtext.data import get_tokenizer
tokenizer = get_tokenizer("basic_english")

max_len = max([
    len(tokenizer(sentence))
     for dialog in dataset['train']['dialog'] + dataset['validation']['dialog'] + dataset['test']['dialog']
     for sentence in dialog 
])
max_len

290

In [43]:
@np.vectorize
def numericalize(word, vocab_stoi):
    return vocab_stoi[word] if word in vocab_stoi else 0

def preprocess(sentence, max_len):
    # Tokenizer, word to indices, padding
    sentence = tokenizer(sentence.lower())
    sentence = sentence + ["<pad>" for _ in range(max_len-len(sentence))]
    sentence = numericalize(sentence, vocab_stoi)
    return sentence
    

In [44]:
# For each dialog, we take the first 5 utterances
new_dataset = dict()
for split in ['train', 'validation', 'test']:
    new_dataset[split] = [
        (
            np.array([preprocess(sentence, max_len=max_len) for sentence in dialog[:5]]),
            np.array([a for a in act[:5]])
        )
        for dialog, act in zip(dataset[split]['dialog'], dataset[split]['act'])
        if len(dialog) >= 5
    ]

In [45]:
new_dataset['train'][0][0].shape # (sequence_length, max_sentence_length)

(5, 290)

In [46]:
new_dataset['train'][0][1].shape # (sequence_length,)

(5,)

In [47]:
class DialogActDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return dict(
            dialog = torch.tensor(self.data[idx][0]),
            act = torch.tensor(self.data[idx][1])
        )

In [48]:
kwargs = dict(
    batch_size = 64,
    shuffle = True,
    drop_last = True
)

train_loader = torch.utils.data.DataLoader(
    DialogActDataset(new_dataset['train']), **kwargs
)
val_loader = torch.utils.data.DataLoader(
    DialogActDataset(new_dataset['validation']), **kwargs
)
test_loader = torch.utils.data.DataLoader(
    DialogActDataset(new_dataset['test']), **kwargs
)

In [49]:
import torchinfo

network = model.Seq2SeqModel(
    nb_classes = 5,
    sequence_length = 5,
    hidden_size = 128,
    pretrained_embeddings = pretrained_embeddings,
)
torchinfo.summary(network)

Layer (type:depth-idx)                   Param #
Seq2SeqModel                             --
├─Embedding: 1-1                         (20,000,050)
├─HierarchicalEncoder: 1-2               --
│    └─GRU: 2-1                          138,240
│    └─DiscontinuedGRU: 2-2              --
│    │    └─GRUCell: 3-1                 148,224
│    │    └─GRUCell: 3-2                 148,224
│    └─GRU: 2-3                          296,448
├─SoftGuidedAttentionDecoder: 1-3        --
│    └─Linear: 2-4                       513
│    └─GRUCell: 2-5                      394,752
├─Linear: 1-4                            1,285
Total params: 21,127,736
Trainable params: 1,127,686
Non-trainable params: 20,000,050

In [50]:
next(iter(train_loader))['dialog'].shape
# (batch_size, context_length, max_sentence_length)

torch.Size([64, 5, 290])

In [51]:
network(next(iter(train_loader))['dialog']).shape
# (batch_size, context_length, nb_classes)

torch.Size([64, 5, 5])