In [1]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

In [2]:
# Load saved data
model = t.load('103001_22072019_checkpoints_gpus_3_epchs/checkpoint_3.pth', map_location='cpu')
dataset = t.load('103001_22072019_checkpoints_gpus_3_epchs/dataset.pth')

In [3]:
def get_proportions(doc_weights):
    """
    Softmax document weights to get proportions
    """
    return F.softmax(doc_weights, dim=1)

def get_doc_vectors(doc_weights, topic_embeds):
    """
    Multiply by proportions by topic embeddings to get document vectors
    """
    proportions = get_proportions(doc_weights)
    doc_vecs = t.matmul(topic_embeds, t.t(proportions))

    return t.t(doc_vecs)

In [4]:
topic_embeds = model["model_state_dict"]["topic_embeds"]
word_embeds = model["model_state_dict"]["word_embeds.weight"]
doc_weights = model["model_state_dict"]["doc_weights.weight"]
vocab = list(dataset['term_freq_dict'].keys())
doc_embeds = get_doc_vectors(doc_weights, topic_embeds)

In [5]:
def get_idx2vec():
    idx2vec = dict()
    with open(r"103001_22072019_checkpoints_gpus_3_epchs/00002/de_epoch_2/tensors.tsv", "r") as f:
        for i, line in enumerate(f):
            vec = line.replace('\n', '').split('\t')
            vec = t.tensor([float(x) for x in vec])
            idx2vec[str(i)] = vec
    return idx2vec

In [6]:
idx2vec = get_idx2vec()
idx2doc = dataset['idx2doc']

In [None]:
for key in idx2doc:
    print('ID:', key)
    print('PKG:', idx2doc[key])
    print('VEC:\n',idx2vec[key])
    print()

In [7]:
def wordvec2idx(word_vec):
    return np.where(word_embeds.numpy() == word_vec.numpy())[0][0]

def vec2word(word_vec):
    idx = wordvec2idx(word_vec)
    return vocab[idx]

In [8]:
def get_n_closest_word_vecs(topic_vec, n=20):
    dist = F.cosine_similarity(word_embeds, topic_vec.unsqueeze(dim=1).transpose(0, 1))
    index_sorted = dist.argsort()
    return index_sorted[:n]

In [9]:

for i, topic in enumerate(topic_embeds.transpose(0, 1)):
    # Get 10 closest word_embeds
    # Print word_embeds words
    top_10 = get_n_closest_word_vecs(topic)
    
    print(f'\nTOPIC: {i}')
    for word_vec in top_10:
        print(vocab[word_vec])


TOPIC: 0
'UQ
1XLNGUTPKI=
warren
Globally
I-75
malleable
5TLNBP
KFK
Similiarly
unprotected
capsules
Christ-killers
contrib/FAQ-Xt
UZ6LO
EI:6EI:6EI:6E
spies
menubar
W1=W=WW7
headline
1EF1

TOPIC: 1
Vervaeke
MainTextSW
RL-10
congress/103rd/HCR3
ZQR\ZAH8Z+SK'AH
1V=G9V=G9V=G9
eurocentric
dot-matrix
wordperfect.com
2PU
Siegel
Fairmount
gearing
1993Apr5.073813.5246
Apr06.184114.73926
120/90
uncompression
1993Apr2.155057.808
U=75
spit

TOPIC: 2
U*C
gammas
0.742
M13R6T=
Concrete
BE4
odder
unaided
MRFTOV
MODULA-2
recollection
1-line
'=J/F8
15.6
FCIY
/P6BF_.BN7NI=
130.9
=U38VONBI^.Q4R1
_Since
furnished

TOPIC: 3
AZO
NAB
impatiently
/G8U
veterans
opting
NRHJ*GIZ
2kHz
BIY
fprintf
=T=W=
misfortune/bad
AAO
C4zAyM.M9u
+4U-34U-34R-002D/Q
flamingest
5HL0
basenotes
gravities
jamesc

TOPIC: 4
cipher
dayhoff.med.Virginia.EDU
01110100
fledging
flagged
34211
4.8
pevasive
Larry_Murphy
Hanrahan
2T555L
P\1466T44\W=
15:50:36
2e30
-G
unheard-of
oak.shu.ac.uk
M/*E7E_/
converter_data
charon.bloomington.in.us

TOPI