In [1]:
__authors__ = "Anton Gochev, Jaro Habr, Yan Jiang, Samuel Kahn"
__version__ = "XCS224u, Stanford, Spring 2021"

## Colours with static embeddings

1. [Setup](#Setup)
1. [Dataset](#Dataset)
    1. [Filtered Corpus](#Filtered-Corpus)
    1. [Bake-Off Corpus](#Bake-Off-Corpus)
1. [Baseline-System](#Baseline-System)
1. [Experiments](#Experiments)
  1. [BERT Embeddings](#BERT-Embeddings)
  2. [XLNet Embeddings](#XLNet-Embeddings)
  3. [RoBERTa Embeddings](#RoBERTa-Embeddings)
  4. [ELECTRA Embeddings](#ELECTRA-Embeddings)

## Setup

This notebook explores the performance of the basemodel with using different pre-trained static embeddings extracted from transformers such as BERT, XLNet, RoBERTa, ELECTRA

In [2]:
from utils.colors import ColorsCorpusReader
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from utils.torch_color_describer import ContextualColorDescriber, create_example_dataset
import utils.utils as utils
from utils.utils import UNK_SYMBOL, START_SYMBOL, END_SYMBOL
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
import numpy as np
from baseline.model import (
    BaselineTokenizer, BaselineColorEncoder,
    BaselineEmbedding, BaselineLSTMDescriber, GloVeEmbedding
)
from experiment.vision import ConvolutionalColorEncoder

from transformers import (
    BertTokenizer, BertModel,
    XLNetTokenizer, XLNetModel,
    RobertaTokenizer, RobertaModel,
    ElectraTokenizer, ElectraModel,    
)

import utils.model_utils as mu
import experiment.helper as eh

In [3]:
utils.fix_random_seeds()

## Dataset

This exploration of the dataset counts the examples for different classes and plots the words distribition in order to see any data imbalance issues.

### Filtered Corpus

The filtered corpus is the full dataset used in assignment 4. The following code looks at the composition of the dataset, the number of example in each condition as well as the word count used in the color descriptions.

In [4]:
COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "filteredCorpus.csv"
)

In [5]:
corpus = ColorsCorpusReader(
    COLORS_SRC_FILENAME,
    word_count=None,
    normalize_colors=True
)

In [6]:
examples = list(corpus.read())

In [7]:
len(examples)

46994

In [8]:
subset_examples = mu.extract_colour_examples(examples, from_word_count=5)

In [9]:
close_examples = [example for example in examples if example.condition == "close"]
split_examples = [example for example in examples if example.condition == "split"]
far_examples = [example for example in examples if example.condition == "far"]

In [10]:
print(f"close: {len(close_examples)}")
print(f"split: {len(split_examples)}")
print(f"far: {len(far_examples)}")

close: 15519
split: 15693
far: 15782


To understand the datasets (training and bake-off) in more details refer to [colors_in_context.ipynb](colors_in_context.ipynb). The notebook shows the distribution of the colours examples among the different splits.

### Bake-Off Corpus

The following code analyses the bake-off dataset. We will look at the number of examples for each of the conditions as well as the word count used to described the colors.

In [11]:
BAKE_OFF_COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "cs224u-colors-bakeoff-data.csv"
)

In [12]:
bake_off_corpus = ColorsCorpusReader(
    BAKE_OFF_COLORS_SRC_FILENAME,
    word_count=None,
    normalize_colors=True
)

In [13]:
bake_off_examples = list(bake_off_corpus.read())

In [14]:
import pickle
emb_path = os.path.join(
    "data", "colors", "resnet18_color_embeddings.pickle"
)
file = open(emb_path,'rb')
resnet_emb = pickle.load(file)
file.close()


In [15]:
len(resnet_emb)

13890

## Baseline-System

This baseline system is based on assignment 4 and we use different token embeddings and sequences.

### Model training - full dataset

The full color context dataset is used for final baseline model training.

In [16]:
rawcols, texts = zip(*[[ex.colors, ex.contents] for ex in examples]) 
b_raw_colors_test, b_texts_test = zip(*[[ex.colors, ex.contents] for ex in bake_off_examples])

raw_colors_train, raw_colors_test, texts_train, texts_test = \
        train_test_split(rawcols, texts)

def create_text_tokens(tokenizer):    
    tokens_train = [ mu.tokenize_colour_description(text, tokenizer, True) for text in texts_train ]
    tokens_test = [ mu.tokenize_colour_description(text, tokenizer, True) for text in texts_test ]
    tokens_bo = [ mu.tokenize_colour_description(text, tokenizer, True) for text in b_texts_test ]
    
    return tokens_train, tokens_test, tokens_bo
    
def create_colours_data():
    color_encoder = ConvolutionalColorEncoder(True)
    colors_train = [ color_encoder.encode_color_context(colors) for colors in raw_colors_train ]
    colors_test = [ color_encoder.encode_color_context(colors) for colors in raw_colors_test ]
    colors_bo = [ color_encoder.encode_color_context(colors) for colors in b_raw_colors_test ]
    
    return colors_train, colors_test, colors_bo

## Experiments

Results from the experiments:

| Model | Embeddings | Unit | h-params | Protocol | Training Results | Bake-off Results |
| --- | --- | --- | --- | --- | --- | --- |
| BERT 'bert-base-cased' | ResNet18 + static | LSTM | default | Stopping after epoch 16. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 186.72212171554565 CPU times: user 51min 9s, sys: 21min 52s, total: 1h 13min 1s Wall time: 23min 55s | {'listener_accuracy': 0.3704996169886799, 'corpus_bleu': 0.4271658095590124} | {'listener_accuracy': 0.34711964549483015, 'corpus_bleu': 0.6166938495719256} |
| BERT 'bert-base-cased' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 112. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.370091795921326 CPU times: user 6h 7min 3s, sys: 2h 36min 50s, total: 8h 43min 53s Wall time: 2h 53min 23s | {'listener_accuracy': 0.8209209294408035, 'corpus_bleu': 0.5904509830096333} | {'listener_accuracy': 0.9148202855736091, 'corpus_bleu': 0.7719649525541963} |
| BERT 'bert-base-cased' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 89. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 72.01054859161377 CPU times: user 5h 32min 19s, sys: 2h 8min 14s, total: 7h 40min 33s Wall time: 2h 47min 38s | {'listener_accuracy': 0.8190484296535875, 'corpus_bleu': 0.6143887490272109} | {'listener_accuracy': 0.8956179222058099, 'corpus_bleu': 0.7866291971865217} |
| XLNet 'xlnet-base-cased' | ResNet18 + static | LSTM | default | Stopping after epoch 145. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 23.32861888408661 CPU times: user 8h 17min 22s, sys: 3h 23min 31s, total: 11h 40min 53s Wall time: 3h 46min 43s | {'listener_accuracy': 0.7549578687547877, 'corpus_bleu': 0.5648871242223105} | {'listener_accuracy': 0.8670605612998523, 'corpus_bleu': 0.7533564516719005} |
| XLNet 'xlnet-base-cased' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 125. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 33.39379119873047 CPU times: user 7h 32min 8s, sys: 3h 6min 32s, total: 10h 38min 40s Wall time: 3h 28min 54s | {'listener_accuracy': 0.835645586858456, 'corpus_bleu': 0.5903786021680395} | {'listener_accuracy': 0.9113737075332349, 'corpus_bleu': 0.766034610276092} |
| XLNet 'xlnet-base-cased' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 76. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 89.37767934799194 CPU times: user 5h 32min 12s, sys: 2h 1min 52s, total: 7h 34min 4s Wall time: 2h 42min 29s | {'listener_accuracy': 0.8366669503787556, 'corpus_bleu': 0.581864699762294} | {'listener_accuracy': 0.9162973904480551, 'corpus_bleu': 0.7633163869123465} |
| RoBERTa 'roberta-base' | ResNet18 + static | LSTM | default | Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 197.75724983215332 CPU times: user 52min 7s, sys: 19min 12s, total: 1h 11min 20s Wall time: 22min 11s | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| RoBERTa 'roberta-base' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 201.13099479675293 CPU times: user 50min 9s, sys: 19min 13s, total: 1h 9min 22s Wall time: 21min 49s | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| RoBERTa 'roberta-base' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 196.342670917511 CPU times: user 1h 38s, sys: 20min 40s, total: 1h 21min 18s Wall time: 28min 6s | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18 + static | LSTM | default | Stopping after epoch 200. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 12.498226076364517 CPU times: user 9h 58min 7s, sys: 3h 51min 46s, total: 13h 49min 54s Wall time: 4h 28min 38s | {'listener_accuracy': 0.8018554770618777, 'corpus_bleu': 0.5976330421393174} | {'listener_accuracy': 0.8936484490398818, 'corpus_bleu': 0.7839809145958113} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 90. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 70.57668948173523 CPU times: user 4h 9min 47s, sys: 1h 40min 13s, total: 5h 50min Wall time: 1h 54min 18s | {'listener_accuracy': 0.8173461571197549, 'corpus_bleu': 0.6020019877608638} | {'listener_accuracy': 0.913343180699163, 'corpus_bleu': 0.7838522381057177} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 101. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 55.01819384098053 CPU times: user 6h 42min 17s, sys: 2h 18min 46s, total: 9h 1min 3s Wall time: 3h 14min 16s | {'listener_accuracy': 0.8438164950208529, 'corpus_bleu': 0.6099872860729817} | {'listener_accuracy': 0.9192516001969473, 'corpus_bleu': 0.7889350641465928} |

In [17]:
%time colors_train, colors_test, colors_bo = create_colours_data()

Using cache found in /Users/antongochev/.cache/torch/hub/pytorch_vision_v0.6.0


CPU times: user 3h 52min 57s, sys: 23min 55s, total: 4h 16min 53s
Wall time: 2h 53min 20s


In [29]:
colors = {'train': colors_train, 'test': colors_test, 'bo': colors_bo}

In [19]:
hidden_dims = [50, 100, 150, 250]
start_index = 7000
end_index = 15000

### BERT Embeddings

In [18]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = BertModel.from_pretrained('bert-base-cased')

In [19]:
tokens_train, tokens_test, tokens_bo = create_text_tokens(bert_tokenizer)

In [20]:
%time bert_embeddings, bert_vocab = mu.extract_input_embeddings(texts_train, bert_model, bert_tokenizer, add_special_tokens=True)

CPU times: user 7.51 s, sys: 35.6 ms, total: 7.54 s
Wall time: 7.55 s


***Experiment with different hidden_dim size***

| Results - BERT LSTM - static embeddings + ResNet+Fourier |
| ------- |
| Stopping after epoch 44. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.76516532897949 <br /> train 50 - {'listener_accuracy': 0.45654949357392116, 'corpus_bleu': 0.3822134017057766} <br /> bake-off 50 - {'listener_accuracy': 0.42146725750861647, 'corpus_bleu': 0.5439987449504363} |
| Stopping after epoch 53. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.292728900909424 <br /> train 100 - {'listener_accuracy': 0.5274491446080517, 'corpus_bleu': 0.5350260003935228} <br /> bake-off 100 - {'listener_accuracy': 0.5701624815361891, 'corpus_bleu': 0.7108937241938815} |
| Stopping after epoch 94. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.2493839263916 <br /> train 150 - {'listener_accuracy': 0.5913694782534684, 'corpus_bleu': 0.5246605959686265} <br /> bake-off 150 - {'listener_accuracy': 0.6814377154111275, 'corpus_bleu': 0.701929740281104} |
| Stopping after epoch 62. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 43.86564564704895 <br /> train 250 - {'listener_accuracy': 0.597242318495191, 'corpus_bleu': 0.5090204246106848} <br /> bake-off 250 - {'listener_accuracy': 0.6578040374199902, 'corpus_bleu': 0.6742907898833734} |

In [31]:
tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

#### for GRU set unit='GRU'
%time eh.run_hiddim_options(hidden_dims, start_index, end_index, bert_vocab, bert_embeddings, colors, tokens)

***Baseline model using BERT pretrained embeddings and vocab*

In [88]:
model = BaselineLSTMDescriber(
    bert_vocab,
    embedding=bert_embeddings,
    early_stopping=True,
    hidden_dim=100
)

In [89]:
%time _ = model.fit(colors_train, tokens_train)

Stopping after epoch 20. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 19.67329978942871

CPU times: user 3min 46s, sys: 1min 7s, total: 4min 53s
Wall time: 4min 16s


Evaluate on test data

In [90]:
%time model.evaluate(colors_test, tokens_test)

CPU times: user 55.8 s, sys: 1.47 s, total: 57.2 s
Wall time: 56.5 s


{'listener_accuracy': 0.4415694952761937, 'corpus_bleu': 0.15}

Evaluate on bake-off data

In [91]:
%time model.evaluate(colors_bo, tokens_bo)

CPU times: user 9.11 s, sys: 200 ms, total: 9.31 s
Wall time: 9.22 s


{'listener_accuracy': 0.4741506646971935, 'corpus_bleu': 0.15}

### XLNet Embeddings

In [20]:
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetModel.from_pretrained('xlnet-base-cased')

In [21]:
tokens_train, tokens_test, tokens_bo = create_text_tokens(xlnet_tokenizer)

In [22]:
%time xlnet_embeddings, xlnet_vocab = mu.extract_input_embeddings(texts_train, xlnet_model, xlnet_tokenizer, add_special_tokens=True)

CPU times: user 7.52 s, sys: 35.2 ms, total: 7.55 s
Wall time: 7.56 s


***Experiment with different hidden_dim size***

| Results - XLNet LSTM - static embeddings + ResNet+Fourier |
| ------- |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 51.43806719779968 <br /> train 50 - {'listener_accuracy': 0.40028938633075156, 'corpus_bleu': 0.11601226111917491} <br /> bake-off 50 - {'listener_accuracy': 0.34859675036927623, 'corpus_bleu': 0.12250562016041742} |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 50.56151986122131 <br /> train 100 - {'listener_accuracy': 0.38420291088603287, 'corpus_bleu': 0.10000000000000002} <br /> bake-off 100 - {'listener_accuracy': 0.38010832102412606, 'corpus_bleu': 0.10000000000000002} |
| Stopping after epoch 65. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.73309278488159 <br /> train 150 - {'listener_accuracy': 0.5334922121031577, 'corpus_bleu': 0.4307617849191753} <br /> bake-off 150 - {'listener_accuracy': 0.6011816838995568, 'corpus_bleu': 0.3972355769230769} |
| Stopping after epoch 81. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.55567979812622 <br /> train 250 - {'listener_accuracy': 0.5929015235339178, 'corpus_bleu': 0.5084326142252633} <br /> bake-off 250 - {'listener_accuracy': 0.6932545544066963, 'corpus_bleu': 0.6762724153392196} |

In [20]:
tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

In [22]:
#for GRU set unit='GRU'
%time eh.run_hiddim_options(hidden_dims, start_index, end_index, xlnet_vocab, xlnet_embeddings, colors, tokens)

***Baseline model using XLNet pretrained embeddings and vocab***

In [23]:
model = BaselineLSTMDescriber(
    xlnet_vocab,
    embedding=xlnet_embeddings,
    early_stopping=True,
    hidden_dim=100
)

In [55]:
%time _ = model.fit(colors_train, tokens_train)

Stopping after epoch 76. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 89.37767934799194

CPU times: user 5h 32min 12s, sys: 2h 1min 52s, total: 7h 34min 4s
Wall time: 2h 42min 29s


Evaluate on test data

In [56]:
%time model.evaluate(colors_test, tokens_test)

CPU times: user 1min 2s, sys: 3.85 s, total: 1min 6s
Wall time: 58.2 s


{'listener_accuracy': 0.8366669503787556, 'corpus_bleu': 0.581864699762294}

Evaluate on bake-off data

In [57]:
%time model.evaluate(colors_bo, tokens_bo)

CPU times: user 9.78 s, sys: 550 ms, total: 10.3 s
Wall time: 9.18 s


{'listener_accuracy': 0.9162973904480551, 'corpus_bleu': 0.7633163869123465}

### RoBERTa Embeddings

In [58]:
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaModel.from_pretrained('roberta-base')

In [59]:
tokens_train, tokens_test, tokens_bo = create_text_tokens(roberta_tokenizer)

In [60]:
%time roberta_embeddings, roberta_vocab = mu.extract_input_embeddings(texts_train, roberta_model, roberta_tokenizer, add_special_tokens=True)

CPU times: user 7.62 s, sys: 111 ms, total: 7.73 s
Wall time: 7.79 s


***Baseline model using RoBERTa pretrained embeddings and vocab***

In [61]:
model = BaselineLSTMDescriber(
    roberta_vocab,
    embedding=roberta_embeddings,
    early_stopping=True,
    hidden_dim=100
)

In [62]:
%time _ = model.fit(colors_train, tokens_train)

Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 196.342670917511

CPU times: user 1h 38s, sys: 20min 40s, total: 1h 21min 18s
Wall time: 28min 6s


Evaluate on test data

In [63]:
%time model.evaluate(colors_test, tokens_test)

CPU times: user 1min 13s, sys: 5.76 s, total: 1min 19s
Wall time: 1min 13s


{'listener_accuracy': 1.0, 'corpus_bleu': 0.2}

Evaluate on bake-off data

In [64]:
%time model.evaluate(colors_bo, tokens_bo)

CPU times: user 10.7 s, sys: 544 ms, total: 11.2 s
Wall time: 10.1 s


{'listener_accuracy': 1.0, 'corpus_bleu': 0.2}

### ELECTRA Embeddings

In [25]:
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
electra_model = ElectraModel.from_pretrained('google/electra-small-discriminator')

In [26]:
tokens_train, tokens_test, tokens_bo = create_text_tokens(electra_tokenizer)

In [27]:
%time electra_embeddings, electra_vocab = mu.extract_input_embeddings(texts_train, electra_model, electra_tokenizer, add_special_tokens=True)

CPU times: user 8.52 s, sys: 6.04 ms, total: 8.53 s
Wall time: 8.53 s


***Experiment with different hidden_dim size***

| Results - ELECTRA LSTM - static embeddings + ResNet+Fourier|
| ------- |
| Stopping after epoch 23. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 47.63471579551697 <br /> train 50 - {'listener_accuracy': 0.3909268873946719, 'corpus_bleu': 0.15} <br /> bake-off 50 - {'listener_accuracy': 0.3845396356474643, 'corpus_bleu': 0.15} |
| Stopping after epoch 21. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 49.45027256011963 <br /> train 100 - {'listener_accuracy': 0.42471699719125033, 'corpus_bleu': 0.5404861636039842} <br /> bake-off 100 - {'listener_accuracy': 0.4244214672575086, 'corpus_bleu': 0.6658470617162527} |
| Stopping after epoch 63. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.79938507080078 <br /> train 150 - {'listener_accuracy': 0.5634522086986127, 'corpus_bleu': 0.5305960626612526} <br /> bake-off 150 - {'listener_accuracy': 0.6149679960610537, 'corpus_bleu': 0.6977797477288923} |
| Stopping after epoch 59. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.45513558387756 <br /> train 250 - {'listener_accuracy': 0.5720486849944676, 'corpus_bleu': 0.5589953211796204} <br /> bake-off 250 - {'listener_accuracy': 0.6297390448055146, 'corpus_bleu': 0.7396116726551871} |

In [20]:
tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

In [22]:
#for GRU set unit='GRU'
%time eh.run_hiddim_options(hidden_dims, start_index, end_index, electra_vocab, electra_embeddings, colors, tokens)

***Baseline model using ELECTRA pretrained embeddings and vocab***

In [68]:
model = BaselineLSTMDescriber(
    electra_vocab,
    embedding=electra_embeddings,
    early_stopping=True,
    hidden_dim=100
)

In [28]:
run_options(hidden_dims, 7000, 15000, electra_vocab, electra_embeddings)

Stopping after epoch 23. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 47.63471579551697

train 50 - {'listener_accuracy': 0.3909268873946719, 'corpus_bleu': 0.15}
bake-off 50 - {'listener_accuracy': 0.3845396356474643, 'corpus_bleu': 0.15}


Stopping after epoch 21. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 49.45027256011963

train 100 - {'listener_accuracy': 0.42471699719125033, 'corpus_bleu': 0.5404861636039842}
bake-off 100 - {'listener_accuracy': 0.4244214672575086, 'corpus_bleu': 0.6658470617162527}


Stopping after epoch 63. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.79938507080078

train 150 - {'listener_accuracy': 0.5634522086986127, 'corpus_bleu': 0.5305960626612526}
bake-off 150 - {'listener_accuracy': 0.6149679960610537, 'corpus_bleu': 0.6977797477288923}


Stopping after epoch 59. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.45513558387756

train 250 - {'listener_accuracy': 0.5720486849944676, 'corpus_bleu': 0.5589953211796204}
bake-off 250 - {'listener_accuracy': 0.6297390448055146, 'corpus_bleu': 0.7396116726551871}


In [69]:
%time _ = model.fit(colors_train, tokens_train)

Stopping after epoch 101. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 55.01819384098053

CPU times: user 6h 42min 17s, sys: 2h 18min 46s, total: 9h 1min 3s
Wall time: 3h 14min 16s


Evaluate on test data

In [70]:
%time model.evaluate(colors_test, tokens_test)

CPU times: user 56.5 s, sys: 3.72 s, total: 1min
Wall time: 52.7 s


{'listener_accuracy': 0.8438164950208529, 'corpus_bleu': 0.6099872860729817}

Evaluate on bake-off data

In [71]:
%time model.evaluate(colors_bo, tokens_bo)

CPU times: user 8.88 s, sys: 523 ms, total: 9.4 s
Wall time: 8.4 s


{'listener_accuracy': 0.9192516001969473, 'corpus_bleu': 0.7889350641465928}