In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle

In [2]:
class SkipGramModel(nn.Module):

    def __init__(self, doc_size, voc_size, emb_dimension):

        super(SkipGramModel, self).__init__()
        self.voc_size = voc_size
        self.doc_size = doc_size
        self.emb_dimension = emb_dimension
        self.d_embeddings = nn.Embedding(doc_size, emb_dimension, sparse=True)
        self.u_embeddings = nn.Embedding(voc_size, emb_dimension, sparse=True)
        self.v_embeddings = nn.Embedding(voc_size, emb_dimension, sparse=True)
        self.init_emb()

    def init_emb(self):

        initrange = 0.5 / self.emb_dimension*10
        self.d_embeddings.weight.data.uniform_(-initrange, initrange)
        self.u_embeddings.weight.data.uniform_(-initrange, initrange)
        self.v_embeddings.weight.data.uniform_(-initrange, initrange)

    def forward(self, doc_u, pos_v, neg_v):

        emb_d = self.d_embeddings(doc_u)
        emb_v = self.u_embeddings(pos_v)
        emb_neg_v = self.u_embeddings(neg_v)

        score_pos = torch.matmul(emb_d, torch.transpose(emb_v, 0, 1))
        score_pos = -torch.sum(F.logsigmoid(score_pos))

        score_neg = torch.matmul(emb_d, torch.transpose(emb_neg_v, 0, 1))
        score_neg = -torch.sum(F.logsigmoid(-score_neg))

        return score_pos + score_neg

    def save_embedding(self, cuda, dataset):

        if cuda:
            doc_emb = self.d_embeddings.weight.data.cpu().numpy()
        else:
            doc_emb = self.d_embeddings.weight.data.numpy()

        with open('./' + dataset + '/weights.txt', 'wb') as f:
            pickle.dump(doc_emb, f)

In [3]:
!python main.py

1 - 1
2 - 1
3 - 1
4 - 2
5 - 2
6 - 2
7 - 2
8 - 2
9 - 2
10 - 2
11 - 2
12 - 2
13 - 3
14 - 3
15 - 4
16 - 4
17 - 4
18 - 4
19 - 4
20 - 4
21 - 4
22 - 4
23 - 4
24 - 5
25 - 5
26 - 5
27 - 5
28 - 5
29 - 5
30 - 5
31 - 5
32 - 5
33 - 5
34 - 5
35 - 5
36 - 5
37 - 5
38 - 5
39 - 5
40 - 5
41 - 5
42 - 5
43 - 5
44 - 5
45 - 5
46 - 5
47 - 5
48 - 5
49 - 5
50 - 5
51 - 5
52 - 5
53 - 5
54 - 5
55 - 5
56 - 5
57 - 5
58 - 5
59 - 5
60 - 5
61 - 5
62 - 5
63 - 5
64 - 5
65 - 5
66 - 5
67 - 5
68 - 5
69 - 6
70 - 6
71 - 6
72 - 6
73 - 6
74 - 6
75 - 6
76 - 6
77 - 6
78 - 6
79 - 6
80 - 6
81 - 6
82 - 6
83 - 6
84 - 6
85 - 6
86 - 6
87 - 6
88 - 6
89 - 6
90 - 6
91 - 6
92 - 6
93 - 6
94 - 6
95 - 6
96 - 6
97 - 6
98 - 6
99 - 6
100 - 6
101 - 6
102 - 7
103 - 7
104 - 7
105 - 7
106 - 7
107 - 7
108 - 7
109 - 7
110 - 7
111 - 7
112 - 7
113 - 7
114 - 7
115 - 7
116 - 7
117 - 7
118 - 7
119 - 7
120 - 7
121 - 7
122 - 7
123 - 7
124 - 7
125 - 7
126 - 7
127 - 7
128 - 7
129 - 7
130 - 7
131 - 7
132 - 7
133 - 7
134 - 7
135 - 7
136 - 7
137 - 7
138 - 7
139 

In [5]:
!pip freeze > requirements.txt