# B - A Closer Look at Word Embeddings

Source: https://github.com/bentrevett/pytorch-sentiment-analysis/blob/master/B%20-%20A%20Closer%20Look%20at%20Word%20Embeddings.ipynb

In [1]:
import torchtext.vocab

glove = torchtext.vocab.GloVe(name = '6B', dim = 100)

print(f'There are {len(glove.itos)} words in the vocabulary')

.vector_cache/glove.6B.zip: 862MB [06:29, 2.22MB/s]                           
100%|█████████▉| 398697/400000 [00:18<00:00, 22326.00it/s]

There are 400000 words in the vocabulary


In [2]:
glove.vectors.shape

torch.Size([400000, 100])

In [3]:
glove.itos[:10]

['the', ',', '.', 'of', 'to', 'and', 'in', 'a', '"', "'s"]

In [4]:
glove.stoi['the']

0

In [5]:
glove.vectors[glove.stoi['am']].shape

torch.Size([100])

In [0]:
def get_vectors(embeddings,word):
  assert word in embeddings.stoi, f'{word} not in embedding'
  return embeddings.vectors[embeddings.stoi[word]]

In [10]:
get_vectors(glove,'the')

tensor([-0.0382, -0.2449,  0.7281, -0.3996,  0.0832,  0.0440, -0.3914,  0.3344,
        -0.5755,  0.0875,  0.2879, -0.0673,  0.3091, -0.2638, -0.1323, -0.2076,
         0.3340, -0.3385, -0.3174, -0.4834,  0.1464, -0.3730,  0.3458,  0.0520,
         0.4495, -0.4697,  0.0263, -0.5415, -0.1552, -0.1411, -0.0397,  0.2828,
         0.1439,  0.2346, -0.3102,  0.0862,  0.2040,  0.5262,  0.1716, -0.0824,
        -0.7179, -0.4153,  0.2033, -0.1276,  0.4137,  0.5519,  0.5791, -0.3348,
        -0.3656, -0.5486, -0.0629,  0.2658,  0.3020,  0.9977, -0.8048, -3.0243,
         0.0125, -0.3694,  2.2167,  0.7220, -0.2498,  0.9214,  0.0345,  0.4674,
         1.1079, -0.1936, -0.0746,  0.2335, -0.0521, -0.2204,  0.0572, -0.1581,
        -0.3080, -0.4162,  0.3797,  0.1501, -0.5321, -0.2055, -1.2526,  0.0716,
         0.7056,  0.4974, -0.4206,  0.2615, -1.5380, -0.3022, -0.0734, -0.2831,
         0.3710, -0.2522,  0.0162, -0.0171, -0.3898,  0.8742, -0.7257, -0.5106,
        -0.5203, -0.1459,  0.8278,  0.27

# Closest Words

In [0]:
import torch

def closest_words(embeddings, vector, n=10):
  distances = [(word,torch.dist(vector, embeddings.vectors[embeddings.stoi[word]]).item()) for word in embeddings.itos ]
  distances = sorted(distances, key = lambda w: w[1])[:n]
  return distances

In [0]:
def print_tuples(distances):
  for i in range(len(distances)):
    print(f'{distances[i][0]}  ({distances[i][1]})')

In [24]:
print_tuples(closest_words(glove,get_vectors(glove,'bihar')))

bihar  (0.0)
uttar  (2.9150240421295166)
pradesh  (3.3416378498077393)
orissa  (3.469266176223755)
jharkhand  (3.4747467041015625)
andhra  (3.6897246837615967)
odisha  (3.730837345123291)
punjab  (3.8919103145599365)
chhattisgarh  (3.904574394226074)
haryana  (3.9727556705474854)


In [25]:
print_tuples(closest_words(glove,get_vectors(glove,'patna')))

patna  (0.0)
lucknow  (3.612200975418091)
ranchi  (3.615757942199707)
guwahati  (3.8395986557006836)
kanpur  (4.065101623535156)
bhubaneswar  (4.141567230224609)
bihar  (4.1439690589904785)
raipur  (4.148581027984619)
kolkata  (4.247442245483398)
chandigarh  (4.313122272491455)


In [26]:
print_tuples(closest_words(glove,get_vectors(glove,'begusarai')))

begusarai  (0.0)
khagaria  (1.7411344051361084)
khargone  (1.999775767326355)
balrampur  (2.055911064147949)
samastipur  (2.085280656814575)
sitamarhi  (2.0947346687316895)
saharsa  (2.0988593101501465)
supaul  (2.117133140563965)
hardoi  (2.171771287918091)
barwani  (2.203012466430664)


In [33]:
print_tuples(closest_words(glove,get_vectors(glove,'kolkata')))

kolkata  (0.0)
calcutta  (3.1519007682800293)
chennai  (3.326792001724243)
pune  (3.395953416824341)
bangalore  (3.5001564025878906)
ahmedabad  (3.839594841003418)
hyderabad  (4.00040340423584)
jaipur  (4.1450042724609375)
chandigarh  (4.1639556884765625)
patna  (4.247442245483398)


# Analogies

In [0]:
def analogies(embedding, word1, word2, word3):
  embedding1 = embedding.vectors[embedding.stoi[word1]]
  embedding2 = embedding.vectors[embedding.stoi[word2]]
  embedding3 = embedding.vectors[embedding.stoi[word3]]

  candidate_words = closest_words(embedding, embedding2+embedding3-embedding1, n = 13)
  candidate_words = [item for item in candidate_words if item[0] not in [word1,word2,word3]]

  print_tuples(candidate_words)

In [13]:
analogies(glove,'man','king','women')

queen  (6.213318347930908)
monarch  (6.567999362945557)
crown  (6.573651313781738)
kingdom  (6.731592655181885)
commonwealth  (6.757242202758789)
wives  (6.773597717285156)
vii  (6.794887542724609)
throne  (6.834307670593262)
both  (6.846220970153809)
men  (6.866441249847412)
viii  (6.884921550750732)


In [14]:
analogies(glove,'begusarai','patna','varanasi')

lucknow  (5.748076915740967)
allahabad  (5.9805803298950195)
kanpur  (6.119074821472168)
amritsar  (6.2075653076171875)
hyderabad  (6.331844806671143)
jaipur  (6.411177158355713)
ahmedabad  (6.450562953948975)
lahore  (6.480228424072266)
delhi  (6.541592121124268)
chandigarh  (6.556024074554443)
calcutta  (6.5794243812561035)


In [18]:
analogies(glove,'mouse','rodents','crocodile')

crocodiles  (5.169244766235352)
reptiles  (5.428557395935059)
amphibians  (5.666388034820557)
alligators  (5.788325309753418)
carnivores  (5.947257995605469)
lungfish  (5.973628997802734)
mammals  (5.990443229675293)
snakes  (5.997899532318115)
hippopotamuses  (6.031688690185547)
vermin  (6.059043884277344)
fishes  (6.074267864227295)


In [19]:
analogies(glove,'shiva','hindu','jesus')

religious  (5.854363918304443)
christian  (6.0805583000183105)
christians  (6.3269782066345215)
islamic  (6.456313610076904)
muslim  (6.458624839782715)
catholic  (6.532078266143799)
fundamentalist  (6.5643486976623535)
secular  (6.682084560394287)
religion  (6.693620681762695)
jewish  (6.7063469886779785)
orthodox  (6.805178642272949)


# Correcting Spelling Mistakes

In [20]:
glove = torchtext.vocab.GloVe(name = '840B', dim = 300)


.vector_cache/glove.840B.300d.zip: 0.00B [00:00, ?B/s][A
.vector_cache/glove.840B.300d.zip:   0%|          | 8.19k/2.18G [00:00<42:39:46, 14.2kB/s][A
.vector_cache/glove.840B.300d.zip:   0%|          | 49.2k/2.18G [00:00<30:28:02, 19.8kB/s][A
.vector_cache/glove.840B.300d.zip:   0%|          | 221k/2.18G [00:00<21:28:10, 28.2kB/s] [A
.vector_cache/glove.840B.300d.zip:   0%|          | 909k/2.18G [00:00<15:03:36, 40.1kB/s][A
.vector_cache/glove.840B.300d.zip:   0%|          | 3.67M/2.18G [00:01<10:32:16, 57.3kB/s][A
.vector_cache/glove.840B.300d.zip:   0%|          | 9.60M/2.18G [00:01<7:21:33, 81.8kB/s] [A
.vector_cache/glove.840B.300d.zip:   1%|          | 12.6M/2.18G [00:01<5:09:05, 117kB/s] [A
.vector_cache/glove.840B.300d.zip:   1%|          | 18.1M/2.18G [00:01<3:35:59, 167kB/s][A
.vector_cache/glove.840B.300d.zip:   1%|          | 21.5M/2.18G [00:01<2:31:20, 237kB/s][A
.vector_cache/glove.840B.300d.zip:   1%|          | 26.9M/2.18G [00:01<1:45:51, 338kB/s][A
.vector_c

In [24]:
glove.vectors.shape

torch.Size([2196017, 300])

In [26]:
word_vector = get_vectors(glove, 'relieable')
print(word_vector.shape)

print_tuples(closest_words(glove, word_vector))

torch.Size([300])
relieable  (0.0)
relyable  (5.03660249710083)
realible  (5.26101541519165)
realiable  (5.471879482269287)
relable  (5.540150165557861)
relaible  (5.5916619300842285)
reliabe  (5.641175270080566)
relaiable  (5.880186080932617)
stabel  (5.959342956542969)
consitant  (5.998053073883057)


In [0]:
reliable_vector = get_vectors(glove, 'relieable')
reliable_misspellings = ['relieable', 'relyable', 'realible', 'realiable', 
                         'relable', 'relaible', 'reliabe', 'relaiable']

diff_reliable = [(reliable_vector - get_vectors(glove, s)).unsqueeze(0) 
                 for s in reliable_misspellings]



We take the average of these 8 'difference from reliable' vectors to get our "misspelling vector".

In [30]:
misspelling_vector = torch.cat(diff_reliable,dim=0)
print(misspelling_vector.shape)
misspelling_vector = misspelling_vector.mean(dim=0)
print(misspelling_vector.shape)

torch.Size([8, 300])
torch.Size([300])


In [32]:
word_vector = get_vectors(glove, 'becuase')

print_tuples(closest_words(glove, word_vector + misspelling_vector))

becuase  (3.59653639793396)
becasue  (3.856720209121704)
b/c  (4.49685525894165)
infact  (4.52193546295166)
unfortunatly  (4.61118221282959)
becuse  (4.624024391174316)
beacuse  (4.633757591247559)
shouldnt  (4.640401363372803)
beacause  (4.6744065284729)
becouse  (4.705304145812988)


In [33]:
word_vector = get_vectors(glove, 'defintiely')

print_tuples(closest_words(glove, word_vector + misspelling_vector))

defintiely  (3.59653639793396)
defnitely  (4.337471008300781)
definietly  (4.384809494018555)
defenitely  (4.460355281829834)
defiantely  (4.463236331939697)
definintely  (4.4975504875183105)
definitaly  (4.502630233764648)
deffinitely  (4.531066417694092)
definetley  (4.547483921051025)
definelty  (4.567103862762451)


In [34]:

word_vector = get_vectors(glove, 'consistant')

print_tuples(closest_words(glove, word_vector + misspelling_vector))

consistant  (3.59653639793396)
consistantly  (5.743580341339111)
consitant  (6.119101524353027)
inconsistant  (6.291146755218506)
consitent  (6.334517955780029)
relieable  (6.367373943328857)
consistent  (6.377477169036865)
consistenly  (6.4559855461120605)
usuall  (6.457441806793213)
consisent  (6.514623641967773)
