In [4]:
import sys
sys.path.append('../')

In [1]:
import torch

In [18]:
from transformers.modeling_bert import BertConfig, BertEmbeddings, BertModel
from transformers.tokenization_bert import BertTokenizer

from mlpack.bert import BertForNERClassification, BertForSpanNERClassification, BertNERSpanHandler
from mlpack.utils import to_device

In [103]:
pretrained_model = '../models/bert-pretrained-portuguese-model/'
# pretrained_model = 'bert-base-cased'

In [104]:
tokenizer = BertTokenizer.from_pretrained(pretrained_model)

In [105]:
config = BertConfig.from_pretrained(pretrained_model, do_lower_case=False)

In [106]:
model = BertForNERClassification.from_pretrained(pretrained_model, output_hidden_states=True, num_labels=7, pooler='last')

In [107]:
model

BertForNERClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(29794, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_

In [8]:
embeddings = model.get_input_embeddings()

In [32]:
casa_id = tokenizer.convert_tokens_to_ids('Home')
casa_embed = embeddings(torch.tensor(casa_id))

In [33]:
apartamento_id = tokenizer.convert_tokens_to_ids('House')
apartamento_embed = embeddings(torch.tensor(apartamento_id))

In [34]:
casa_apart_dist = torch.dist(casa_embed, apartamento_embed).item()
casa_apart_dist

0.7628178000450134

In [35]:
closer = []
for id in range(28998):
    tensor_id = torch.tensor(id)
    embed = embeddings(tensor_id)
    dist_casa = torch.dist(apartamento_embed, embed).item()
    if dist_casa < casa_apart_dist:
        token = tokenizer.convert_ids_to_tokens(id)
        print(f'Token {token} - dist {dist_casa}')
        closer.append((token, dist_casa))

Token [unused6] - dist 0.7522004842758179
Token [unused7] - dist 0.750287652015686
Token [unused9] - dist 0.7610279321670532
Token [unused15] - dist 0.7504995465278625
Token [unused19] - dist 0.7601041197776794
Token [unused26] - dist 0.746816098690033
Token [unused28] - dist 0.7613805532455444
Token [unused29] - dist 0.7509776949882507
Token [unused36] - dist 0.749943196773529
Token [unused38] - dist 0.7584313154220581
Token [unused42] - dist 0.7573304772377014
Token [unused59] - dist 0.7622300386428833
Token [unused70] - dist 0.7529243230819702
Token [unused75] - dist 0.7440377473831177
Token [unused76] - dist 0.7616509199142456
Token [unused78] - dist 0.7521815896034241
Token [unused81] - dist 0.760960578918457
Token [unused82] - dist 0.7566332817077637
Token [unused84] - dist 0.7580102682113647
Token [SEP] - dist 0.7325114011764526
Token [MASKC] - dist 0.75566565990448
Token " - dist 0.75141841173172
Token $ - dist 0.7575690150260925
Token & - dist 0.7463468909263611
Token , - dist

Token Of - dist 0.7562370300292969
Token hear - dist 0.7578892111778259
Token ##P - dist 0.7415839433670044
Token Western - dist 0.7565302848815918
Token morning - dist 0.758612871170044
Token Minister - dist 0.757713794708252
Token 1988 - dist 0.7599549293518066
Token ##tion - dist 0.758813202381134
Token Black - dist 0.7549208998680115
Token Canadian - dist 0.7627585530281067
Token Spanish - dist 0.7617514133453369
Token replaced - dist 0.7617276906967163
Token Christian - dist 0.7600794434547424
Token whose - dist 0.7530258297920227
Token doesn - dist 0.7607004046440125
Token ##ur - dist 0.7472218871116638
Token beginning - dist 0.7620699405670166
Token Records - dist 0.7454589605331421
Token ##L - dist 0.7442261576652527
Token aircraft - dist 0.7594611048698425
Token ##ble - dist 0.7494584918022156
Token Act - dist 0.7354598641395569
Token court - dist 0.745551586151123
Token ##ge - dist 0.7515998482704163
Token countries - dist 0.7472957372665405
Token river - dist 0.7438541650772

Token handle - dist 0.7514903545379639
Token violence - dist 0.7623521089553833
Token wrapped - dist 0.7578823566436768
Token tiny - dist 0.7503054141998291
Token expansion - dist 0.737644612789154
Token Nothing - dist 0.7536083459854126
Token Main - dist 0.7485902905464172
Token inch - dist 0.757022500038147
Token Saturday - dist 0.7264042496681213
Token Senior - dist 0.7478981614112854
Token representing - dist 0.755411684513092
Token Kansas - dist 0.7539322376251221
Token ##chi - dist 0.7602689862251282
Token Dan - dist 0.761563241481781
Token streets - dist 0.7600846886634827
Token pp - dist 0.7595121264457703
Token 54 - dist 0.740977942943573
Token Philippines - dist 0.7523927092552185
Token 1900 - dist 0.7604998350143433
Token Khan - dist 0.7482261061668396
Token signal - dist 0.7501083016395569
Token sword - dist 0.7620552182197571
Token cool - dist 0.7627732157707214
Token engaged - dist 0.7448939681053162
Token rail - dist 0.758895754814148
Token successfully - dist 0.74074476

Token agriculture - dist 0.7600550651550293
Token viewed - dist 0.7568116784095764
Token Wolf - dist 0.7552013993263245
Token associate - dist 0.7499995827674866
Token Chile - dist 0.7400230169296265
Token detail - dist 0.7600547671318054
Token attorney - dist 0.7587430477142334
Token boyfriend - dist 0.7413733005523682
Token Spencer - dist 0.7621419429779053
Token ##ig - dist 0.7502814531326294
Token demolished - dist 0.7565109729766846
Token Lisa - dist 0.7558459639549255
Token Coach - dist 0.7589486241340637
Token forcing - dist 0.7565313577651978
Token Dream - dist 0.7428439855575562
Token cargo - dist 0.74570631980896
Token Murphy - dist 0.740909218788147
Token ##master - dist 0.7274851202964783
Token doorway - dist 0.749448835849762
Token traded - dist 0.7540512681007385
Token Sure - dist 0.7616360187530518
Token placing - dist 0.7619470357894897
Token attending - dist 0.7437468767166138
Token principles - dist 0.7456985712051392
Token component - dist 0.7496834993362427
Token ma

Token Want - dist 0.7600114941596985
Token ##rine - dist 0.7456742525100708
Token Say - dist 0.7577276825904846
Token besides - dist 0.752362847328186
Token centered - dist 0.758194088935852
Token Engineers - dist 0.7583601474761963
Token tourism - dist 0.7393489480018616
Token announcement - dist 0.7587196826934814
Token recover - dist 0.7499109506607056
Token mainland - dist 0.749853253364563
Token Johann - dist 0.7387924194335938
Token airline - dist 0.7537830471992493
Token Congo - dist 0.7627495527267456
Token Syrian - dist 0.7495896816253662
Token Casey - dist 0.7571022510528564
Token progressive - dist 0.7285419702529907
Token coup - dist 0.7415719032287598
Token cultures - dist 0.7557610869407654
Token ##lessly - dist 0.7567746043205261
Token Sofia - dist 0.7465019226074219
Token trunk - dist 0.7534539103507996
Token Liz - dist 0.7453515529632568
Token Woods - dist 0.7369931936264038
Token Angels - dist 0.7362606525421143
Token broadcasts - dist 0.7436011433601379
Token backwar

Token Joy - dist 0.7596299648284912
Token payments - dist 0.7595782279968262
Token ##dam - dist 0.7469829320907593
Token Hood - dist 0.7459940314292908
Token Mad - dist 0.7341384291648865
Token Ari - dist 0.7576184868812561
Token ##uts - dist 0.758640468120575
Token molecules - dist 0.7509202361106873
Token strictly - dist 0.7576670050621033
Token ##rg - dist 0.7487295269966125
Token ##mba - dist 0.7419440746307373
Token grabbing - dist 0.759154200553894
Token receptor - dist 0.7441903352737427
Token blown - dist 0.7615011930465698
Token Listen - dist 0.7282482385635376
Token resolve - dist 0.7576540112495422
Token Shannon - dist 0.731082558631897
Token overwhelming - dist 0.7514384984970093
Token ##rice - dist 0.7483528852462769
Token 170 - dist 0.7544373273849487
Token fury - dist 0.7459710836410522
Token nerves - dist 0.7585766911506653
Token rectangular - dist 0.7329514026641846
Token sworn - dist 0.7239536643028259
Token Animal - dist 0.7262796759605408
Token Southwest - dist 0.75

Token loses - dist 0.7616186738014221
Token ##rade - dist 0.755206823348999
Token ##tead - dist 0.7587268352508545
Token cruiser - dist 0.7519554495811462
Token ##psy - dist 0.7557570338249207
Token Garion - dist 0.7615184187889099
Token 1817 - dist 0.7555610537528992
Token publishes - dist 0.7425903677940369
Token Gates - dist 0.7459990978240967
Token Worcester - dist 0.7466804385185242
Token Belarus - dist 0.7411664128303528
Token radius - dist 0.7535321712493896
Token Marathon - dist 0.7305841445922852
Token 1819 - dist 0.7454259991645813
Token affection - dist 0.743445098400116
Token ##bro - dist 0.7369738221168518
Token recreation - dist 0.7465788125991821
Token jewelry - dist 0.7483558654785156
Token ##logue - dist 0.7416423559188843
Token ##rop - dist 0.7601491212844849
Token 117 - dist 0.7617184519767761
Token ##II - dist 0.757817268371582
Token slower - dist 0.7418792843818665
Token Volunteer - dist 0.7617170810699463
Token celebrations - dist 0.7560518980026245
Token regards 

Token Letter - dist 0.7624093294143677
Token fare - dist 0.7616270184516907
Token traders - dist 0.7544752359390259
Token ##mere - dist 0.7542606592178345
Token Fortune - dist 0.7593047618865967
Token cooperative - dist 0.7552448511123657
Token ##him - dist 0.7540717720985413
Token sucking - dist 0.7595446109771729
Token ##NP - dist 0.7587978839874268
Token Elijah - dist 0.7571523189544678
Token armour - dist 0.7466819286346436
Token 132 - dist 0.7543515563011169
Token Author - dist 0.7340481877326965
Token Soldier - dist 0.7615483403205872
Token il - dist 0.7509771585464478
Token ##bat - dist 0.7402808666229248
Token ##lord - dist 0.7394492030143738
Token ##runner - dist 0.7479046583175659
Token ##tles - dist 0.7451634407043457
Token ##body - dist 0.7592694759368896
Token Livingston - dist 0.7481658458709717
Token trends - dist 0.7451572418212891
Token Melody - dist 0.7573683261871338
Token eighty - dist 0.7487573027610779
Token mock - dist 0.7374382615089417
Token stretches - dist 0.

Token Bride - dist 0.752274751663208
Token ##ience - dist 0.746091365814209
Token Packers - dist 0.7502084970474243
Token ##erry - dist 0.7625519633293152
Token ##stra - dist 0.7583048939704895
Token dispersed - dist 0.7576380372047424
Token expose - dist 0.7383139729499817
Token ##hibition - dist 0.754204511642456
Token orbital - dist 0.7508030533790588
Token ##ević - dist 0.7594588398933411
Token backgrounds - dist 0.7446560859680176
Token replica - dist 0.7469742298126221
Token obstacles - dist 0.7614869475364685
Token vowels - dist 0.7527710795402527
Token surpassed - dist 0.7624675631523132
Token Zambia - dist 0.7529264688491821
Token Alpine - dist 0.7519590258598328
Token Animals - dist 0.7502521872520447
Token digit - dist 0.759819507598877
Token presses - dist 0.7591578960418701
Token posterior - dist 0.7606363892555237
Token ##bas - dist 0.7580209374427795
Token representations - dist 0.747143566608429
Token Paralympics - dist 0.7533887624740601
Token cartridge - dist 0.761356

Token searches - dist 0.7509945631027222
Token VHS - dist 0.7561338543891907
Token Jacqueline - dist 0.752798318862915
Token Redskins - dist 0.7556543946266174
Token desktop - dist 0.7611891031265259
Token ##smos - dist 0.7562421560287476
Token Irwin - dist 0.7515566349029541
Token Process - dist 0.7598098516464233
Token ##bson - dist 0.7571933269500732
Token yielded - dist 0.7614344358444214
Token Flores - dist 0.7467120885848999
Token armament - dist 0.7433210611343384
Token neighbours - dist 0.7482308745384216
Token ##tour - dist 0.7607412338256836
Token ##jet - dist 0.7464047074317932
Token 1774 - dist 0.7541900873184204
Token monasteries - dist 0.736182451248169
Token ##inted - dist 0.7609127163887024
Token ##inae - dist 0.7625336050987244
Token ##lifting - dist 0.747992992401123
Token Sands - dist 0.7545312643051147
Token mast - dist 0.7544474005699158
Token fills - dist 0.7564542293548584
Token ##roma - dist 0.7573080658912659
Token ##dius - dist 0.7562419176101685
Token ##racti

Token launches - dist 0.7574054598808289
Token differed - dist 0.740811824798584
Token ##child - dist 0.7553233504295349
Token confirming - dist 0.743786096572876
Token scaled - dist 0.749045729637146
Token ##resh - dist 0.7524169087409973
Token Understanding - dist 0.7590721845626831
Token asserts - dist 0.7499985098838806
Token tails - dist 0.7571061849594116
Token comparisons - dist 0.7619969248771667
Token 186 - dist 0.759238064289093
Token Libby - dist 0.7548491954803467
Token ledge - dist 0.7459426522254944
Token Missionary - dist 0.7574305534362793
Token Varsity - dist 0.7516738176345825
Token ##nial - dist 0.7481257319450378
Token ##dson - dist 0.7599060535430908
Token stellar - dist 0.7535808682441711
Token superiority - dist 0.7613144516944885
Token Lila - dist 0.7506697177886963
Token Petit - dist 0.7539182901382446
Token distorted - dist 0.7513890266418457
Token ISIL - dist 0.7537394165992737
Token distinctly - dist 0.7495111227035522
Token ##cats - dist 0.7440828084945679


Token Desire - dist 0.7517167925834656
Token Routledge - dist 0.7538386583328247
Token BWF - dist 0.7619349956512451
Token ##stituting - dist 0.7449301481246948
Token Pig - dist 0.7613992094993591
Token Telecommunications - dist 0.7535715699195862
Token Larson - dist 0.747948169708252
Token Wonderland - dist 0.7470545768737793
Token ##nife - dist 0.7571470141410828
Token predominant - dist 0.741470456123352
Token tripped - dist 0.7530940771102905
Token ##ruly - dist 0.7550777196884155
Token hiss - dist 0.7458145022392273
Token ##crat - dist 0.7555007338523865
Token Côte - dist 0.7487894296646118
Token listener - dist 0.753464937210083
Token Santana - dist 0.748290479183197
Token IAAF - dist 0.74591463804245
Token Seneca - dist 0.7585077881813049
Token ##upon - dist 0.7428340315818787
Token ##hop - dist 0.7513964176177979
Token ##sad - dist 0.7430835366249084
Token Them - dist 0.7443279027938843
Token Petra - dist 0.7530657649040222
Token framing - dist 0.7541533708572388
Token ##mata -

Token 315 - dist 0.7617630362510681
Token ##erham - dist 0.737848162651062
Token Demons - dist 0.7612036466598511
Token surfing - dist 0.758048951625824
Token Rana - dist 0.7603890895843506
Token insisting - dist 0.7460370659828186
Token ##jong - dist 0.7331537008285522
Token residues - dist 0.7548360824584961
Token Consider - dist 0.7575878500938416
Token ##elling - dist 0.7433325052261353
Token ##hlete - dist 0.7460285425186157
Token ##biotics - dist 0.7427068948745728
Token ##gat - dist 0.7560054659843445
Token chanting - dist 0.7407026886940002
Token affluent - dist 0.7589202523231506
Token fullback - dist 0.7566543817520142
Token geologic - dist 0.7580081820487976
Token interiors - dist 0.7620987892150879
Token Inverness - dist 0.7273285388946533
Token ##gno - dist 0.7617389559745789
Token 281 - dist 0.7472253441810608
Token leveled - dist 0.7479708790779114
Token migrate - dist 0.7449647784233093
Token Frog - dist 0.7516729235649109
Token Rare - dist 0.7408618330955505
Token ##io

Token clearer - dist 0.7379012703895569
Token ##narl - dist 0.7523199319839478
Token Print - dist 0.7417882680892944
Token automation - dist 0.7626647353172302
Token flashback - dist 0.7555941939353943
Token Casimir - dist 0.748885452747345
Token differentiated - dist 0.7552894353866577
Token ##aks - dist 0.7534441351890564
Token Evaluation - dist 0.7383720278739929
Token Kobe - dist 0.7544509172439575
Token Wet - dist 0.7623026967048645
Token bandage - dist 0.7564948797225952
Token ##pipe - dist 0.7400298118591309
Token McIntyre - dist 0.7420646548271179
Token flopped - dist 0.7619701623916626
Token reassure - dist 0.7502260208129883
Token sincerity - dist 0.7598525285720825
Token ##teacher - dist 0.7442008852958679
Token ##idated - dist 0.755143404006958
Token ##nast - dist 0.7571526765823364
Token 1690 - dist 0.7463250756263733
Token Bai - dist 0.739280104637146
Token expands - dist 0.7564147710800171
Token Britannia - dist 0.7513495683670044
Token Steps - dist 0.7478920221328735
To

In [36]:
len(closer)

8136

In [37]:
sorted(closer, key=lambda x:x[1])[1:11]

[('Psychological', 0.7049500346183777),
 ('##マ', 0.7090457677841187),
 ('considered', 0.7103508114814758),
 ('##ics', 0.7105996012687683),
 ('##fuse', 0.7110041379928589),
 ('Э', 0.7111547589302063),
 ('commented', 0.7129557132720947),
 ('steals', 0.7133222222328186),
 ('THE', 0.7150970697402954),
 ('embroidered', 0.7157056927680969)]

# Checking compositions

In [50]:
word1 = torch.tensor(tokenizer.convert_tokens_to_ids('king'))
word2 = torch.tensor(tokenizer.convert_tokens_to_ids('man'))
word3 = torch.tensor(tokenizer.convert_tokens_to_ids('woman'))
word4 = torch.tensor(tokenizer.convert_tokens_to_ids('queen'))

In [51]:
emb1 = embeddings(word1)
emb2 = embeddings(word2)
emb3 = embeddings(word3)
emb4 = embeddings(word4)

In [53]:
emb_final = emb1 - emb2 + emb3

In [58]:
torch.dist(emb2, emb4)

tensor(0.7867, grad_fn=<DistBackward>)

In [42]:
closer = []
for id in range(28998):
    tensor_id = torch.tensor(id)
    token = tokenizer.convert_ids_to_tokens(id)
    embed = embeddings(tensor_id)
    dist = torch.dist(emb_final, embed).item()
    closer.append((token, dist))

In [44]:
sorted(closer, key=lambda x:x[1])[:20]

[('king', 0.7744460105895996),
 ('woman', 0.7895422577857971),
 ('Falls', 1.0190023183822632),
 ('ல', 1.0250002145767212),
 ('Sergeant', 1.0281301736831665),
 ('plates', 1.028336524963379),
 ('remind', 1.0291839838027954),
 ('Illinois', 1.0295065641403198),
 ('##ulatory', 1.0297354459762573),
 ('seasons', 1.0298035144805908),
 ('##pear', 1.0299683809280396),
 ('##ke', 1.0304064750671387),
 ('raiding', 1.0307337045669556),
 ('gasp', 1.0307921171188354),
 ('okay', 1.0311487913131714),
 ('##iq', 1.0315492153167725),
 ('similarly', 1.0319414138793945),
 ('nominees', 1.0321558713912964),
 ('accurately', 1.0326005220413208),
 ('ф', 1.032941460609436)]

# Handler

In [108]:
LABELS = [
    'O',
    'B-PER',
    'I-PER',
    'B-ORG',
    'I-ORG',
    'B-LOC',
    'I-LOC',
]

In [125]:
handler = BertNERSpanHandler(model, tokenizer, LABELS)

In [126]:
features, word_offset = handler.create_feature_from_text('Minha tia mora numa casa. Meu tio mora num apartamento.')

In [127]:
feat = features[0]

In [128]:
input_ids, input_mask, label_mask = handler._prepare_tensors(feat)
handler.model.eval()
with torch.no_grad():
    out = handler.model.predict_logits(input_ids, input_mask)

In [129]:
out.shape

torch.Size([1, 512, 7])

In [130]:
tokens = tokenizer.convert_ids_to_tokens(feat.input_ids)
tokens[5], tokens[11]

('casa', 'apartamento')

In [131]:
emb1 = out[0][5].to('cpu')
emb2 = out[0][11].to('cpu')

In [132]:
torch.dist(emb1, emb2)

tensor(0.3247)

In [133]:
tokens[5], tokens[2]

('casa', 'tia')

In [134]:
emb2 = out[0][2].to('cpu')
torch.dist(emb1, emb2)

tensor(0.7595)

In [135]:
tokens[5], tokens[8]

('casa', 'tio')

In [136]:
emb2 = out[0][8].to('cpu')
torch.dist(emb1, emb2)

tensor(0.7530)

In [139]:
emb1 = out[0][5].to('cpu')
for emb2, tok in zip(out[0], tokens):
    print(torch.dist(emb1, emb2.to('cpu')), tok)

tensor(0.4138) [CLS]
tensor(0.6642) minha
tensor(0.7595) tia
tensor(0.4471) mora
tensor(0.4626) numa
tensor(0.) casa
tensor(0.6063) .
tensor(0.7590) meu
tensor(0.7530) tio
tensor(0.4535) mora
tensor(0.5799) num
tensor(0.3247) apartamento
tensor(0.4638) .
tensor(0.4137) [SEP]
tensor(0.3949) [PAD]
tensor(0.4695) [PAD]
tensor(0.6260) [PAD]
tensor(0.5698) [PAD]
tensor(0.4655) [PAD]
tensor(0.4496) [PAD]
tensor(0.6142) [PAD]
tensor(0.5209) [PAD]
tensor(0.4925) [PAD]
tensor(0.5621) [PAD]
tensor(0.6967) [PAD]
tensor(0.5975) [PAD]
tensor(0.4504) [PAD]
tensor(0.3565) [PAD]
tensor(0.4761) [PAD]
tensor(0.5101) [PAD]
tensor(0.4719) [PAD]
tensor(0.4511) [PAD]
tensor(0.5203) [PAD]
tensor(0.5406) [PAD]
tensor(0.6249) [PAD]
tensor(0.5786) [PAD]
tensor(0.6509) [PAD]
tensor(0.5017) [PAD]
tensor(0.5847) [PAD]
tensor(0.5308) [PAD]
tensor(0.6678) [PAD]
tensor(0.4808) [PAD]
tensor(0.4740) [PAD]
tensor(0.4440) [PAD]
tensor(0.5073) [PAD]
tensor(0.4923) [PAD]
tensor(0.4817) [PAD]
tensor(0.4902) [PAD]
tensor(0.4

# Handler

In [64]:
LABELS = [
    'O',
    'B-PER',
    'I-PER',
    'B-ORG',
    'I-ORG',
    'B-LOC',
    'I-LOC',
]

In [65]:
handler = BertNERSpanHandler(model, tokenizer, LABELS)

In [86]:
features, word_offset = handler.create_feature_from_text('My aunt lives in a house. My uncle lives in an mansion.')

In [87]:
feat = features[0]

In [88]:
input_ids, input_mask, label_mask = handler._prepare_tensors(feat)
handler.model.eval()
with torch.no_grad():
    out = handler.model.predict_logits(input_ids, input_mask)

In [89]:
out.shape

torch.Size([1, 512, 7])

In [99]:
tokens = tokenizer.convert_ids_to_tokens(feat.input_ids)
tokens[6], tokens[13]
tokens

['[CLS]',
 'My',
 'aunt',
 'lives',
 'in',
 'a',
 'house',
 '.',
 'My',
 'uncle',
 'lives',
 'in',
 'an',
 'mansion',
 '.',
 '[SEP]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD]',
 '[PAD

In [92]:
emb1 = out[0][6].to('cpu')
emb2 = out[0][13].to('cpu')

In [93]:
torch.dist(emb1, emb2)

tensor(0.4014)

In [94]:
tokens[6], tokens[2]

('house', 'aunt')

In [95]:
emb2 = out[0][2].to('cpu')
torch.dist(emb1, emb2)

tensor(0.7505)

In [97]:
tokens[6], tokens[9]

('house', 'uncle')

In [98]:
emb2 = out[0][9].to('cpu')
torch.dist(emb1, emb2)

tensor(0.8965)

In [102]:
emb1 = out[0][2].to('cpu')
for emb2 in out[0]:
    print(torch.dist(emb1, emb2.to('cpu')))

tensor(0.7361)
tensor(1.0668)
tensor(0.)
tensor(0.8373)
tensor(0.8490)
tensor(0.4706)
tensor(0.7505)
tensor(0.7663)
tensor(0.6523)
tensor(0.2612)
tensor(0.7719)
tensor(0.8929)
tensor(0.6274)
tensor(0.6020)
tensor(0.7211)
tensor(1.4901)
tensor(0.6175)
tensor(0.5478)
tensor(0.6123)
tensor(0.4432)
tensor(0.4535)
tensor(0.3888)
tensor(0.3829)
tensor(0.5379)
tensor(0.6086)
tensor(0.4358)
tensor(0.2399)
tensor(0.3324)
tensor(0.2748)
tensor(0.3250)
tensor(0.4043)
tensor(0.4430)
tensor(0.6631)
tensor(0.6638)
tensor(0.5308)
tensor(0.2377)
tensor(0.3210)
tensor(0.3785)
tensor(0.5499)
tensor(0.5243)
tensor(0.6112)
tensor(0.5431)
tensor(0.2861)
tensor(0.3131)
tensor(0.3358)
tensor(0.3439)
tensor(0.4903)
tensor(0.4904)
tensor(0.6138)
tensor(0.6640)
tensor(0.4385)
tensor(0.3887)
tensor(0.3115)
tensor(0.3792)
tensor(0.4621)
tensor(0.5415)
tensor(0.5966)
tensor(0.5205)
tensor(0.2807)
tensor(0.3480)
tensor(0.3988)
tensor(0.3287)
tensor(0.6528)
tensor(0.6031)
tensor(0.6115)
tensor(0.6551)
tensor(0.4494)