# Imports

In [1]:
import neurox.data.extraction.transformers_extractor as transformers_extractor
import neurox.data.loader as data_loader
import neurox.interpretation.utils as utils
import neurox.interpretation.linear_probe as linear_probe

# Data 

In [2]:
# define paths to your word and label files 
train_sentences = "data/pos_train.word"
train_labels = "data/pos_train.label"
dev_sentences = "data/pos_dev.word"
dev_labels = "data/pos_dev.label" 
test_sentences = "data/pos_test.word"
test_labels = "data/pos_test.label"

In [3]:
!cat "data/pos_train.word"

Pierre Vinken , 61 years old , will join the board as a nonexecutive director Nov. 29 .
Mr. Vinken is chairman of Elsevier N.V. , the Dutch publishing group .
Rudolph Agnew , 55 years old and former chairman of Consolidated Gold Fields PLC , was named a nonexecutive director of this British industrial conglomerate .
A form of asbestos once used to make Kent cigarette filters has caused a high percentage of cancer deaths among a group of workers exposed to it more than 30 years ago , researchers reported .
The asbestos fiber , crocidolite , is unusually resilient once it enters the lungs , with even brief exposures to it causing symptoms that show up decades later , researchers said .
Lorillard Inc. , the unit of New York-based Loews Corp. that makes Kent cigarettes , stopped using crocidolite in its Micronite cigarette filters in 1956 .
Although preliminary findings were reported more than a year ago , the latest results appear in today 's New England Journal of Medicine , a foru

In [4]:
!cat "data/pos_train.label"

NNP NNP , CD NNS JJ , MD VB DT NN IN DT JJ NN NNP CD .
NNP NNP VBZ NN IN NNP NNP , DT NNP VBG NN .
NNP NNP , CD NNS JJ CC JJ NN IN NNP NNP NNP NNP , VBD VBN DT JJ NN IN DT JJ JJ NN .
DT NN IN NN RB VBN TO VB NNP NN NNS VBZ VBN DT JJ NN IN NN NNS IN DT NN IN NNS VBN TO PRP RBR IN CD NNS IN , NNS VBD .
DT NN NN , NN , VBZ RB JJ IN PRP VBZ DT NNS , IN RB JJ NNS TO PRP VBG NNS WDT VBP RP NNS JJ , NNS VBD .
NNP NNP , DT NN IN JJ JJ NNP NNP WDT VBZ NNP NNS , VBD VBG NN IN PRP$ NN NN NNS IN CD .
IN JJ NNS VBD VBN RBR IN DT NN IN , DT JJS NNS VBP IN NN POS NNP NNP NNP IN NNP , DT NN JJ TO VB JJ NN TO DT NN .
DT NNP NN VBD , `` DT VBZ DT JJ NN .
PRP VBP VBG IN NNS IN IN NN VBD IN NN VBG DT JJ NNS .
EX VBZ DT NN IN PRP$ NNS RB . ''


# Extract Representations

In [5]:
# extract activations for the train sentences 
transformers_extractor.extract_representations('bert-base-uncased',
    train_sentences,
    'train_activations.json',
    aggregation="average" #last, first
)

Loading model: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reading input corpus
Preparing output file
Extracting representations from model
Sentence         : "Pierre Vinken , 61 years old , will join the board as a nonexecutive director Nov. 29 ."
Original    (018): ['Pierre', 'Vinken', ',', '61', 'years', 'old', ',', 'will', 'join', 'the', 'board', 'as', 'a', 'nonexecutive', 'director', 'Nov.', '29', '.']
Tokenized   (025): ['[CLS]', 'pierre', 'vin', '##ken', ',', '61', 'years', 'old', ',', 'will', 'join', 'the', 'board', 'as', 'a', 'none', '##x', '##ec', '##utive', 'director', 'nov', '.', '29', '.', '[SEP]']
Filtered   (023): ['pierre', 'vin', '##ken', ',', '61', 'years', 'old', ',', 'will', 'join', 'the', 'board', 'as', 'a', 'none', '##x', '##ec', '##utive', 'director', 'nov', '.', '29', '.']
Detokenized (018): ['pierre', 'vin##ken', ',', '61', 'years', 'old', ',', 'will', 'join', 'the', 'board', 'as', 'a', 'none##x##ec##utive', 'director', 'nov.', '29', '.']
Counter: 23
Hidden states:  (13, 18, 768)
# Extracted words:  18
Sentence        

