-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
42 lines (28 loc) · 1.11 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
import torch.nn as nn
import torch.nn.init as init
class Glove(nn.Module):
def __init__(self, vocab_size, emb_size, x_max=100, alpha=0.75):
super(Glove, self).__init__()
self.emb_u = nn.Embedding(vocab_size, emb_size)
self.emb_v = nn.Embedding(vocab_size, emb_size)
self.bias_u = nn.Embedding(vocab_size, 1)
self.bias_v = nn.Embedding(vocab_size, 1)
for param in self.parameters():
init.xavier_uniform_(param, gain=.1)
self.alpha = alpha
self.x_max = x_max
def forward(self, i, j, w):
l_vecs = self.emb_u(i)
r_vecs = self.emb_v(j)
l_bias = self.bias_u(i)
r_bias = self.bias_v(j)
log_covals = torch.log(w)
weight = torch.pow(w / self.x_max, self.alpha)
weight[weight > 1] = 1
sim = (l_vecs * r_vecs).sum(1).view(-1)
x = (sim + l_bias + r_bias - log_covals) ** 2
loss = torch.mul(x, weight)
return loss.mean()
def get_embedding(self, idx):
return self.emb_u.weight[idx] + self.emb_v.weight[idx].detach().cpu().numpy()