7/3/2019

This notebook is for the oracle experiments for predicting what condition the speaker is in. 

The model developed here takes as input the three colors that make up the context (target first) and then predicts whether the colors are from the close, split, or far condition. The model also takes in captions, because that's what the interface uses, but they are unused. It performs with 0.94 macro-F1 score which is ok, but not wonderful. It should be able to acheive scores closer to 100% because of the clear deliniations between the conditions.

In [1]:
# so we can access classes from parent directory
import sys
sys.path.append("..")

In [9]:
from monroe_data import MonroeData, MonroeDataEntry, Color # for loading in training data
import caption_featurizers                              # for getting caption representations
import color_featurizers                                # for getting color representations
from experiment import FeatureHandler                   # for combining caption and color features

from models import PytorchModel, ConditionPredictor, ColorEncoder  # model base that handles training / evaluation

In [14]:
import importlib
import models
importlib.reload(models)
from models import PytorchModel, ConditionPredictor, ColorEncoder  # model base that handles training / evaluation

In [28]:
import torch
import torch.nn as nn
import numpy as np
from sklearn.metrics import classification_report

In [6]:
# get data
train_data = MonroeData("../data/csv/train_corpus_monroe.csv", "../data/entries/train_entries_monroe.pkl")
dev_data = MonroeData("../data/csv/dev_corpus_monroe.csv", "../data/entries/dev_entries_monroe.pkl")

In [7]:
# define feature functions
caption_phi_character = caption_featurizers.CaptionFeaturizer(tokenizer = caption_featurizers.CharacterTokenizer)

# we use a color featurizer because the code was designed to always use both color and text features. These 
# features will just be ignored in the model code
color_phi = color_featurizers.ColorFeaturizer(color_featurizers.color_phi_fourier, "hsv", normalized=True)

In [8]:
# we define a mapping from condition to index from closest to farthest
condition_to_idx = {"close": 0, "split": 1, "far":2}

# condition predictor's target is to predict what color condition the participants were put in
def condition_predictor_target(data_entry):
    return condition_to_idx[data_entry.condition]

# pass in train and dev data, our caption and color feature functions, function for turning an element of our data
# (train or dev) into the target, we don't care about the colors at all, but the feature handler expects them.
# We set randomized_colors to false because our target function shouldn't need to take a color index
feature_handler = FeatureHandler(train_data, dev_data, caption_phi_character, color_phi, target_fn=condition_predictor_target, randomized_colors=False)

In [22]:
# let's also create an oracle: one that just takes the colors (no captions) in the fourier feature space and
# predicts the condition. This should be able to get 100% accuracy. We can use the exact same training data too!

class OracleConditionPredictor(nn.Module):
    
    def __init__(self, color_in_dim, color_hidden_dim, num_conditions=3):
        super(OracleConditionPredictor, self).__init__()
        
        self.color_encoder = ColorEncoder(color_in_dim, color_hidden_dim)
        self.linear = nn.Linear(color_hidden_dim, num_conditions)
        self.logsoftmax = nn.LogSoftmax(dim=1)
        
    def forward(self, colors, caption):
        # caption is ignored because the oracle only accesses colors
        color_encs = self.color_encoder(colors)
        output_lin = self.linear(color_encs)
        outputs = self.logsoftmax(output_lin)
        return outputs

In [11]:
X_train = feature_handler.train_features()
y_train = feature_handler.train_targets()

In [23]:
condition_predictor_oracle = CaptionPredictor(OracleConditionPredictor, optimizer=torch.optim.Adam, lr=0.004, num_epochs=5)
condition_predictor_oracle.init_model(color_in_dim=54, color_hidden_dim=100)

In [31]:
condition_predictor_oracle.fit(X_train, y_train)

---EPOCH 0---
0m 0s (0:0 0.00%) 0.0000
0m 2s (0:1000 7.90%) 0.1482
0m 4s (0:2000 15.79%) 0.1336
0m 6s (0:3000 23.69%) 0.1353
0m 8s (0:4000 31.58%) 0.1248
0m 11s (0:5000 39.48%) 0.1235
0m 13s (0:6000 47.37%) 0.1373
0m 15s (0:7000 55.27%) 0.1420
0m 17s (0:8000 63.17%) 0.1593
0m 19s (0:9000 71.06%) 0.1572
0m 21s (0:10000 78.96%) 0.1622
0m 23s (0:11000 86.85%) 0.1178
0m 26s (0:12000 94.75%) 0.1174
AFTER EPOCH 2999 - AVERAGE VALIDATION LOSS: 0.2010897552172343
---EPOCH 1---
0m 28s (1:0 0.00%) 0.0000
0m 30s (1:1000 7.90%) 0.1285
0m 32s (1:2000 15.79%) 0.1239
0m 34s (1:3000 23.69%) 0.1416
0m 36s (1:4000 31.58%) 0.1136
0m 38s (1:5000 39.48%) 0.1292
0m 40s (1:6000 47.37%) 0.1310
0m 42s (1:7000 55.27%) 0.1249
0m 44s (1:8000 63.17%) 0.1309
0m 46s (1:9000 71.06%) 0.1526
0m 48s (1:10000 78.96%) 0.1621
0m 50s (1:11000 86.85%) 0.1372
0m 52s (1:12000 94.75%) 0.1170
AFTER EPOCH 2999 - AVERAGE VALIDATION LOSS: 0.18238019279638926
---EPOCH 2---
0m 54s (2:0 0.00%) 0.0000
0m 57s (2:1000 7.90%) 0.1368
0m 59

In [32]:
condition_predictor_oracle.save_model("../model/condition_predictor_oracle10_epochs.params")

In [25]:
X_assess = feature_handler.test_features()
y_assess = feature_handler.test_targets()

In [26]:
predictions = condition_predictor_oracle.predict(X_assess)
y_hat = np.argmax(predictions, axis=1)

In [29]:
print(classification_report(y_hat, y_assess, target_names=['close', 'split', 'far']))

              precision    recall  f1-score   support

       close       0.98      0.96      0.97      5267
       split       0.93      0.90      0.91      5369
         far       0.91      0.96      0.93      5034

   micro avg       0.94      0.94      0.94     15670
   macro avg       0.94      0.94      0.94     15670
weighted avg       0.94      0.94      0.94     15670



In [33]:
predictions = condition_predictor_oracle.predict(X_assess)
y_hat = np.argmax(predictions, axis=1)

In [34]:
print(classification_report(y_hat, y_assess, target_names=['close', 'split', 'far']))

              precision    recall  f1-score   support

       close       0.98      0.97      0.98      5223
       split       0.93      0.91      0.92      5311
         far       0.92      0.95      0.94      5136

   micro avg       0.94      0.94      0.94     15670
   macro avg       0.94      0.94      0.94     15670
weighted avg       0.94      0.94      0.94     15670



In [36]:
sum(y_assess == y_hat)/len(y_hat) # accuracy

0.9442884492661135