Sentence         : "A Lorillard spokewoman said , `` This is an old story ."
Original    (012): ['A', 'Lorillard', 'spokewoman', 'said', ',', '``', 'This', 'is', 'an', 'old', 'story', '.']
Tokenized   (018): ['[CLS]', 'a', 'lori', '##llar', '##d', 'spoke', '##woman', 'said', ',', '`', '`', 'this', 'is', 'an', 'old', 'story', '.', '[SEP]']
Filtered   (016): ['a', 'lori', '##llar', '##d', 'spoke', '##woman', 'said', ',', '`', '`', 'this', 'is', 'an', 'old', 'story', '.']
Detokenized (012): ['a', 'lori##llar##d', 'spoke##woman', 'said', ',', '``', 'this', 'is', 'an', 'old', 'story', '.']
Counter: 16
Hidden states:  (13, 12, 768)
# Extracted words:  12
Sentence         : "We 're talking about years ago before anyone heard of asbestos having any questionable properties ."
Original    (016): ['We', "'re", 'talking', 'about', 'years', 'ago', 'before', 'anyone', 'heard', 'of', 'asbestos', 'having', 'any', 'questionable', 'properties', '.']
Tokenized   (019): ['[CLS]', 'we', "'", 're', 'talking

In [6]:
# extract activations for the dev sentences 
transformers_extractor.extract_representations('bert-base-uncased',
    dev_sentences,
    'dev_activations.json',
    aggregation="average" #last, first
)

Loading model: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reading input corpus
Preparing output file
Extracting representations from model
Sentence         : "The Arizona Corporations Commission authorized an 11.5 % rate increase at Tucson Electric Power Co. , substantially lower than recommended last month by a commission hearing officer and barely half the rise sought by the utility ."
Original    (037): ['The', 'Arizona', 'Corporations', 'Commission', 'authorized', 'an', '11.5', '%', 'rate', 'increase', 'at', 'Tucson', 'Electric', 'Power', 'Co.', ',', 'substantially', 'lower', 'than', 'recommended', 'last', 'month', 'by', 'a', 'commission', 'hearing', 'officer', 'and', 'barely', 'half', 'the', 'rise', 'sought', 'by', 'the', 'utility', '.']
Tokenized   (042): ['[CLS]', 'the', 'arizona', 'corporations', 'commission', 'authorized', 'an', '11', '.', '5', '%', 'rate', 'increase', 'at', 'tucson', 'electric', 'power', 'co', '.', ',', 'substantially', 'lower', 'than', 'recommended', 'last', 'month', 'by', 'a', 'commission', 'hearing', 'officer', '

Sentence         : "South Korean President Roh Tae Woo , brushing aside suggestions that the won be revalued again , said the currency 's current level against the dollar is `` appropriate . ''"
Original    (031): ['South', 'Korean', 'President', 'Roh', 'Tae', 'Woo', ',', 'brushing', 'aside', 'suggestions', 'that', 'the', 'won', 'be', 'revalued', 'again', ',', 'said', 'the', 'currency', "'s", 'current', 'level', 'against', 'the', 'dollar', 'is', '``', 'appropriate', '.', "''"]
Tokenized   (040): ['[CLS]', 'south', 'korean', 'president', 'ro', '##h', 'tae', 'woo', ',', 'brushing', 'aside', 'suggestions', 'that', 'the', 'won', 'be', 'rev', '##al', '##ue', '##d', 'again', ',', 'said', 'the', 'currency', "'", 's', 'current', 'level', 'against', 'the', 'dollar', 'is', '`', '`', 'appropriate', '.', "'", "'", '[SEP]']
Filtered   (038): ['south', 'korean', 'president', 'ro', '##h', 'tae', 'woo', ',', 'brushing', 'aside', 'suggestions', 'that', 'the', 'won', 'be', 'rev', '##al', '##ue', '##d', 

In [7]:
# extract activations for the test sentences 
transformers_extractor.extract_representations('bert-base-uncased',
    test_sentences,
    'test_activations.json',
    aggregation="average" #last, first
)

Loading model: bert-base-uncased


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Reading input corpus
Preparing output file
Extracting representations from model
Sentence         : "Rockwell International Corp. 's Tulsa unit said it signed a tentative agreement extending its contract with Boeing Co. to provide structural parts for Boeing 's 747 jetliners ."
Original    (028): ['Rockwell', 'International', 'Corp.', "'s", 'Tulsa', 'unit', 'said', 'it', 'signed', 'a', 'tentative', 'agreement', 'extending', 'its', 'contract', 'with', 'Boeing', 'Co.', 'to', 'provide', 'structural', 'parts', 'for', 'Boeing', "'s", '747', 'jetliners', '.']
Tokenized   (036): ['[CLS]', 'rockwell', 'international', 'corp', '.', "'", 's', 'tulsa', 'unit', 'said', 'it', 'signed', 'a', 'tentative', 'agreement', 'extending', 'its', 'contract', 'with', 'boeing', 'co', '.', 'to', 'provide', 'structural', 'parts', 'for', 'boeing', "'", 's', '747', 'jet', '##liner', '##s', '.', '[SEP]']
Filtered   (034): ['rockwell', 'international', 'corp', '.', "'", 's', 'tulsa', 'unit', 'said', 'it', 'signed', '

Sentence         : "In January , he accepted the position of vice chairman of Carlyle Group , a merchant banking concern ."
Original    (019): ['In', 'January', ',', 'he', 'accepted', 'the', 'position', 'of', 'vice', 'chairman', 'of', 'Carlyle', 'Group', ',', 'a', 'merchant', 'banking', 'concern', '.']
Tokenized   (022): ['[CLS]', 'in', 'january', ',', 'he', 'accepted', 'the', 'position', 'of', 'vice', 'chairman', 'of', 'carly', '##le', 'group', ',', 'a', 'merchant', 'banking', 'concern', '.', '[SEP]']
Filtered   (020): ['in', 'january', ',', 'he', 'accepted', 'the', 'position', 'of', 'vice', 'chairman', 'of', 'carly', '##le', 'group', ',', 'a', 'merchant', 'banking', 'concern', '.']
Detokenized (019): ['in', 'january', ',', 'he', 'accepted', 'the', 'position', 'of', 'vice', 'chairman', 'of', 'carly##le', 'group', ',', 'a', 'merchant', 'banking', 'concern', '.']
Counter: 20
Hidden states:  (13, 19, 768)
# Extracted words:  19
Sentence         : "SHEARSON LEHMAN HUTTON Inc ."
Original  

# Train a Linear Probe

In [10]:
activations, num_layers = data_loader.load_activations('train_activations.json', 768)

Loading json activations from train_activations.json...
10 13.0


In [11]:
tokens = data_loader.load_data(train_sentences, train_labels, activations, 512)

In [12]:
X, y, mapping = utils.create_tensors(tokens, activations, 'NN')
label2idx, idx2label, src2idx, idx2src = mapping
probe = linear_probe.train_logistic_regression_probe(X, y, lambda_l1=0.001, lambda_l2=0.001)

Number of tokens:  224
length of source dictionary:  142
length of target dictionary:  29
224
Total instances: 224
['about', 'up', 'Inc.', 'preliminary', 'appear', 'We', 'There', 'Rudolph', 'named', 'said', 'in', 'an', 'now', 'has', '1956', 'causing', 'industrial', 'new', 'Nov.', 'results']
Number of samples:  224
Stats: Labels with their frequencies in the final set
VBD 7
PRP 4
POS 1
. 10
TO 5
RBR 2
VBN 5
VBZ 7
WDT 2
, 15
NN 30
CC 1
`` 1
NNS 19
EX 1
MD 1
VBP 3
'' 1
RP 1
CD 5
JJ 18
NNP 25
JJS 1
RB 4
DT 20
PRP$ 2
VBG 5
IN 25
VB 3
Training classification probe
Creating model...
Number of training instances: 224
Number of classes: 29


epoch [1/10]: 0it [00:00, ?it/s]

Epoch: [1/10], Loss: 0.1108


epoch [2/10]: 0it [00:00, ?it/s]

Epoch: [2/10], Loss: 0.0472


epoch [3/10]: 0it [00:00, ?it/s]

Epoch: [3/10], Loss: 0.0447


epoch [4/10]: 0it [00:00, ?it/s]

Epoch: [4/10], Loss: 0.0408


epoch [5/10]: 0it [00:00, ?it/s]

Epoch: [5/10], Loss: 0.0352


epoch [6/10]: 0it [00:00, ?it/s]

Epoch: [6/10], Loss: 0.0296


epoch [7/10]: 0it [00:00, ?it/s]

Epoch: [7/10], Loss: 0.0248


epoch [8/10]: 0it [00:00, ?it/s]

Epoch: [8/10], Loss: 0.0209


epoch [9/10]: 0it [00:00, ?it/s]

Epoch: [9/10], Loss: 0.0181


epoch [10/10]: 0it [00:00, ?it/s]

Epoch: [10/10], Loss: 0.0160


# Get Top Neurons

In [13]:
top_neurons, top_neurons_per_class = linear_probe.get_top_neurons(probe, 0.01, label2idx)

In [14]:
print(top_neurons_per_class)

{'VBD': array([5539, 3630, 6812, 8643, 4527,   42, 8348, 6126, 3113]), 'PRP': array([4140, 5883, 2093, 3099, 1083, 6836, 2858, 4699, 3528]), 'POS': array([5430, 7572, 4832, 4900]), '.': array([9396, 8628, 4248, 4738, 8219, 4949, 4863]), 'TO': array([4423, 2988, 4709,  589, 1021, 3799, 5349, 5549]), 'RBR': array([9123, 1653, 4527, 5229, 4538, 5899]), 'VBN': array([5848, 4013, 7928, 6307, 7625, 5083, 6018, 3965,  793, 1409]), 'VBZ': array([3659, 7607, 1275,  173, 2188, 3940, 6335, 1463,  430]), 'WDT': array([7817, 4572, 5150,  921, 8830, 5107]), ',': array([8219, 5214, 4137, 2601, 6894, 2756, 3822,  774, 1518, 5060, 3215]), 'NN': array([3459, 6057, 4100, 6579, 5830, 1124, 4057, 5005, 7123, 6305, 1669,
       4604, 4958, 6531, 3489, 9775, 2158, 8128, 3594, 4812]), 'CC': array([ 482, 6862, 6185, 5478]), '``': array([8776, 7240, 7052, 3684]), 'NNS': array([3146, 2016, 8265, 4542, 8684, 5113, 6192,  787, 4789, 1831, 2954,
        931, 9033,  356, 4411, 4021]), 'EX': array([5582, 5069, 8006, 