In [1]:
from iit import generate_data
from transformers import BertModel, BertTokenizer
from model import TorchDeepNeuralClassifier
import numpy as np
import torch
import vsm
from sklearn.metrics import classification_report

In [2]:
vals = ['zero', 'one', 'two', 'three', 'four',
       'five', 'six', 'seven', 'eight', 'nine']

train_test_split = 0.9
X_train, y_train, X_test, y_test = generate_data(vals, train_test_split)

In [6]:
len(X_train)

900

In [3]:
X_train[:5]

[['four', 'nine', 'four'],
 ['nine', 'nine', 'zero'],
 ['six', 'one', 'one'],
 ['seven', 'three', 'four'],
 ['five', 'nine', 'three']]

In [4]:
y_train[:5]

[52, 81, 12, 49, 60]

In [5]:
X_test[:5]

[['eight', 'six', 'five'],
 ['six', 'nine', 'two'],
 ['two', 'one', 'zero'],
 ['nine', 'five', 'five'],
 ['two', 'five', 'five']]

Training a feed-forward network using randomized embeddings.

In [7]:
output_size = 163
num_inputs = 3
num_layers = 2
embed_dim = 5

mod = TorchDeepNeuralClassifier(vals, output_size, num_inputs,
                            num_layers, embed_dim, None, False)

In [8]:
mod.fit(X_train, y_train)

Finished epoch 1000 of 1000; error is 0.10345329344272614

TorchDeepNeuralClassifier(
	batch_size=1028,
	max_iter=1000,
	eta=0.001,
	optimizer_class=<class 'torch.optim.adam.Adam'>,
	l2_strength=0,
	gradient_accumulation_steps=1,
	max_grad_norm=None,
	validation_fraction=0.1,
	early_stopping=False,
	n_iter_no_change=10,
	warm_start=False,
	tol=1e-05,
	hidden_dim=50,
	hidden_activation=Tanh(),
	num_layers=2)

In [9]:
preds = mod.predict(X_test)

print("\nClassification report:")
print(classification_report(y_test, preds))


Classification report:
              precision    recall  f1-score   support

           0       0.88      1.00      0.93         7
           1       0.00      0.00      0.00         1
           2       0.00      0.00      0.00         2
           3       0.00      0.00      0.00         1
           5       0.00      0.00      0.00         1
           6       0.67      1.00      0.80         2
           8       0.67      0.50      0.57         4
          10       0.67      0.67      0.67         3
          12       0.50      1.00      0.67         2
          13       1.00      1.00      1.00         1
          14       1.00      1.00      1.00         1
          15       0.67      1.00      0.80         2
          16       1.00      0.67      0.80         3
          18       1.00      1.00      1.00         2
          20       1.00      1.00      1.00         2
          21       1.00      1.00      1.00         2
          22       1.00      1.00      1.00         1
   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Training a feed-forward network using BERT embeddings.

In [10]:
bert_weights_name = 'bert-base-uncased'
# Initialize a BERT tokenizer and BERT model based on
# `bert_weights_name`:
bert_tokenizer = BertTokenizer.from_pretrained(bert_weights_name)
bert_model = BertModel.from_pretrained(bert_weights_name)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [11]:
bert_embed = vsm.create_subword_pooling_vsm(
    vals, bert_tokenizer, bert_model, layer=1, pool_func=vsm.mean_pooling)

