In [1]:
from pathlib import Path
from typing import *
import torch
import torch.optim as optim
import numpy as np
import pandas as pd
from functools import partial
from overrides import overrides

from allennlp.data import Instance
from allennlp.data.fields import TextField, SequenceLabelField, LabelField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token

TODOs:

- Make compatible with GPU
- Try replicating SST results
- Write results to MongoDB Atlas
- Store weights in s3

In [2]:
class Config(dict):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        for k, v in kwargs.items():
            setattr(self, k, v)

In [3]:
# for papermill
testing = True
seed = 1
batch_size = 64
embed_dim = 256
hidden_sz = 768
dataset = "sst"
n_classes = 2
max_seq_len = 128
bert_model = "bert-base-cased"
run_id = "replicate_0"

In [4]:
# TODO: Can we make this play better with papermill?
config = Config(
    testing=testing,
    seed=seed,
    batch_size=batch_size, # This is probably too large: need to handle effective v.s. machine batch size
    embed_dim=embed_dim,
    hidden_sz=hidden_sz,
    dataset=dataset,
    n_classes=n_classes,
    max_seq_len=max_seq_len, # necessary to limit memory usage
#     bert_model=None,
    bert_model=bert_model,
    run_id=run_id,
)

In [5]:
import datetime
now = datetime.datetime.now()
RUN_ID = config.run_id if config.run_id is not None else now.strftime("%m_%d_%H:%M:%S")

In [6]:
DATA_ROOT = Path("../data") / config.dataset

Set random seed manually to replicate results

In [7]:
torch.manual_seed(config.seed)

<torch._C.Generator at 0x119502e50>

# Load Data

In [8]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.data.dataset_readers import DatasetReader, StanfordSentimentTreeBankDatasetReader

### Prepare dataset

In [9]:
import glob

In [10]:
reader_registry = {}

In [11]:
def register(name: str):
    def dec(x: Callable):
        reader_registry[name] = x
        return x
    return dec

In [12]:
@register("jigsaw")
class JigsawDatasetReader(DatasetReader):
    def __init__(self, tokenizer: Callable[[str], List[str]]=lambda x: x.split(),
                 token_indexers: Dict[str, TokenIndexer] = None, # TODO: Handle mapping from BERT
                 max_seq_len: Optional[int]=config.max_seq_len) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len

    @overrides
    def text_to_instance(self, tokens: List[Token], label: str = None) -> Instance:
        # TODO: Reimplement
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}

        label_field = LabelField(label=label, skip_indexing=True)
        fields["label"] = label_field

        return Instance(fields)
    
    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        df = pd.read_csv(file_path)
        if config.testing: df = df.head(10000)
        for i, row in df.iterrows():
            yield self.text_to_instance(
                [Token(x) for x in self.tokenizer(row["comment_text"])],
                row["toxic"]
            )

In [13]:
@register("imdb")
class IMDBDatasetReader(DatasetReader):
    def __init__(self, tokenizer=None, 
                 token_indexers: Dict[str, TokenIndexer] = None,
                 max_seq_len=None) -> None:
        super().__init__(lazy=False)
        self.tokenizer = tokenizer
        self.token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.max_seq_len = max_seq_len
    
    @overrides
    def text_to_instance(self, tokens: List[Token], label: str = None) -> Instance:
        sentence_field = TextField(tokens, self.token_indexers)
        fields = {"tokens": sentence_field}
        
        # TODO: Add statistical features?

        label_field = LabelField(label=label)
        fields["label"] = label_field

        return Instance(fields)

    @overrides
    def _read(self, file_path: str) -> Iterator[Instance]:
        # TODO: Implement
        for label in ["pos", "neg"]:
            for file in (Path(file_path) / label).glob("*.txt"):
                text = file.open("rt", encoding="utf-8").read()
                yield self.text_to_instance([Token(word) for word in self.tokenizer(text)], 
                                            label)

In [14]:
@register("sst")
class SSTDatasetReader(StanfordSentimentTreeBankDatasetReader):
    def __init__(self, *args, tokenizer=None, **kwargs):
        super().__init__(*args, granularity=f"{config.n_classes}-class", **kwargs)
        self._tokenizer = tokenizer
        
    @overrides
    def text_to_instance(self, tokens: List[str], sentiment: str=None) -> Instance:
        """
        Forcibly re-tokenize the input to be wordpiece tokenized
        """
        tokens = self._tokenizer(" ".join(tokens))
        return super().text_to_instance(tokens, sentiment=sentiment)

