This notebook shows a simple demo of how to use `EmbGAMClassifier`. It follows a simple sklearn-style interface, but leverage language models to extract embeddings, so may be slow to run during training. At test time, it converts to a simple linear model, making it extremely fast.

In [1]:
%load_ext autoreload
%autoreload 2
from imodelsx import EmbGAMClassifier
import datasets
import numpy as np

### Load some data
Here, we load some training/validation data from the rotten-tomatoes movie dataset. To make things fast, we restrict our training and testing datasets to only 300 examples.

In [2]:
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))

dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

Using custom data configuration default
Found cached dataset rotten_tomatoes (/home/chansingh/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

Using custom data configuration default
Found cached dataset rotten_tomatoes (/home/chansingh/.cache/huggingface/datasets/rotten_tomatoes/default/1.0.0/40d411e45a6ce3484deed7cc15b82a53dad9a72aafd9f86f8f227134bec5ca46)


  0%|          | 0/3 [00:00<?, ?it/s]

### Fit EmbGAMClassifier
Fitting EmbGAM is a simple function call! EmbGAM takes a few hyperparameters, which you can explore [here](https://csinva.io/emb-gam/).

In [12]:
m = EmbGAMClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2,
    all_ngrams=True, # also use lower-order ngrams
)
m.fit(dset['text'], dset['label'])

initializing model...


Some weights of the model checkpoint at textattack/distilbert-base-uncased-rotten-tomatoes were not used when initializing DistilBertModel: ['classifier.weight', 'pre_classifier.bias', 'classifier.bias', 'pre_classifier.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


calculating embeddings...


100%|██████████| 300/300 [00:01<00:00, 249.14it/s]


training linear model...
caching linear coefs...


100%|██████████| 7463/7463 [00:28<00:00, 261.17it/s]

After caching, coefs_dict_ len 7463





## Interpretation

We now have a linear model of ngrams. The `fit` function above has precomputed the linear coefficients for ngrams it saw during training and saved them to `m.coefs_dict_` Let's take a look at some of them.

In [13]:
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Total ngram coefficients:  7463
Most positive ngrams
	 watchable . 5.99
	 compellingly watchable 5.67
	 poignant and 5.57
	 interesting and 5.49
	 thoughtful and 5.45
	 likableness . 5.39
	 watchable 5.37
	 soulful and 5.35
Most negative ngrams
	 too formulaic -3.57
	 haphazard , -3.5
	 fuzziness . -3.44
	 mess , -3.42
	 apparently reassembled -3.36
	 idiotic and -3.28
	 unimpressively fussy -3.25
	 dumb , -3.25


# Predictions
Now, let's take a look at how we make predictions. This is very fast, as it just uses the precomputed dictionary `m.coefs_dict_`

In [14]:
preds = m.predict(dset['text'])
print('acc_train', np.mean(preds == dset['label']))
preds_proba = m.predict_proba(dset['text'])

acc_train 0.5766666666666667


In [5]:
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

100%|██████████| 300/300 [00:00<00:00, 5479.10it/s]

acc_val 0.6633333333333333





Note: we may want to infer the coefficients for ngrams we didn't see during training. To do this, we call the `cache_linear_coefs` function on the inputs for the test set. This adds the values for the unseen coefficients to the dictionary `m.coefs_dict_`. Then we can call `predict` as before.

In [68]:
m.cache_linear_coefs(dset_val['text'])
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

Some weights of the model checkpoint at textattack/distilbert-base-uncased-rotten-tomatoes were not used when initializing DistilBertModel: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
100%|██████████| 6359/6359 [00:25<00:00, 249.83it/s]


coefs_dict_ len 13748


100%|██████████| 300/300 [00:00<00:00, 11890.41it/s]

acc_val 0.7933333333333333