In [12]:
bert_embed

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,758,759,760,761,762,763,764,765,766,767
zero,0.368413,0.714821,-0.532399,-0.153238,-0.184203,0.009934,0.028352,-0.162773,0.163669,-0.471809,...,0.63745,-0.782,0.181008,-1.160265,0.005396,-0.899586,0.350256,-0.099035,-0.274523,0.308869
one,0.213226,0.484864,-0.032716,-0.026842,0.090642,-0.086201,0.195947,-0.16962,0.540456,-0.261407,...,0.455036,-0.426076,0.282586,-0.676856,-0.045337,-0.42813,0.227807,0.216206,0.112196,-0.128516
two,-0.121083,0.16106,-0.549375,-0.472711,-0.146909,0.155344,-0.13319,-0.483241,0.173656,-0.328344,...,0.30386,0.224042,0.113168,-0.807975,-0.281597,-0.575383,0.138091,0.198803,-0.091571,-0.442494
three,0.104971,0.313852,-0.34021,-0.434407,-0.049186,0.154321,-0.099751,-0.506876,0.389955,-0.244507,...,0.379173,0.134788,0.23369,-0.651362,-0.219017,-0.658988,0.203868,0.215337,-0.093137,-0.477617
four,0.039704,0.210273,-0.526674,-0.221241,-0.102211,0.203096,-0.147871,-0.331833,0.267574,-0.48117,...,0.240441,-0.040645,0.122894,-0.991067,0.013572,-0.848153,0.368776,0.249214,-0.265129,-0.095733
five,-0.039133,0.148628,-0.443105,-0.602812,0.024338,-0.077868,-0.239623,-0.467639,0.117604,-0.661229,...,0.410685,-0.126665,0.254384,-0.834398,-0.037767,-0.695847,0.18872,0.23778,-0.017827,-0.366614
six,0.304316,0.202454,-0.50693,-0.336193,0.042346,-0.201237,-0.326339,-0.235801,0.244673,-0.600611,...,0.417288,-0.329488,0.127992,-0.932242,-0.041394,-0.668947,-0.093982,0.385106,0.048364,-0.338787
seven,0.092242,0.176112,-0.475537,-0.412275,0.071916,-0.310987,-0.04849,-0.409864,0.006404,-0.807861,...,0.095583,-0.178779,0.421195,-0.693458,0.123755,-0.667253,0.076955,-0.13888,-0.079268,-0.244648
eight,-0.033224,0.192466,-0.507554,-0.419275,0.235861,0.105365,-0.25565,-0.2007,-0.053316,-0.733084,...,0.2336,-0.17198,0.28649,-0.620115,-0.104281,-0.58495,0.222393,0.159457,-0.355556,-0.570065
nine,0.105645,0.060501,-0.367164,-0.43336,0.425262,-0.032186,-0.136358,-0.358969,0.12983,-0.574778,...,0.081581,-0.280892,0.540993,-0.752147,-0.059858,-0.446215,0.033923,-0.036796,-0.296937,-0.314274


In [13]:
output_size = 163
num_inputs = 3
num_layers = 2
embed_dim = bert_embed.shape[1]
freeze_embedding = True

mod_bert_embed = TorchDeepNeuralClassifier(vals, output_size, num_inputs,
                            num_layers, embed_dim, bert_embed,
                            freeze_embedding)

In [15]:
mod_bert_embed.fit(X_train, y_train)

Finished epoch 1000 of 1000; error is 0.04465832561254501

TorchDeepNeuralClassifier(
	batch_size=1028,
	max_iter=1000,
	eta=0.001,
	optimizer_class=<class 'torch.optim.adam.Adam'>,
	l2_strength=0,
	gradient_accumulation_steps=1,
	max_grad_norm=None,
	validation_fraction=0.1,
	early_stopping=False,
	n_iter_no_change=10,
	warm_start=False,
	tol=1e-05,
	hidden_dim=50,
	hidden_activation=Tanh(),
	num_layers=2)

In [16]:
preds = mod_bert_embed.predict(X_test)

print("\nClassification report:")
print(classification_report(y_test, preds))


Classification report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         7
           1       0.00      0.00      0.00         1
           2       0.00      0.00      0.00         2
           3       0.25      1.00      0.40         1
           5       1.00      1.00      1.00         1
           6       0.33      0.50      0.40         2
           8       1.00      0.50      0.67         4
          10       1.00      1.00      1.00         3
          12       1.00      0.50      0.67         2
          13       1.00      1.00      1.00         1
          14       0.33      1.00      0.50         1
          15       1.00      1.00      1.00         2
          16       0.67      0.67      0.67         3
          18       1.00      0.50      0.67         2
          20       1.00      1.00      1.00         2
          21       1.00      1.00      1.00         2
          22       1.00      1.00      1.00         1
   

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
