Skip to content
This repository has been archived by the owner on Apr 22, 2022. It is now read-only.

Commit

Permalink
Possibly fixing up convseq2seq.
Browse files Browse the repository at this point in the history
  • Loading branch information
gugarosa committed Jun 9, 2020
1 parent 6a4a053 commit 4fb12f7
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 17 deletions.
5 changes: 2 additions & 3 deletions examples/applications/generation/conv_seq2seq_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
file_path = 'data/generative/chapter1_harry.txt'

# Defines a datatype for further tensor conversion
source = Field(lower=True)
source = Field(lower=True, batch_first=True)

# Creates the GenerativeDataset
dataset = GenerativeDataset(file_path, source)
Expand All @@ -33,7 +33,6 @@
conv_seq2seq.fit(train_iterator, epochs=10)

# Generating artificial text
text = conv_seq2seq.generate_text(
'Mr. Dursley', source, length=100, temperature=0.5)
text = conv_seq2seq.generate_text('Mr. Dursley', source, length=100, temperature=0.5)

print(' '.join(text))
45 changes: 45 additions & 0 deletions examples/applications/translating/conv_seq2seq_translation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from torchtext.data import BucketIterator, Field

import textformer.utils.visualization as v
from textformer.datasets.translation import TranslationDataset
from textformer.models.conv_seq2seq import ConvSeq2Seq

# Defines the device which should be used, e.g., `cpu` or `cuda`
device = 'cpu'

# Defines the input file
file_path = 'data/translation/europarl'

# Defines datatypes for further tensor conversion
source = Field(init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)
target = Field(init_token='<sos>', eos_token='<eos>', lower=True, batch_first=True)

# Creates the TranslationDataset
train_dataset, val_dataset, test_dataset = TranslationDataset.splits(
file_path, ('.en', '.pt'), (source, target))

# Builds the vocabularies
source.build_vocab(train_dataset, min_freq=1)
target.build_vocab(train_dataset, min_freq=1)

# Gathering the <pad> token index for further ignoring
target_pad_index = target.vocab.stoi[target.pad_token]

# Creates a bucket iterator
train_iterator, val_iterator, test_iterator = BucketIterator.splits(
(train_dataset, val_dataset, test_dataset), batch_size=2, sort=False, device=device)

# Creating the ConvSeq2Seq model
conv_seq2seq = ConvSeq2Seq(n_input=len(source.vocab), n_output=len(target.vocab),
n_hidden=512, n_embedding=256, n_layers=1, kernel_size=3,
scale=0.5, max_length=200, ignore_token=target_pad_index,
init_weights=None, device=device)

# Training the model
conv_seq2seq.fit(train_iterator, val_iterator, epochs=10)

# Evaluating the model
conv_seq2seq.evaluate(test_iterator)

# Calculating BLEU score
conv_seq2seq.bleu(test_dataset, source, target)
23 changes: 10 additions & 13 deletions textformer/models/conv_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,12 @@ def generate_text(self, start, field, length=10, temperature=1.0):
# Performs the initial encoding
conv, output = self.E(tokens)

# Removes the batch dimension from the tokens
tokens = tokens.squeeze(0)

# For every possible length
for i in range(length):
# Inhibits the gradient from updating the parameters
with torch.no_grad():
# Decodes only the last token, i.e., last sampled token
preds, _ = self.D(tokens[-1].unsqueeze(0), conv, output)
preds, _ = self.D(tokens[:,-1].unsqueeze(0), conv, output)

# Regularize the prediction with the temperature
preds /= temperature
Expand All @@ -122,10 +119,10 @@ def generate_text(self, start, field, length=10, temperature=1.0):
sampled_token = distributions.Categorical(logits=preds).sample()

# Concatenate the sampled token with the input tokens
tokens = torch.cat((tokens, sampled_token))
tokens = torch.cat((tokens, sampled_token), axis=1)

# Decodes the tokens into text
sampled_text = [field.vocab.itos[t] for t in tokens]
sampled_text = [field.vocab.itos[t] for t in tokens.squeeze(0)]

return sampled_text

Expand Down Expand Up @@ -161,7 +158,7 @@ def translate_text(self, start, src_field, trg_field, max_length=10):
# Inhibits the gradient from updating the parameters
with torch.no_grad():
# Performs the initial encoding
hidden = context = self.E(tokens)
conv, output = self.E(tokens)

# Creating a tensor with `<sos>` token from target vocabulary
tokens = torch.LongTensor([trg_field.vocab.stoi[trg_field.init_token]]).unsqueeze(0).to(self.device)
Expand All @@ -171,23 +168,23 @@ def translate_text(self, start, src_field, trg_field, max_length=10):
# Inhibits the gradient from updating the parameters
with torch.no_grad():
# Decodes only the last token, i.e., last sampled token
preds, hidden = self.D(tokens[-1], hidden, context)
preds, atts = self.D(tokens, conv, output)

# Samples a token using argmax
sampled_token = preds.argmax(1)
sampled_token = preds.argmax(2)[:,-1]

# Concatenate the sampled token with the input tokens
tokens = torch.cat((tokens, sampled_token.unsqueeze(0)))
tokens = torch.cat((tokens, sampled_token.unsqueeze(0)), axis=1)

# Check if has reached the end of string
if sampled_token == trg_field.vocab.stoi[trg_field.eos_token]:
# If yes, breaks the loop
break

# Decodes the tokens into text
translated_text = [trg_field.vocab.itos[t] for t in tokens]
translated_text = [trg_field.vocab.itos[t] for t in tokens.squeeze(0)]

return translated_text[1:]
return translated_text[1:], atts

def bleu(self, dataset, src_field, trg_field, max_length=50, n_grams=4):
"""Calculates BLEU score over a dataset from its difference between targets and predictions.
Expand Down Expand Up @@ -215,7 +212,7 @@ def bleu(self, dataset, src_field, trg_field, max_length=50, n_grams=4):
# For every example in the dataset
for data in dataset:
# Calculates the prediction, i.e., translated text
pred = self.translate_text(data.text, src_field, trg_field, max_length)
pred, _ = self.translate_text(data.text, src_field, trg_field, max_length)

# Appends the prediction without the `<eos>` token
preds.append(pred[:-1])
Expand Down
2 changes: 1 addition & 1 deletion textformer/models/decoders/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, y, enc_c, enc_o):
The output and attention values.
"""

# Creates the positions tensor
pos = torch.arange(0, y.shape[1]).unsqueeze(0).repeat(y.shape[0], 1)

Expand Down

0 comments on commit 4fb12f7

Please sign in to comment.