### Prepare token handlers

In [15]:
from allennlp.data.token_indexers import WordpieceIndexer, SingleIdTokenIndexer
from pytorch_pretrained_bert.tokenization import BertTokenizer

class BertIndexerCustom(WordpieceIndexer):
    """
    Virtually the same as PretrainedWordIndexer, except exposes more options.
    """
    def __init__(self, pretrained_model: str,
                 use_starting_offsets: bool = False,
                 do_lowercase: bool = True,
                 never_lowercase: List[str] = None,
                 max_pieces: int = 512,
                 start_tokens=["[CLS]"],
                 end_tokens=["[SEP]"]) -> None:
        assert not (pretrained_model.endswith("-cased") and do_lowercase)
        assert not (pretrained_model.endswith("-uncased") and not do_lowercase)
        bert_tokenizer = BertTokenizer.from_pretrained(pretrained_model,
                                                       do_lower_case=do_lowercase)
        super().__init__(vocab=bert_tokenizer.vocab,
                         wordpiece_tokenizer=bert_tokenizer.wordpiece_tokenizer.tokenize,
                         namespace="bert",
                         use_starting_offsets=use_starting_offsets,
                         max_pieces=max_pieces,
                         do_lowercase=do_lowercase,
                         never_lowercase=never_lowercase,
                         start_tokens=start_tokens,
                         end_tokens=end_tokens)

In [16]:
if config.bert_model is not None:
    token_indexer = BertIndexerCustom(
        pretrained_model=config.bert_model,
        max_pieces=config.max_seq_len,
        do_lowercase="uncased" in config.bert_model,
     )
    # apparently we need to truncate the sequence here, which is a stupid design decision
    def tokenizer(s: str):
        return token_indexer.wordpiece_tokenizer(s)[:config.max_seq_len - 2]
else:
    token_indexer = SingleIdTokenIndexer(
        lowercase_tokens=False,  # don't lowercase by default
    )
    tokenizer = lambda x: x.split()

01/18/2019 15:19:20 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /Users/keitakurita/.pytorch_pretrained_bert/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1


In [17]:
reader_cls = reader_registry[config.dataset]
reader = reader_cls(tokenizer=tokenizer,
                    token_indexers={"tokens": token_indexer})

In [18]:
if config.dataset == "imdb":
    data_dir = DATA_ROOT / "imdb" / "aclImdb"
    train_ds, test_ds = (reader.read(data_dir / fname) for fname in ["train", "test"])
    val_ds = None
elif config.dataset == "sst":
    data_dir = DATA_ROOT / "trees"
    train_ds, val_ds, test_ds = (reader.read(data_dir / fname) for fname in ["train.txt", "dev.txt", "test.txt"])
else:
    train_ds, test_ds = (reader.read(DATA_ROOT / fname) for fname in ["train.csv", "test_proced.csv"])
    val_ds = None

0it [00:00, ?it/s]01/18/2019 15:19:21 - INFO - allennlp.data.dataset_readers.stanford_sentiment_tree_bank -   Reading instances from lines in file at: ../data/sst/trees/train.txt
6920it [00:02, 2378.88it/s]
0it [00:00, ?it/s]01/18/2019 15:19:23 - INFO - allennlp.data.dataset_readers.stanford_sentiment_tree_bank -   Reading instances from lines in file at: ../data/sst/trees/dev.txt
872it [00:00, 2349.07it/s]
0it [00:00, ?it/s]01/18/2019 15:19:24 - INFO - allennlp.data.dataset_readers.stanford_sentiment_tree_bank -   Reading instances from lines in file at: ../data/sst/trees/test.txt
1821it [00:00, 2850.56it/s]


In [19]:
len(train_ds), len(test_ds)

(6920, 1821)

### Prepare vocabulary

In [20]:
vocab = Vocabulary.from_instances(train_ds)
if config.bert_model is not None: 
    token_indexer._add_encoding_to_vocabulary(vocab)

01/18/2019 15:19:24 - INFO - allennlp.data.vocabulary -   Fitting token dictionary from dataset.
100%|██████████| 6920/6920 [00:00<00:00, 120423.46it/s]


### Prepare iterator

In [21]:
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator

In [22]:
# TODO: Allow for customization
iterator = BucketIterator(batch_size=config.batch_size, 
                          biggest_batch_first=True,
                          sorting_keys=[("tokens", "num_tokens")],)
iterator.index_with(vocab)

### Read sample

In [23]:
import warnings
with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    batch = next(iter(iterator(train_ds)))

