In [None]:
__authors__ = "Anton Gochev, Jaro Habr, Yan Jiang, Samuel Kahn"
__version__ = "XCS224u, Stanford, Spring 2021"

## Colours with static embeddings

1. [Setup](#Setup)
1. [Dataset](#Dataset)
    1. [Filtered Corpus](#Filtered-Corpus)
    1. [Bake-Off Corpus](#Bake-Off-Corpus)
1. [Baseline-System](#Baseline-System)
1. [Experiments](#Experiments)
  1. [BERT Embeddings](#BERT-Embeddings)
  2. [XLNet Embeddings](#XLNet-Embeddings)
  3. [RoBERTa Embeddings](#RoBERTa-Embeddings)
  4. [ELECTRA Embeddings](#ELECTRA-Embeddings)

## Setup

This notebook explores the performance of the basemodel with using different pre-trained static embeddings extracted from transformers such as BERT, XLNet, RoBERTa, ELECTRA

In [None]:
from utils.colors import ColorsCorpusReader
import os
import pandas as pd
from sklearn.model_selection import train_test_split
import torch
from utils.torch_color_describer import ContextualColorDescriber, create_example_dataset
import utils.utils as utils
from utils.utils import UNK_SYMBOL, START_SYMBOL, END_SYMBOL
import matplotlib.pyplot as plt
import matplotlib.patches as mpatch
import numpy as np
from baseline.model import (
    BaselineTokenizer, BaselineColorEncoder, BaselineDescriber,
    BaselineEmbedding, BaselineLSTMDescriber, GloVeEmbedding
)
from experiment.vision import ConvolutionalColorEncoder

from transformers import (
    BertTokenizer, BertModel,
    XLNetTokenizer, XLNetModel,
    RobertaTokenizer, RobertaModel,
    ElectraTokenizer, ElectraModel,    
)

import utils.model_utils as mu
import experiment.helper as eh

In [None]:
utils.fix_random_seeds()

## Dataset

This exploration of the dataset counts the examples for different classes and plots the words distribition in order to see any data imbalance issues.

### Filtered Corpus

The filtered corpus is the full dataset used in assignment 4. The following code looks at the composition of the dataset, the number of example in each condition as well as the word count used in the color descriptions.

In [None]:
COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "filteredCorpus.csv"
)

In [None]:
corpus = ColorsCorpusReader(
    COLORS_SRC_FILENAME,
    word_count=None,
    normalize_colors=True
)

In [None]:
examples = list(corpus.read())

In [None]:
len(examples)

In [None]:
subset_examples = mu.extract_colour_examples(examples, from_word_count=5)

In [None]:
close_examples = [example for example in examples if example.condition == "close"]
split_examples = [example for example in examples if example.condition == "split"]
far_examples = [example for example in examples if example.condition == "far"]

In [None]:
print(f"close: {len(close_examples)}")
print(f"split: {len(split_examples)}")
print(f"far: {len(far_examples)}")

To understand the datasets (training and bake-off) in more details refer to [colors_in_context.ipynb](colors_in_context.ipynb). The notebook shows the distribution of the colours examples among the different splits.

### Bake-Off Corpus

The following code analyses the bake-off dataset. We will look at the number of examples for each of the conditions as well as the word count used to described the colors.

In [None]:
BAKE_OFF_COLORS_SRC_FILENAME = os.path.join(
    "data", "colors", "cs224u-colors-bakeoff-data.csv"
)

In [None]:
bake_off_corpus = ColorsCorpusReader(
    BAKE_OFF_COLORS_SRC_FILENAME,
    word_count=None,
    normalize_colors=True
)

In [None]:
bake_off_examples = list(bake_off_corpus.read())

In [None]:
import pickle
emb_path = os.path.join(
    "data", "colors", "resnet18_color_embeddings.pickle"
)
file = open(emb_path,'rb')
resnet_emb = pickle.load(file)
file.close()


In [None]:
len(resnet_emb)

## Baseline-System

This baseline system is based on assignment 4 and we use different token embeddings and sequences.

### Model training - full dataset

The full color context dataset is used for final baseline model training.

In [None]:
rawcols, texts = zip(*[[ex.colors, ex.contents] for ex in examples]) 
raw_colors_test_bo, texts_test_bo = zip(*[[ex.colors, ex.contents] for ex in bake_off_examples])

raw_colors_train, raw_colors_test, texts_train, texts_test = \
        train_test_split(rawcols, texts)

## Experiments

Results from the experiments:

| Model | Embeddings | Unit | h-params | Protocol | Training Results | Bake-off Results |
| --- | --- | --- | --- | --- | --- | --- |
| BERT 'bert-base-cased' | ResNet18 + static | LSTM | default | Stopping after epoch 16. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 186.72212171554565 CPU times: user 51min 9s, sys: 21min 52s, total: 1h 13min 1s Wall time: 23min 55s | {'listener_accuracy': 0.3704996169886799, 'corpus_bleu': 0.4271658095590124} | {'listener_accuracy': 0.34711964549483015, 'corpus_bleu': 0.6166938495719256} |
| BERT 'bert-base-cased' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 112. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.370091795921326 CPU times: user 6h 7min 3s, sys: 2h 36min 50s, total: 8h 43min 53s Wall time: 2h 53min 23s | {'listener_accuracy': 0.8209209294408035, 'corpus_bleu': 0.5904509830096333} | {'listener_accuracy': 0.9148202855736091, 'corpus_bleu': 0.7719649525541963} |
| BERT 'bert-base-cased' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 89. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 72.01054859161377 CPU times: user 5h 32min 19s, sys: 2h 8min 14s, total: 7h 40min 33s Wall time: 2h 47min 38s | {'listener_accuracy': 0.8190484296535875, 'corpus_bleu': 0.6143887490272109} | {'listener_accuracy': 0.8956179222058099, 'corpus_bleu': 0.7866291971865217} |
| XLNet 'xlnet-base-cased' | ResNet18 + static | LSTM | default | Stopping after epoch 145. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 23.32861888408661 CPU times: user 8h 17min 22s, sys: 3h 23min 31s, total: 11h 40min 53s Wall time: 3h 46min 43s | {'listener_accuracy': 0.7549578687547877, 'corpus_bleu': 0.5648871242223105} | {'listener_accuracy': 0.8670605612998523, 'corpus_bleu': 0.7533564516719005} |
| XLNet 'xlnet-base-cased' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 125. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 33.39379119873047 CPU times: user 7h 32min 8s, sys: 3h 6min 32s, total: 10h 38min 40s Wall time: 3h 28min 54s | {'listener_accuracy': 0.835645586858456, 'corpus_bleu': 0.5903786021680395} | {'listener_accuracy': 0.9113737075332349, 'corpus_bleu': 0.766034610276092} |
| XLNet 'xlnet-base-cased' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 76. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 89.37767934799194 CPU times: user 5h 32min 12s, sys: 2h 1min 52s, total: 7h 34min 4s Wall time: 2h 42min 29s | {'listener_accuracy': 0.8366669503787556, 'corpus_bleu': 0.581864699762294} | {'listener_accuracy': 0.9162973904480551, 'corpus_bleu': 0.7633163869123465} |
| XLNet 'xlnet-base-cased' | ResNet18+Fourier + static | LSTM | hid_dim=250 | Stopping after epoch 56. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 115.53824925422668 CPU times: user 6h 38min 43s, sys: 1h 55min 55s, total: 8h 34min 38s Wall time: 3h 48min 39s | {'listener_accuracy': 0.6549493573921185, 'corpus_bleu': 0.2996265804208502} | {'listener_accuracy': 0.7311669128508124, 'corpus_bleu': 0.38023324553545096} |
| XLNet 'xlnet-base-cased' (2) | ResNet18+Fourier + static | LSTM | hid_dim=250 | Stopping after epoch 50. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 127.68089389801025 CPU times: user 4h 43min 49s, sys: 1h 25min 23s, total: 6h 9min 12s Wall time: 2h 33min 41s | {'listener_accuracy': 0.8478168354753596, 'corpus_bleu': 0.5908243462277566} | {'listener_accuracy': 0.9148202855736091, 'corpus_bleu': 0.770364787937182} |
| XLNet 'xlnet-base-cased' | ResNet18+Fourier + static | GRU | hid_dim=250 | Stopping after epoch 33. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 148.86750602722168 CPU times: user 3h 22min 40s, sys: 1h 1min 45s, total: 4h 24min 25s Wall time: 1h 53min 35s 39s | {'listener_accuracy': 0.6495020852838539, 'corpus_bleu': 0.29153560410619433} | {'listener_accuracy': 0.7257508616445101, 'corpus_bleu': 0.3740184682732033} |
| XLNet 'xlnet-base-cased' (2) | ResNet18+Fourier + static | GRU | hid_dim=250 | Stopping after epoch 40. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 143.11642217636108 CPU times: user 3h 26min 43s, sys: 1h 4min 33s, total: 4h 31min 16s Wall time: 1h 50min 59s | {'listener_accuracy': 0.8321559281640991, 'corpus_bleu': 0.5825406845918764} | {'listener_accuracy': 0.7257508616445101, 'corpus_bleu': 0.3740184682732033} |
| XLNet 'xlnet-base-cased' (3) | ResNet18+Fourier + static | GRU | hid_dim=250 | Stopping after epoch 51. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 127.30778098106384 CPU times: user 5h 21min 10s, sys: 1h 36min 17s, total: 6h 57min 27s Wall time: 2h 55min 45s | {'listener_accuracy': 0.8382841092858967, 'corpus_bleu': 0.5808065691051519} | {'listener_accuracy': 0.914327917282127, 'corpus_bleu': 0.7633533127675665} |
| RoBERTa 'roberta-base' | ResNet18 + static | LSTM | default | Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 197.75724983215332 CPU times: user 52min 7s, sys: 19min 12s, total: 1h 11min 20s Wall time: 22min 11s | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| RoBERTa 'roberta-base' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 201.13099479675293 CPU times: user 50min 9s, sys: 19min 13s, total: 1h 9min 22s Wall time: 21min 49s | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| RoBERTa 'roberta-base' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 196.342670917511 CPU times: user 1h 38s, sys: 20min 40s, total: 1h 21min 18s Wall time: 28min 6s | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} | {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18 + static | LSTM | default | Stopping after epoch 200. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 12.498226076364517 CPU times: user 9h 58min 7s, sys: 3h 51min 46s, total: 13h 49min 54s Wall time: 4h 28min 38s | {'listener_accuracy': 0.8018554770618777, 'corpus_bleu': 0.5976330421393174} | {'listener_accuracy': 0.8936484490398818, 'corpus_bleu': 0.7839809145958113} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18+Fourier + static | LSTM | default | Stopping after epoch 90. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 70.57668948173523 CPU times: user 4h 9min 47s, sys: 1h 40min 13s, total: 5h 50min Wall time: 1h 54min 18s | {'listener_accuracy': 0.8173461571197549, 'corpus_bleu': 0.6020019877608638} | {'listener_accuracy': 0.913343180699163, 'corpus_bleu': 0.7838522381057177} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18+Fourier + static | LSTM | hid_dim=100 | Stopping after epoch 101. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 55.01819384098053 CPU times: user 6h 42min 17s, sys: 2h 18min 46s, total: 9h 1min 3s Wall time: 3h 14min 16s | {'listener_accuracy': 0.8438164950208529, 'corpus_bleu': 0.6099872860729817} | {'listener_accuracy': 0.9192516001969473, 'corpus_bleu': 0.7889350641465928} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18+Fourier + static | LSTM | hid_dim=250 | Stopping after epoch 63. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 108.05819869041443 CPU times: user 6h 51min 1s, sys: 1h 52min 51s, total: 8h 43min 53s Wall time: 3h 45min 36s | {'listener_accuracy': 0.8496893352625755, 'corpus_bleu': 0.6152868875583275} | {'listener_accuracy': 0.9256523879862137, 'corpus_bleu': 0.7888655244301115} |
| ELECTRA 'google/electra-small-discriminator' | ResNet18+Fourier + static | GRU | hid_dim=250 | Stopping after epoch 50. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 127.47516798973083 CPU times: user 5h 2min 18s, sys: 1h 25min 8s, total: 6h 27min 27s Wall time: 2h 42min 58s | {'listener_accuracy': 0.8374329730189803, 'corpus_bleu': 0.59792027562232} | {'listener_accuracy': 0.9148202855736091, 'corpus_bleu': 0.7844196479990682} |

In [None]:
color_encoder = ConvolutionalColorEncoder(True)
%time colors_train, colors_test, colors_bo = \
    eh.create_colours_sets(color_encoder, raw_colors_train, raw_colors_test, raw_colors_test_bo)

In [None]:
colors = {'train': colors_train, 'test': colors_test, 'bo': colors_bo}

In [None]:
hidden_dims = [50, 100, 150, 250]
start_index = 7000
end_index = 15000

### BERT Embeddings

In [None]:
bert_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
bert_model = BertModel.from_pretrained('bert-base-cased')

In [None]:
%time tokens_train, tokens_test, tokens_bo = \
    eh.create_tokens_sets(bert_tokenizer, texts_train, texts_test, texts_test_bo)

In [None]:
%time bert_embeddings, bert_vocab = mu.extract_input_embeddings(texts_train, bert_model, bert_tokenizer, add_special_tokens=True)

***Experiment with different hidden_dim size***

| Results - BERT LSTM - static embeddings + ResNet+Fourier |
| ------- |
| Stopping after epoch 44. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.76516532897949 <br /> train 50 - {'listener_accuracy': 0.45654949357392116, 'corpus_bleu': 0.3822134017057766} <br /> bake-off 50 - {'listener_accuracy': 0.42146725750861647, 'corpus_bleu': 0.5439987449504363} |
| Stopping after epoch 53. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.292728900909424 <br /> train 100 - {'listener_accuracy': 0.5274491446080517, 'corpus_bleu': 0.5350260003935228} <br /> bake-off 100 - {'listener_accuracy': 0.5701624815361891, 'corpus_bleu': 0.7108937241938815} |
| Stopping after epoch 94. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.2493839263916 <br /> train 150 - {'listener_accuracy': 0.5913694782534684, 'corpus_bleu': 0.5246605959686265} <br /> bake-off 150 - {'listener_accuracy': 0.6814377154111275, 'corpus_bleu': 0.701929740281104} |
| Stopping after epoch 62. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 43.86564564704895 <br /> train 250 - {'listener_accuracy': 0.597242318495191, 'corpus_bleu': 0.5090204246106848} <br /> bake-off 250 - {'listener_accuracy': 0.6578040374199902, 'corpus_bleu': 0.6742907898833734} |

| Results - BERT GRU - static embeddings + ResNet+Fourier |
| ------- |
| Stopping after epoch 53. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 43.16489589214325 <br /> train 50 - {'listener_accuracy': 0.45374074389309726, 'corpus_bleu': 0.5557992772981372} <br /> bake-off 50 - {'listener_accuracy': 0.48695224027572626, 'corpus_bleu': 0.7114784693608719} |
| Stopping after epoch 50. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 45.62858486175537 <br /> train 100 - {'listener_accuracy': 0.530172780662184, 'corpus_bleu': 0.5187928224826821} <br /> bake-off 100 - {'listener_accuracy': 0.5839487936976858, 'corpus_bleu': 0.6876888605354385} |
| Stopping after epoch 55. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.40023684501648 <br /> train 150 - {'listener_accuracy': 0.6003915226827815, 'corpus_bleu': 0.1633409451091387} <br /> bake-off 150 - {'listener_accuracy': 0.6681437715411127, 'corpus_bleu': 0.14443461637230007} |
| Stopping after epoch 93. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.163474321365356 <br /> train 250 - {'listener_accuracy': 0.6660992424887224, 'corpus_bleu': 0.5453195319279439} <br /> bake-off 250 - {'listener_accuracy': 0.7154111275233875, 'corpus_bleu': 0.720590807786267} |

*Execution time* <br /> 
CPU times: user 4h 47min 57s, sys: 1h 35min 35s, total: 6h 23min 33s
Wall time: 2h 36min 18s


In [None]:
####remove the comment and run the code for experiments with hidden_dim

In [None]:
#tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

In [None]:
#### for GRU set unit='GRU'
#%time eh.run_hiddim_options(hidden_dims, start_index, end_index, bert_vocab, bert_embeddings, colors, tokens, unit='GRU')

***Baseline model using BERT pretrained embeddings and vocab*

In [None]:
model = BaselineLSTMDescriber(
    bert_vocab,
    embedding=bert_embeddings,
    early_stopping=True,
    hidden_dim=250
)

In [None]:
%time _ = model.fit(colors_train, tokens_train)

Evaluate on test data

In [None]:
%time model.evaluate(colors_test, tokens_test)

Evaluate on bake-off data

In [None]:
%time model.evaluate(colors_bo, tokens_bo)

### XLNet Embeddings

In [None]:
xlnet_tokenizer = XLNetTokenizer.from_pretrained('xlnet-base-cased')
xlnet_model = XLNetModel.from_pretrained('xlnet-base-cased')

In [None]:
%time tokens_train, tokens_test, tokens_bo = \
    eh.create_tokens_sets(xlnet_tokenizer, texts_train, texts_test, texts_test_bo)

In [None]:
%time xlnet_embeddings, xlnet_vocab = mu.extract_input_embeddings(texts_train, xlnet_model, xlnet_tokenizer, add_special_tokens=True)

***Experiment with different hidden_dim size***

| Results - XLNet LSTM - static embeddings + ResNet+Fourier |
| ------- |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 51.43806719779968 <br /> train 50 - {'listener_accuracy': 0.40028938633075156, 'corpus_bleu': 0.11601226111917491} <br /> bake-off 50 - {'listener_accuracy': 0.34859675036927623, 'corpus_bleu': 0.12250562016041742} |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 50.56151986122131 <br /> train 100 - {'listener_accuracy': 0.38420291088603287, 'corpus_bleu': 0.10000000000000002} <br /> bake-off 100 - {'listener_accuracy': 0.38010832102412606, 'corpus_bleu': 0.10000000000000002} |
| Stopping after epoch 65. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.73309278488159 <br /> train 150 - {'listener_accuracy': 0.5334922121031577, 'corpus_bleu': 0.4307617849191753} <br /> bake-off 150 - {'listener_accuracy': 0.6011816838995568, 'corpus_bleu': 0.3972355769230769} |
| Stopping after epoch 81. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.55567979812622 <br /> train 250 - {'listener_accuracy': 0.5929015235339178, 'corpus_bleu': 0.5084326142252633} <br /> bake-off 250 - {'listener_accuracy': 0.6932545544066963, 'corpus_bleu': 0.6762724153392196} |

| Results - XLNet GRU - static embeddings + ResNet+Fourier |
| ------- |
| Stopping after epoch 22. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 49.95226716995239 <br /> train 50 - {'listener_accuracy': 0.4078644991063069, 'corpus_bleu': 0.22546245072544133} <br /> bake-off 50 - {'listener_accuracy': 0.4017725258493353, 'corpus_bleu': 0.367694263699723} |
| Stopping after epoch 32. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 47.21669912338257 <br /> train 100 - {'listener_accuracy': 0.4464209719976168, 'corpus_bleu': 0.5122018513333958} <br /> bake-off 100 - {'listener_accuracy': 0.4938453963564746, 'corpus_bleu': 0.688742937744571} |
| Stopping after epoch 35. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.30550456047058 <br /> train 150 - {'listener_accuracy': 0.5110222146565665, 'corpus_bleu': 0.4031163071177077} <br /> bake-off 150 - {'listener_accuracy': 0.5558838010832102, 'corpus_bleu': 0.40430663221360896} |
| Stopping after epoch 42. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.303759813308716 <br /> train 250 - {'listener_accuracy': 0.5473657332538939, 'corpus_bleu': 0.5113511314807752} <br /> bake-off 250 - {'listener_accuracy': 0.6272772033481043, 'corpus_bleu': 0.6787631418303108} |

*Execetion time* <br /> 
CPU times: user 2h 33min 14s, sys: 49min 24s, total: 3h 22min 39s
Wall time: 1h 23min 39s

In [None]:
###remove the comment and run the code for experiments with hidden_dim

In [None]:
#tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

In [None]:
###for GRU set unit='GRU'
#%time eh.run_hiddim_options(hidden_dims, start_index, end_index, xlnet_vocab, xlnet_embeddings, colors, tokens, unit='GRU')

***Baseline model using XLNet pretrained embeddings and vocab***

In [None]:
model_xlnet_lstm = BaselineLSTMDescriber(
    vocab=xlnet_vocab,
    embedding=xlnet_embeddings,
    early_stopping=True,
    hidden_dim=250
)

In [None]:
%time _ = model_xlnet_lstm.fit(colors_train, tokens_train)

Evaluate on test data

In [None]:
%time model_xlnet_lstm.evaluate(colors_test, tokens_test)

Evaluate on bake-off data

In [None]:
%time model_xlnet_lstm.evaluate(colors_bo, tokens_bo)

In [None]:
model_xlnet_gru = BaselineDescriber(
    vocab=xlnet_vocab,
    embedding=xlnet_embeddings,
    early_stopping=True,
    hidden_dim=250
)

In [None]:
%time _ = model_xlnet_gru.fit(colors_train, tokens_train)

Evaluate on test data

In [None]:
%time model_xlnet_gru.evaluate(colors_test, tokens_test)

Evaluate on bake-off data

In [None]:
%time model_xlnet_gru.evaluate(colors_bo, tokens_bo)

### RoBERTa Embeddings

In [None]:
roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
roberta_model = RobertaModel.from_pretrained('roberta-base')

In [None]:
%time tokens_train, tokens_test, tokens_bo = \
    eh.create_tokens_sets(roberta_tokenizer, texts_train, texts_test, texts_test_bo)

In [None]:
%time roberta_embeddings, roberta_vocab = mu.extract_input_embeddings(texts_train, roberta_model, roberta_tokenizer, add_special_tokens=True)

***Experiment with different hidden_dim size***

| Results - RoBERTa GRU - static embeddings + ResNet+Fourier|
| ------- |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 50.86355710029602 <br /> train 50 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.15} <br /> bake-off 50 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.15} |
Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 52.07556104660034 <br /> train 100 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} <br /> bake-off 100 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 50.25857353210449 <br /> train 150 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} <br /> bake-off 150 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |
| Stopping after epoch 12. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 50.29149055480957 <br /> train 250 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} <br /> bake-off 250 - {'listener_accuracy': 1.0, 'corpus_bleu': 0.2} |

*Execution time* <br /> 
CPU times: user 1h 16min 18s, sys: 21min 54s, total: 1h 38min 12s
Wall time: 41min 2s

In [None]:
###remove the comment and run the code for experiments with hidden_dim

In [None]:
tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

In [None]:
###for GRU set unit='GRU'
%time eh.run_hiddim_options(hidden_dims, start_index, end_index, roberta_vocab, roberta_embeddings, colors, tokens, unit='GRU')

***Baseline model using RoBERTa pretrained embeddings and vocab***

In [None]:
model = BaselineLSTMDescriber(
    roberta_vocab,
    embedding=roberta_embeddings,
    early_stopping=True,
    hidden_dim=250
)

In [None]:
%time _ = model.fit(colors_train, tokens_train)

Evaluate on test data

In [None]:
%time model.evaluate(colors_test, tokens_test)

Evaluate on bake-off data

In [None]:
%time model.evaluate(colors_bo, tokens_bo)

### ELECTRA Embeddings

In [None]:
electra_tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')
electra_model = ElectraModel.from_pretrained('google/electra-small-discriminator')

In [None]:
%time tokens_train, tokens_test, tokens_bo = \
    eh.create_tokens_sets(electra_tokenizer, texts_train, texts_test, texts_test_bo)

In [None]:
%time electra_embeddings, electra_vocab = mu.extract_input_embeddings(texts_train, electra_model, electra_tokenizer, add_special_tokens=True)

***Experiment with different hidden_dim size***

| Results - ELECTRA LSTM - static embeddings + ResNet+Fourier|
| ------- |
| Stopping after epoch 23. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 47.63471579551697 <br /> train 50 - {'listener_accuracy': 0.3909268873946719, 'corpus_bleu': 0.15} <br /> bake-off 50 - {'listener_accuracy': 0.3845396356474643, 'corpus_bleu': 0.15} |
| Stopping after epoch 21. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 49.45027256011963 <br /> train 100 - {'listener_accuracy': 0.42471699719125033, 'corpus_bleu': 0.5404861636039842} <br /> bake-off 100 - {'listener_accuracy': 0.4244214672575086, 'corpus_bleu': 0.6658470617162527} |
| Stopping after epoch 63. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 41.79938507080078 <br /> train 150 - {'listener_accuracy': 0.5634522086986127, 'corpus_bleu': 0.5305960626612526} <br /> bake-off 150 - {'listener_accuracy': 0.6149679960610537, 'corpus_bleu': 0.6977797477288923} |
| Stopping after epoch 59. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.45513558387756 <br /> train 250 - {'listener_accuracy': 0.5720486849944676, 'corpus_bleu': 0.5589953211796204} <br /> bake-off 250 - {'listener_accuracy': 0.6297390448055146, 'corpus_bleu': 0.7396116726551871} |

| Results - ELECTRA GRU - static embeddings + ResNet+Fourier|
| ------- |
| Stopping after epoch 39. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 47.41162323951721 <br /> train 50 - {'listener_accuracy': 0.4632734700825602, 'corpus_bleu': 0.4777896222522073} <br /> bake-off 50 - {'listener_accuracy': 0.4938453963564746, 'corpus_bleu': 0.657231763491947} |
| Stopping after epoch 34. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 46.33916139602661 <br /> train 100 - {'listener_accuracy': 0.4330581326070304, 'corpus_bleu': 0.5023199247333534} <br /> bake-off 100 - {'listener_accuracy': 0.4401772525849335, 'corpus_bleu': 0.6706403921940264} |
| Stopping after epoch 37. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 43.32993298768997 <br /> train 150 - {'listener_accuracy': 0.5265128947144437, 'corpus_bleu': 0.5373655274110541} <br /> bake-off 150 - {'listener_accuracy': 0.5824716888232397, 'corpus_bleu': 0.7040877364082542} |
| Stopping after epoch 65. Validation score did not improve by tol=1e-05 for more than 10 epochs. Final error is 44.20422172546387 <br /> train 250 - {'listener_accuracy': 0.5704315260873266, 'corpus_bleu': 0.5305960626612526} <br /> bake-off 250 - {'listener_accuracy': 0.6366322008862629, 'corpus_bleu': 0.6977797477288923} |

*Execution time* <br /> 
CPU times: user 2h 54min 38s, sys: 54min 51s, total: 3h 49min 29s
Wall time: 1h 33min 24s

In [None]:
###remove the comment and run the code for experiments with hidden_dim

In [None]:
#tokens = {'train': tokens_train, 'test': tokens_test, 'bo': tokens_bo}

In [None]:
###for GRU set unit='GRU'
#%time eh.run_hiddim_options(hidden_dims, start_index, end_index, electra_vocab, electra_embeddings, colors, tokens, unit='GRU')

***Baseline model using ELECTRA pretrained embeddings and vocab***

In [None]:
with open("data/experiments/xlnet_gru_resnet_vscores_2.txt", "w") as f:
    f.write(str(model_xlnet_gru.validation_scores))

In [None]:
model_lstm = BaselineLSTMDescriber(
    vocab=electra_vocab,
    embedding=electra_embeddings,
    early_stopping=True,
    hidden_dim=250
)

In [None]:
%time _ = model_lstm.fit(colors_train, tokens_train)

Evaluate on test data

In [None]:
%time model_lstm.evaluate(colors_test, tokens_test)

Evaluate on bake-off data

In [None]:
%time model_lstm.evaluate(colors_bo, tokens_bo)

In [None]:
model_gru = BaselineDescriber(
    vocab=electra_vocab,
    embedding=electra_embeddings,
    early_stopping=True,
    hidden_dim=250
)

In [None]:
%time _ = model_gru.fit(colors_train, tokens_train)

Evaluate on test data

In [None]:
%time model_gru.evaluate(colors_test, tokens_test)

Evaluate on bake-off data

In [None]:
%time model_gru.evaluate(colors_bo, tokens_bo)