https://arxiv.org/abs/1804.02341
$ tar -xzvf assets/dataset.tar.gz -C assets
$ python create_ds.py
This will create all images in the dataset: 8 colors (red, blue, green, white, gray, yellow, cyan, magenta) x 5 shapes (box, sphere, cylinder, torus (donut), ellipsoid) x 100 samples in different locations and angles.
Command line options:
--n_samples=100
--seed=0
- python 3.6
- pytorch == 0.4
- matplotlib
$ python train.py
Command line options:
--lr=6e-4 (learning rate)
--batch_size=50 (number of images in a batch)
--num_rounds=20000 (number of total training rounds)
--num_games_per_round=20 (number of games per round)
--vocab_size=5 (vocabulary size)
--max_sentence_len=20 (maximum sentence length)
--data_n_samples=100 (number of samples per color, shape combination)
Output sample from round #9702
...
message: 'ec', speaker object: ('ellipsoid', 'red'), speaker score: 0.96, listener object: ('ellipsoid', 'red'), label: 1, listener score: 0.89
message: 'ec', speaker object: ('box', 'blue'), speaker score: 0.96, listener object: ('box', 'blue'), label: 1, listener score: 0.97
message: 'cbe', speaker object: ('torus', 'magenta'), speaker score: 0.96, listener object: ('torus', 'magenta'), label: 1, listener score: 0.96
message: 'ecb', speaker object: ('sphere', 'magenta'), speaker score: 0.97, listener object: ('sphere', 'magenta'), label: 1, listener score: 0.97
message: 'ea', speaker object: ('box', 'white'), speaker score: 0.96, listener object: ('box', 'white'), label: 1, listener score: 0.96
message: 'ee', speaker object: ('sphere', 'cyan'), speaker score: 0.96, listener object: ('sphere', 'cyan'), label: 1, listener score: 0.97
message: 'eeb', speaker object: ('box', 'gray'), speaker score: 0.97, listener object: ('box', 'gray'), label: 1, listener score: 0.97
message: 'ecac', speaker object: ('box', 'cyan'), speaker score: 0.97, listener object: ('box', 'cyan'), label: 1, listener score: 0.95
message: 'ecc', speaker object: ('torus', 'green'), speaker score: 0.97, listener object: ('torus', 'green'), label: 1, listener score: 0.95
message: 'bed', speaker object: ('ellipsoid', 'cyan'), speaker score: 0.96, listener object: ('ellipsoid', 'cyan'), label: 1, listener score: 0.73
message: 'b', speaker object: ('sphere', 'white'), speaker score: 0.97, listener object: ('sphere', 'white'), label: 1, listener score: 0.97
message: 'bee', speaker object: ('torus', 'white'), speaker score: 0.97, listener object: ('torus', 'white'), label: 1, listener score: 0.92
message: 'cdb', speaker object: ('box', 'magenta'), speaker score: 0.96, listener object: ('box', 'yellow'), label: 0, listener score: 0.00
message: 'ec', speaker object: ('torus', 'green'), speaker score: 0.97, listener object: ('torus', 'blue'), label: 0, listener score: 0.00
message: 'ebe', speaker object: ('ellipsoid', 'yellow'), speaker score: 0.97, listener object: ('ellipsoid', 'green'), label: 0, listener score: 0.00
message: 'cd', speaker object: ('torus', 'red'), speaker score: 0.95, listener object: ('torus', 'cyan'), label: 0, listener score: 0.00
message: 'cb', speaker object: ('sphere', 'red'), speaker score: 0.96, listener object: ('sphere', 'blue'), label: 0, listener score: 0.00
message: 'ebc', speaker object: ('cylinder', 'white'), speaker score: 0.96, listener object: ('cylinder', 'cyan'), label: 0, listener score: 0.00
message: 'd', speaker object: ('ellipsoid', 'white'), speaker score: 0.99, listener object: ('ellipsoid', 'blue'), label: 0, listener score: 0.00
message: 'cda', speaker object: ('torus', 'red'), speaker score: 0.96, listener object: ('torus', 'white'), label: 0, listener score: 0.00
message: 'ebc', speaker object: ('cylinder', 'white'), speaker score: 0.96, listener object: ('cylinder', 'gray'), label: 0, listener score: 0.07
message: 'ecb', speaker object: ('box', 'blue'), speaker score: 0.96, listener object: ('box', 'white'), label: 0, listener score: 0.00
message: 'eeb', speaker object: ('ellipsoid', 'green'), speaker score: 0.97, listener object: ('ellipsoid', 'magenta'), label: 0, listener score: 0.00
message: 'ecac', speaker object: ('box', 'cyan'), speaker score: 0.97, listener object: ('box', 'white'), label: 0, listener score: 0.00
message: 'ebc', speaker object: ('sphere', 'blue'), speaker score: 0.97, listener object: ('sphere', 'white'), label: 0, listener score: 0.00
message: 'ed', speaker object: ('torus', 'gray'), speaker score: 0.95, listener object: ('torus', 'red'), label: 0, listener score: 0.00
message: 'cc', speaker object: ('cylinder', 'red'), speaker score: 1.00, listener object: ('cylinder', 'blue'), label: 0, listener score: 0.00
message: 'ede', speaker object: ('ellipsoid', 'blue'), speaker score: 0.96, listener object: ('sphere', 'blue'), label: 0, listener score: 0.00
message: 'ce', speaker object: ('torus', 'yellow'), speaker score: 0.97, listener object: ('sphere', 'yellow'), label: 0, listener score: 0.00
message: 'cda', speaker object: ('torus', 'red'), speaker score: 0.96, listener object: ('cylinder', 'red'), label: 0, listener score: 0.00
message: 'cdb', speaker object: ('box', 'magenta'), speaker score: 0.96, listener object: ('torus', 'magenta'), label: 0, listener score: 0.00
message: 'ebe', speaker object: ('cylinder', 'gray'), speaker score: 0.97, listener object: ('ellipsoid', 'gray'), label: 0, listener score: 0.00
message: 'ce', speaker object: ('torus', 'yellow'), speaker score: 0.96, listener object: ('sphere', 'yellow'), label: 0, listener score: 0.00
message: 'ccc', speaker object: ('cylinder', 'magenta'), speaker score: 1.00, listener object: ('box', 'magenta'), label: 0, listener score: 0.00
message: 'eb', speaker object: ('torus', 'cyan'), speaker score: 0.96, listener object: ('ellipsoid', 'cyan'), label: 0, listener score: 0.00
message: 'beb', speaker object: ('sphere', 'gray'), speaker score: 0.97, listener object: ('torus', 'gray'), label: 0, listener score: 0.00
message: 'cec', speaker object: ('cylinder', 'cyan'), speaker score: 0.98, listener object: ('sphere', 'cyan'), label: 0, listener score: 0.00
message: 'cdb', speaker object: ('box', 'magenta'), speaker score: 0.96, listener object: ('sphere', 'green'), label: 0, listener score: 0.00
message: 'ebc', speaker object: ('sphere', 'blue'), speaker score: 0.97, listener object: ('box', 'red'), label: 0, listener score: 0.00
message: 'ebd', speaker object: ('torus', 'blue'), speaker score: 0.95, listener object: ('sphere', 'green'), label: 0, listener score: 0.67
message: 'eebb', speaker object: ('ellipsoid', 'green'), speaker score: 0.97, listener object: ('ellipsoid', 'green'), label: 1, listener score: 0.96
message: 'ccd', speaker object: ('box', 'red'), speaker score: 0.98, listener object: ('cylinder', 'red'), label: 0, listener score: 0.00
message: 'b', speaker object: ('sphere', 'white'), speaker score: 0.97, listener object: ('cylinder', 'green'), label: 0, listener score: 0.00
message: 'c', speaker object: ('box', 'red'), speaker score: 0.95, listener object: ('sphere', 'red'), label: 0, listener score: 0.00
message: 'b', speaker object: ('ellipsoid', 'gray'), speaker score: 0.96, listener object: ('torus', 'red'), label: 0, listener score: 0.00
message: 'ebb', speaker object: ('sphere', 'green'), speaker score: 0.95, listener object: ('box', 'white'), label: 0, listener score: 0.00
message: 'ea', speaker object: ('box', 'white'), speaker score: 0.96, listener object: ('torus', 'white'), label: 0, listener score: 0.00
message: 'ed', speaker object: ('torus', 'gray'), speaker score: 0.96, listener object: ('box', 'white'), label: 0, listener score: 0.92
message: 'bee', speaker object: ('sphere', 'gray'), speaker score: 0.97, listener object: ('cylinder', 'white'), label: 0, listener score: 0.00
message: 'ed', speaker object: ('torus', 'gray'), speaker score: 0.96, listener object: ('cylinder', 'gray'), label: 0, listener score: 0.00
batch accuracy 0.96
batch loss 0.09183049947023392
*******
Round average accuracy: 96.90
Round average sentence length: 2.6
Round average loss: 0.1
This repository was written with the intention to be as close as possible to the paper's described methods.
Two differences are known:
- This implementation contains less convolution layers, and less filters in each layer
- This dataset is using the torus (donut) shape instead of the paper's capsule