In [24]:
batch

{'tokens': {'tokens': tensor([[  101,   169,  2131, 18757,  1513,   112,  1116,  8447,   118, 26161,
            2064, 28137,  1110,   118,  2069, 22672, 28137,   170,  5439, 14439,
             176, 20901,  1162, 28137,  8508,  5498,  1114,   170,  1716,  2572,
            1176,   144, 16724,  9654,  1107,   170,  3223, 28137, 19972,  8967,
            1105,   170,  2566,  1112,  2430,  6775,  1158,  1884,  1197, 28137,
            1830, 24891,  2254, 28137,  7535,  1964, 28137,  1161, 28137,  7641,
            2158, 11012,  4695,  9603,   119,   112,   102,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0],
          [  101,   142, 28138,  1942, 28138,  1759,  1272,  1157, 22593,  6639,
            2953, 12788,  1158,  3981,  1116,   117,  1489, 28137,  4980,  1813,
           28137, 11015,  1823,  6603,  2249, 23825,  1633,   117,   127, 28137,
            4980,  1813, 28137, 11015,  8633,  56

In [25]:
batch["tokens"]["tokens"]

tensor([[  101,   169,  2131, 18757,  1513,   112,  1116,  8447,   118, 26161,
          2064, 28137,  1110,   118,  2069, 22672, 28137,   170,  5439, 14439,
           176, 20901,  1162, 28137,  8508,  5498,  1114,   170,  1716,  2572,
          1176,   144, 16724,  9654,  1107,   170,  3223, 28137, 19972,  8967,
          1105,   170,  2566,  1112,  2430,  6775,  1158,  1884,  1197, 28137,
          1830, 24891,  2254, 28137,  7535,  1964, 28137,  1161, 28137,  7641,
          2158, 11012,  4695,  9603,   119,   112,   102,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0],
        [  101,   142, 28138,  1942, 28138,  1759,  1272,  1157, 22593,  6639,
          2953, 12788,  1158,  3981,  1116,   117,  1489, 28137,  4980,  1813,
         28137, 11015,  1823,  6603,  2249, 23825,  1633,   117,   127, 28137,
          4980,  1813, 28137, 11015,  8633,  5631,  4982,  1105,  1275, 28137,
          4980

In [26]:
batch["tokens"]["tokens"].shape

torch.Size([8, 84])

# Prepare Model

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim

In [28]:
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.token_embedders.bert_token_embedder import BertEmbedder, PretrainedBertEmbedder
from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, PytorchSeq2VecWrapper
from allennlp.modules.stacked_bidirectional_lstm import StackedBidirectionalLstm
from allennlp.nn.util import get_text_field_mask

In [29]:
class LstmEncoder(nn.Module):
    def __init__(self, lstm):
        super().__init__()
        self.lstm = lstm
        
    def forward(self, x, mask): # TODO: replace with allennlp built in modules
        _, (state, _) = self.lstm(x)
        state = torch.cat([state[0, :, :], state[1, :, :]], dim=1)
        return state

In [30]:
class BertPooler(nn.Module):
    """Source code copied"""
    def __init__(self):
        super().__init__()
        self.dense = nn.Linear(768, 768)
        self.activation = nn.Tanh()

    def forward(self, hidden_states, mask):
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        first_token_tensor = hidden_states[:, 0]
        pooled_output = self.dense(first_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [31]:
class SentimentAnalysisModel(Model):
    def __init__(self, word_embeddings: TextFieldEmbedder,
                 encoder: StackedBidirectionalLstm,
                 out_sz=config.n_classes):
        super().__init__(vocab)
        self.word_embeddings = word_embeddings
        self.encoder = encoder
#         self.projection = nn.Linear(encoder.get_output_dim(), out_sz)
        self.projection = nn.Linear(config.hidden_sz, out_sz)
        self.accuracy = CategoricalAccuracy()
        
    def forward(self,
                tokens: Dict[str, torch.Tensor],
                label: torch.Tensor = None) -> torch.Tensor:
        mask = get_text_field_mask(tokens)
        embeddings = self.word_embeddings(tokens["tokens"])
        state = self.encoder(embeddings, mask)

        class_logits = self.projection(state)
        
        output = {"class_logits": class_logits}
        if label is not None:
            output["accuracy"] = self.accuracy(class_logits, label, None)
            output["loss"] = nn.CrossEntropyLoss()(class_logits, label)

        return output

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        return {"accuracy": self.accuracy.get_metric(reset)}

In [32]:
if config.bert_model is None:
    token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                                embedding_dim=config.embed_dim)
    word_embeddings = BasicTextFieldEmbedder({"tokens": token_embedding})

    # encoder = PytorchSeq2VecWrapper(nn.LSTM(config.embed_dim, config.hidden_sz, batch_first=True,
    #                                         bidirectional=True))
    encoder = LstmEncoder(
        nn.LSTM(config.embed_dim, config.hidden_sz, batch_first=True, bidirectional=True)
    )

else:
    word_embeddings = PretrainedBertEmbedder(
        pretrained_model=config.bert_model,
        top_layer_only=True, # conserve memory
    )
    encoder = BertPooler()

01/18/2019 15:19:26 - INFO - pytorch_pretrained_bert.modeling -   loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased.tar.gz from cache at /Users/keitakurita/.pytorch_pretrained_bert/a803ce83ca27fecf74c355673c434e51c265fb8a3e0e57ac62a80e38ba98d384.681017f415dfb33ec8d0e04fe51a619f3f01532ecea04edbfd48c5d160550d9c
01/18/2019 15:19:26 - INFO - pytorch_pretrained_bert.modeling -   extracting archive file /Users/keitakurita/.pytorch_pretrained_bert/a803ce83ca27fecf74c355673c434e51c265fb8a3e0e57ac62a80e38ba98d384.681017f415dfb33ec8d0e04fe51a619f3f01532ecea04edbfd48c5d160550d9c to temp dir /var/folders/hy/1czs1y5j2d58zgkqx6w_wnpw0000gn/T/tmp1vlo11x0
01/18/2019 15:19:30 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 

In [33]:
model = SentimentAnalysisModel(
    word_embeddings, 
    encoder, 
    out_sz=config.n_classes,
)

In [34]:
model

SentimentAnalysisModel(
  (word_embeddings): PretrainedBertEmbedder(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(28996, 768)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): BertLayerNorm()
        (dropout): Dropout(p=0.1)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0): BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): BertLayerNorm()
                (

### Basic sanity checks

In [35]:
np.isnan(list(model.word_embeddings.parameters())[0].detach().numpy()).any()

False

In [36]:
[np.isnan(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

[False, False]

In [37]:
[np.isinf(x.detach().numpy()).any() for x in list(model.encoder.parameters())]

[False, False]

In [38]:
tokens = batch["tokens"]
encoder(model.word_embeddings(tokens["tokens"]), get_text_field_mask(tokens))
# encoder(model.word_embeddings(tokens["tokens"]))[1][0].size()

tensor([[-0.2845,  0.5205, -0.2903,  ..., -0.6036, -0.2421, -0.0638],
        [-0.2780,  0.4673, -0.3143,  ..., -0.6079, -0.1663, -0.1194],
        [-0.2507,  0.3541, -0.3031,  ..., -0.6354, -0.2382, -0.0679],
        ...,
        [-0.1060,  0.4919, -0.3137,  ..., -0.7491, -0.2599,  0.0784],
        [ 0.2190,  0.4185, -0.3076,  ..., -0.4171, -0.1773,  0.1692],
        [-0.1925,  0.2830, -0.1771,  ..., -0.5368, -0.2289, -0.0171]],
       grad_fn=<TanhBackward>)

In [39]:
loss = model(**batch)["loss"]

In [40]:
nn.CrossEntropyLoss()(model(**batch)["class_logits"][:10, :], batch["label"][:10])

tensor(0.7603, grad_fn=<NllLossBackward>)

In [41]:
loss

tensor(0.7166, grad_fn=<NllLossBackward>)

In [42]:
loss.backward()

In [43]:
[x.grad for x in list(model.encoder.parameters())]

[tensor([[ 0.0026, -0.0017, -0.0017,  ..., -0.0019,  0.0011,  0.0007],
         [ 0.0004, -0.0003, -0.0003,  ..., -0.0003,  0.0002,  0.0001],
         [-0.0049,  0.0032,  0.0032,  ...,  0.0035, -0.0020, -0.0014],
         ...,
         [-0.0017,  0.0011,  0.0011,  ...,  0.0013, -0.0007, -0.0005],
         [-0.0005,  0.0003,  0.0003,  ...,  0.0004, -0.0002, -0.0001],
         [-0.0043,  0.0028,  0.0028,  ...,  0.0032, -0.0019, -0.0011]]),
 tensor([ 8.0603e-03,  1.3489e-03, -1.5124e-02,  2.5406e-03, -1.3403e-02,
         -5.6139e-03, -1.1616e-02,  7.7399e-03,  1.2475e-02,  7.3855e-03,
         -2.7034e-03,  1.4998e-03,  6.3677e-03,  7.0324e-03, -1.0463e-02,
          6.5658e-03,  9.1537e-04,  5.8868e-03, -4.9838e-04,  1.1895e-03,
         -1.6767e-02,  8.7264e-04,  1.0641e-02,  9.3373e-04, -1.0958e-04,
         -2.4146e-03,  1.1429e-02, -2.5398e-03,  6.6740e-03,  8.7339e-03,
          1.6180e-03, -3.7155e-04,  6.2519e-03,  6.1777e-03,  1.1338e-03,
         -9.2085e-04, -6.2459e-03, -4.97

# Train

In [44]:
from allennlp.training.trainer import Trainer

In [45]:
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [46]:
training_options = {
    # TODO: Add appropriate learning rate scheduler
    "should_log_parameter_statistics": True,
    "should_log_learning_rate": True,
    "patience": 3,
    "num_epochs": 10,
}

In [47]:
trainer = Trainer(
    model=model,
    optimizer=optimizer,
    iterator=iterator,
    train_dataset=train_ds,
    validation_dataset=val_ds,
    serialization_dir=DATA_ROOT / "ckpts" / RUN_ID,
    **training_options,
)

In [48]:
trainer.train()

01/18/2019 15:19:38 - INFO - allennlp.training.trainer -   Beginning training.
01/18/2019 15:19:38 - INFO - allennlp.training.trainer -   Epoch 0/9
01/18/2019 15:19:38 - INFO - allennlp.training.trainer -   Peak CPU memory usage MB: 1180.868608
01/18/2019 15:19:38 - INFO - allennlp.training.trainer -   Training
accuracy: 0.7677, loss: 0.4799 ||: 100%|██████████| 109/109 [09:06<00:00,  6.18s/it]
01/18/2019 15:28:45 - INFO - allennlp.training.trainer -   Validating
accuracy: 0.8360, loss: 0.3880 ||: 100%|██████████| 14/14 [00:57<00:00,  5.25s/it]
01/18/2019 15:29:43 - INFO - allennlp.training.trainer -                     Training |  Validation
01/18/2019 15:29:43 - INFO - allennlp.training.trainer -   loss          |     0.480  |     0.388
01/18/2019 15:29:43 - INFO - allennlp.training.trainer -   accuracy      |     0.768  |     0.836
01/18/2019 15:29:43 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1180.869  |       N/A
01/18/2019 15:29:44 - INFO - allennlp.training.trainer 

01/18/2019 16:28:26 - INFO - allennlp.training.trainer -   Training
accuracy: 0.8273, loss: 0.3847 ||: 100%|██████████| 109/109 [08:01<00:00,  3.76s/it]
01/18/2019 16:36:27 - INFO - allennlp.training.trainer -   Validating
accuracy: 0.8222, loss: 0.4070 ||: 100%|██████████| 14/14 [00:55<00:00,  5.02s/it]
01/18/2019 16:37:23 - INFO - allennlp.training.trainer -                     Training |  Validation
01/18/2019 16:37:23 - INFO - allennlp.training.trainer -   loss          |     0.385  |     0.407
01/18/2019 16:37:23 - INFO - allennlp.training.trainer -   accuracy      |     0.827  |     0.822
01/18/2019 16:37:23 - INFO - allennlp.training.trainer -   cpu_memory_MB |  1224.270  |       N/A
01/18/2019 16:37:24 - INFO - allennlp.training.trainer -   Epoch duration: 00:08:57
01/18/2019 16:37:24 - INFO - allennlp.training.trainer -   Estimated training time remaining: 0:19:26
01/18/2019 16:37:24 - INFO - allennlp.training.trainer -   Epoch 8/9
01/18/2019 16:37:24 - INFO - allennlp.trainin

{'peak_cpu_memory_MB': 1224.269824,
 'training_duration': '01:36:04',
 'training_start_epoch': 0,
 'training_epochs': 9,
 'epoch': 9,
 'training_accuracy': 0.8236994219653179,
 'training_loss': 0.39802677273203474,
 'training_cpu_memory_MB': 1224.269824,
 'validation_accuracy': 0.8325688073394495,
 'validation_loss': 0.3854148132460458,
 'best_epoch': 8,
 'best_validation_accuracy': 0.8474770642201835,
 'best_validation_loss': 0.3624090020145689}