Skip to content

Latest commit

 

History

History
64 lines (52 loc) · 2.46 KB

readme.md

File metadata and controls

64 lines (52 loc) · 2.46 KB

Augmenting Interpretable Models with LLMs during Training

This repo contains code to reproduce the experiments in the Aug-imodels paper (Nature Communications, 2023). For a simple scikit-learn interface to use Aug-imodels, use the imodelsX library. Below is a quickstart example.

Installation: pip install imodelsx

from imodelsx import AugLinearClassifier, AugTreeClassifier, AugLinearRegressor, AugTreeRegressor
import datasets
import numpy as np

# set up data
dset = datasets.load_dataset('rotten_tomatoes')['train']
dset = dset.select(np.random.choice(len(dset), size=300, replace=False))
dset_val = datasets.load_dataset('rotten_tomatoes')['validation']
dset_val = dset_val.select(np.random.choice(len(dset_val), size=300, replace=False))

# fit model
m = AugLinearClassifier(
    checkpoint='textattack/distilbert-base-uncased-rotten-tomatoes',
    ngrams=2, # use bigrams
)
m.fit(dset['text'], dset['label'])

# predict
preds = m.predict(dset_val['text'])
print('acc_val', np.mean(preds == dset_val['label']))

# interpret
print('Total ngram coefficients: ', len(m.coefs_dict_))
print('Most positive ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1], reverse=True)[:8]:
    print('\t', k, round(v, 2))
print('Most negative ngrams')
for k, v in sorted(m.coefs_dict_.items(), key=lambda item: item[1])[:8]:
    print('\t', k, round(v, 2))

Reference:

@misc{ch2022augmenting,
    title={Augmenting Interpretable Models with LLMs during Training},
    author={Chandan Singh and Armin Askari and Rich Caruana and Jianfeng Gao},
    year={2022},
    eprint={2209.11799},
    archivePrefix={arXiv},
    primaryClass={cs.AI}
}