Skip to content

Commit

Permalink
Basic seq2seq
Browse files Browse the repository at this point in the history
  • Loading branch information
lizeyan committed Sep 16, 2018
1 parent 674438c commit 93bd495
Show file tree
Hide file tree
Showing 11 changed files with 145 additions and 11 deletions.
8 changes: 5 additions & 3 deletions examples/mnist_conditional_gan_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"source": [
"import torch.nn as nn\n",
"from snippets.modules import one_hot\n",
"\n",
"\n",
"class Discriminator(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
Expand Down Expand Up @@ -365,9 +367,9 @@
},
"varInspector": {
"cols": {
"lenName": 16,
"lenType": 16,
"lenVar": 40
"lenName": 16.0,
"lenType": 16.0,
"lenVar": 40.0
},
"kernels_config": {
"python": {
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist_cvae_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
],
"source": [
"from snippets.scaffold import TrainLoop, TestLoop\n",
"from snippets.modules import VAE, MLP, Lambda, one_hot\n",
"from snippets.modules import MLP, Lambda, one_hot\n",
"from snippets.modules.bayesian import VAE\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.distributions as dist\n",
Expand Down
3 changes: 2 additions & 1 deletion examples/mnist_vae_generation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@
],
"source": [
"from snippets.scaffold import TrainLoop, TestLoop\n",
"from snippets.modules import VAE, MLP, Lambda\n",
"from snippets.modules import MLP, Lambda\n",
"from snippets.modules.bayesian import VAE\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.distributions as dist\n",
Expand Down
2 changes: 0 additions & 2 deletions snippets/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from .multi_layer_perceptron import MultiLayerPerceptron, MLP
from .bayesian import *
from .lambda_module import *
from .one_hot import *
from .sequence import *
2 changes: 2 additions & 0 deletions tests/modules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
from .test_one_hot import *
from .test_vae_mlp_lambda import *
from .test_lang import *
from .test_sequence import *
17 changes: 17 additions & 0 deletions tests/modules/test_lang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from snippets.modules.sequence import Lang
import unittest
import torch


class TestLang(unittest.TestCase):
def testLang(self):
lang = Lang(name="name")
lang.add_sentences([
"a,b,c,d",
"c,e,f,g"
], tokenizer=lambda x: x.split(","))
lang.add_sentences([
"e g h j",
])
self.assertListEqual(lang.tensor_to_tokens(torch.tensor([2])), ["a"])
self.assertListEqual(lang.sentence_to_tensor("a b").tolist(), [2, 3, lang.EOS_INDEX])
87 changes: 87 additions & 0 deletions tests/modules/test_sequence.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from snippets.modules.sequence import *
import unittest


class TestSequence(unittest.TestCase):
def test_encoder(self):
seq_length = 7
batch_size = 4
n_layers = 2
n_direction = 2
hidden_size = 13
vocab_size = 11
encoder = EncoderSeq(input_size=vocab_size, hidden_size=hidden_size,
embedding_size=12, n_layers=n_layers,
bidirectional=True if n_direction > 1 else False,
reverse_input=True)
input_tensor = torch.randint(low=0, high=vocab_size,
size=(seq_length, batch_size), dtype=torch.long)
encoder_outputs, hidden = encoder(input_tensor, None)
self.assertEqual(encoder_outputs.size(),
(seq_length, batch_size, hidden_size * n_direction))
self.assertEqual(hidden.size(), (n_layers * n_direction, batch_size, hidden_size))

def test_decoder(self):
seq_length = 7
batch_size = 4
n_layers = 2
n_direction = 2
hidden_size = 13
vocab_size = 11
max_steps = 12
decoder = DecoderSeq(hidden_size=hidden_size, output_size=vocab_size,
embedding_size=12, n_layers=n_layers,
bidirectional=True if n_direction > 1 else False)
input_tensor = torch.randint(low=0, high=vocab_size,
size=(seq_length, batch_size), dtype=torch.long)
decoder_outputs, hidden = decoder(input_tensor, None)
self.assertEqual(decoder_outputs.size(),
(seq_length, batch_size, vocab_size))
self.assertEqual(hidden.size(), (n_layers * n_direction, batch_size, hidden_size))
decoder_outputs = decoder.forward_n(input_tensor[0], None, n_steps=max_steps)
self.assertEqual(decoder_outputs.size(),
(max_steps, batch_size, vocab_size))

def test_trainer(self):
seq_length = 7
batch_size = 4
n_layers = 2
n_direction = 2
hidden_size = 13
vocab_size = 11
encoder = EncoderSeq(input_size=vocab_size, hidden_size=hidden_size,
embedding_size=12, n_layers=n_layers,
bidirectional=True if n_direction > 1 else False,
reverse_input=True)
decoder = DecoderSeq(hidden_size=hidden_size, output_size=vocab_size,
embedding_size=12, n_layers=n_layers,
bidirectional=True if n_direction > 1 else False)
trainer = Seq2SeqTrainer(encoder=encoder, decoder=decoder)
trainer.teach_forcing_prob = 1
input_tensors = list(torch.randint(low=0, high=vocab_size,
size=(random.randint(5, seq_length + 1),),
dtype=torch.long) for _ in range(batch_size))
trainer.step(input_tensors, input_tensors)
trainer.teach_forcing_prob = 0
trainer.step(input_tensors, input_tensors)

def test_inference(self):
seq_length = 7
batch_size = 4
n_layers = 2
n_direction = 2
hidden_size = 13
vocab_size = 11
encoder = EncoderSeq(input_size=vocab_size, hidden_size=hidden_size,
embedding_size=12, n_layers=n_layers,
bidirectional=True if n_direction > 1 else False,
reverse_input=True)
decoder = DecoderSeq(hidden_size=hidden_size, output_size=vocab_size,
embedding_size=12, n_layers=n_layers,
bidirectional=True if n_direction > 1 else False)
target_lang = Lang(name="lang")
target_lang.add_sentence("a b c d e f g h i g k l m n")
inference = Seq2SeqInference(encoder=encoder, decoder=decoder, target_lang=target_lang)
input_tensor = torch.randint(low=0, high=vocab_size,
size=(seq_length,), dtype=torch.long)
inference(input_tensor, max_length=2)
3 changes: 2 additions & 1 deletion tests/modules/test_vae_mlp_lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import torch.distributions as dist
import torch.nn as nn

from snippets.modules import VAE, MLP, Lambda
from snippets.modules import MLP, Lambda
from snippets.modules.bayesian import VAE
from snippets.scaffold import get_gpu_metrics


Expand Down
5 changes: 3 additions & 2 deletions tests/scaffold/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test(self):
metric.collect(3, 3)
self.assertEqual(metric.data, {1: 1, 2: 2, 4: 4, 3: [3, 3, 3], 5: [1, 2, 3, 4, 5]})
self.assertEqual(metric["all"], list({1: 1, 2: 2, 4: 4, 3: [3, 3, 3], 5: [1, 2, 3, 4, 5]}.values()))
self.assertEqual(metric["last"], [1, 2, 3, 4, 5])
self.assertEqual(metric[3], [3, 3, 3])
self.assertEqual(metric[[3, 4, 5]], [[3, 3, 3], 4, [1, 2, 3, 4, 5]])
self.assertEqual(metric.format(1), "test:1.000")
Expand All @@ -24,7 +25,7 @@ def test(self):
ok = True
try:
with Metric.raise_key_error():
metric[128]
_ = metric[128]
except KeyError:
ok = False
finally:
Expand All @@ -34,7 +35,7 @@ def test(self):
ok = True
try:
with Metric.raise_key_error():
metric[0, 4, 128]
_ = metric[0, 4, 128]
except KeyError:
ok = False
finally:
Expand Down
3 changes: 2 additions & 1 deletion tests/utilities/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .test_assertion import *
from .test_assertion import *
from .test_snippets import *
23 changes: 23 additions & 0 deletions tests/utilities/test_snippets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from snippets.utilities import *
import unittest


class TestSnippets(unittest.TestCase):
def test_same_length(self):
a = [1, 2, 3]
b = []
c = [1, 2, 3]
self.assertEqual(in_same_length(), True)
self.assertEqual(in_same_length(a, b, c), False)
self.assertEqual(in_same_length(a, c), True)

def test_split(self):
arr = list(range(10))
a, b, c, d = split(arr, [0.1, 0.2, 0.3, 0.4])
self.assertListEqual(a, [0])
self.assertListEqual(b, [1, 2])
self.assertListEqual(c, [3, 4, 5])
self.assertListEqual(d, [6, 7, 8, 9])
a, b, = split(arr, [0.4, 0.2])
self.assertListEqual(a, [0, 1, 2, 3])
self.assertListEqual(b, [4, 5])

0 comments on commit 93bd495

Please sign in to comment.