# Using LM-inspector with a WSD-classifier

Import libraries

In [1]:
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer

from lm_inspect import LanguageModelInspector
from word_sense_disambiguation import bert_encoder, label_encoder, Xval as Xtest, Yval as Ytest



Load trained classifier from binary file.

In [2]:
seq = torch.nn.Sequential(
            bert_encoder,
            torch.nn.Dropout(0.2),
            torch.nn.Linear(bert_encoder.output_size, out_features=358)
        )
state_dict = torch.load('models/KB-bert-swedish-cased-wsd.pt')
seq.load_state_dict(state_dict)

<All keys matched successfully>

Load transformers config and tokenizer

In [3]:
config = AutoConfig.from_pretrained('KB/bert-base-swedish-cased',
                                    output_hidden_states=True,
                                    output_attentions=True
                                    )
tokenizer = AutoTokenizer.from_pretrained('KB/bert-base-swedish-cased', config=config)

Get the positions of the ambigious words

In [4]:
input_ids = [x['pos'] for x in Xtest]

Initialize LM-inspector object and set the configuration.

In [5]:
inspector = LanguageModelInspector(seq, Xtest, Ytest, tokenizer, label_encoder)
inspector.configure(label='gälla_1_1', layer=[0, 6, 11], head=[0,6], correct_only = True, input_id=input_ids)

Evaluating data
 1530  /  1525
Done.


<lm_inspect.inspect.LanguageModelInspector at 0x7fbbdc592780>

Apply method on the current configuration

In [6]:
inspector.topk_most_attended_to(k=100, return_type="all")

[('Det', 0.2818432152271271),
 ('att', 0.155453160405159),
 ('gäller', 0.14438840746879578),
 ('[UNK]', 0.05462219566106796),
 ('således', 0.03643447160720825),
 ('lätt', 0.02866438589990139),
 ('Därför', 0.012868179939687252),
 ('på', 0.009184379130601883),
 ('skriver', 0.008755034767091274),
 ('om', 0.008551124483346939),
 ('-', 0.008212991058826447),
 ('är', 0.007640307303518057),
 ('debatten', 0.0075257159769535065),
 ('sätt', 0.0071131582371890545),
 ('Ett', 0.006638072431087494),
 ('Intresset', 0.00587509386241436),
 ('som', 0.00586360227316618),
 ('kan', 0.0056860256008803844),
 ('mycket', 0.005482356064021587),
 ('tilltänkta', 0.005444932263344526),
 ('för', 0.005351259373128414),
 ('ta', 0.005305282771587372),
 ('en', 0.005230214446783066),
 ('försämrad', 0.005061602219939232),
 ('kamp', 0.005042547360062599),
 ('fördomar', 0.004798736423254013),
 (',', 0.0046339742839336395),
 ('före', 0.004595036618411541),
 ('sig', 0.004476252943277359),
 ('testa', 0.004403362981975079),
 (

Visualize results scope-wise.

In [7]:
inspector.topk_most_attended_to(k=8, return_type="scope", visualize=True)

{'indices': [[[1, 48, 68, 2200, 39875, 1048, 19, 2861],
   [160, 945, 1, 2200, 48, 1048, 6036, 408]],
  [[1, 1048, 48, 6406, 100, 160, 59, 54],
   [48, 6036, 1, 1048, 160, 945, 2847, 19946]],
  [[945, 160, 48, 6036, 2861, 692, 1, 864],
   [945, 160, 48, 6036, 2861, 864, 52, 178]]],
 'values': [[[0.09006540477275848,
    0.043870627880096436,
    0.03544776886701584,
    0.03315487504005432,
    0.027392113581299782,
    0.0218181274831295,
    0.016506437212228775,
    0.016503773629665375],
   [0.9955369234085083,
    0.0018758677178993821,
    0.0010604319395497441,
    0.00022522658400703222,
    0.00022192652977537364,
    0.00012656854232773185,
    8.965009328676388e-05,
    8.707543747732416e-05]],
  [[0.09289173781871796,
    0.07521997392177582,
    0.062305331230163574,
    0.0337488017976284,
    0.03044206276535988,
    0.02488088794052601,
    0.021416064351797104,
    0.020704317837953568],
   [0.5778059363365173,
    0.13729195296764374,
    0.12334316223859787,
    0.06