In [1]:
import os

from dataset.dataset import *
from pytorch_pipeline.ptpl import PyTorchPipeline
from model.model import *

import torch
import torch.nn as nn

import random
from tqdm import tqdm

In [14]:
hidden_dim = 64
emb_dim = 32

hparams = {
    'hidden_dim': hidden_dim,
    'emb_dim': emb_dim,
    'num_batches' : 32,
    'path2load': "./weights/hidden_size_" + str(hidden_dim) + "_emb_dim_" + str(emb_dim) + ".pt",
}

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [15]:
# prepare data

path = os.path.join("./data", "eng-fra.txt")
perSentence, vocab = load_data(path)
vocab_size = len(vocab)

train_data = {}
val_data = {}

for key, values in perSentence.items():
    arr = values[:]
    random.shuffle(arr)

    val_size = int(0.2 * len(arr))

    val_arr = arr[:val_size]
    train_arr = arr[val_size:]

    train_data[key] = train_arr[:]
    val_data[key] = val_arr[:]

train_size = sum([len(value) for value in train_data.values() ])
val_size = sum([len(value) for value in val_data.values() ])


print("Size of the training data: ", train_size)
print("Size of the validation data: ", val_size)
print()


Total number of sentences is 135842 



100%|██████████| 135842/135842 [00:00<00:00, 580348.20it/s]


The vocabulary size if 5384 

The size of the newly created text data is 122885, 90% of the original text
Size of the training data:  98310
Size of the validation data:  24575



In [16]:
# define a model
embedding = nn.Embedding(vocab_size, emb_dim)
encoder = Encoder(emb_dim, hidden_dim)
decoder = Decoder(hidden_dim, vocab_size)
model = Seq2Seq(embedding, encoder, decoder)
model.to(device)

Your vector field callable (nn.Module) should have both time `t` and state `x` as arguments, we've wrapped it for you.


Seq2Seq(
  (embedding): Embedding(5384, 32)
  (encoder): Encoder(
    (rnn): GRU(32, 64)
    (dropout): Dropout(p=0.3, inplace=False)
  )
  (decoder): Decoder(
    (ode_solve): Neural ODE:
    	- order: 1        
    	- solver: RungeKutta4()
    	- adjoint solver: RungeKutta4()        
    	- tolerances: relative 0.001 absolute 0.001        
    	- adjoint tolerances: relative 0.0001 absolute 0.0001        
    	- num_parameters: 8320        
    	- NFE: 0.0
    (fc): Linear(in_features=64, out_features=5384, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
  )
)

In [17]:
ptpl = PyTorchPipeline(
    project_name = "gru_node",
    configs = {
        'device': device,
        'criterion': None,
        'optimizer': None,
        'train_dataloader': train_data,
        'val_dataloader': val_data,
        'print_logs': True,
        'wb': False,
    },
    hparams = hparams,
    model = model,
)

PyTorch pipeline for gru_node is set up


In [18]:
sampleSentences = random.sample(perSentence[5], k = 10)

In [19]:
for sampleSentence in sampleSentences:
    sampleSentence = " ".join([vocab.itos[w] for w in sampleSentence])
    print(sampleSentence)

i don t date . <EOS>
who s in control ? <EOS>
your cough worries me . <EOS>
the pain was unbearable . <EOS>
i like your style . <EOS>
he s a <UNK> . <EOS>
anyone can do it . <EOS>
you ve been infected . <EOS>
i love traveling alone . <EOS>
what station is it ? <EOS>


In [20]:
batch = torch.cat([torch.tensor(s).view(1, -1) for s in sampleSentences])
print(batch.shape)

torch.Size([10, 6])


In [21]:
ptpl.load(hparams['path2load'])

In [24]:
output_data = ptpl.predict(batch, "train")

In [26]:
print(output_data.shape)

torch.Size([6, 10, 5384])


In [32]:
temp = output_data.transpose(1, 0).argmax(2)
print(temp.shape)

for i in range(10):
    ground_truth_sentence = " ".join([vocab.itos[w] for w in sampleSentences[i]])
    torch_sentence = temp[i].cpu()
    sentence = " ".join([vocab.itos[w.item()] for w in torch_sentence])
    print(sentence)
#     print(torch_sentence)
#     print(" ")

torch.Size([10, 6])
i don t date . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
who s in control ? <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
your cough worries me . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
the pain was unbearable . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
i like your style . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
he s a <UNK> . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
anyone can do it . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
you ve been infected . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
i love traveling alone . <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
what station is it ? <EOS>
<EOS> <EOS> <EOS> <EOS> <EOS> <EOS>
