In [0]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch as tt
import torch.nn as nn
import torch.nn.functional as F

from tqdm.notebook import tqdm

from google.colab import files

In [0]:
files.upload()

Saving data.zip to data.zip


In [0]:
!unzip data.zip

Archive:  data.zip
  inflating: data.txt                
  inflating: stopwords.txt           


In [0]:
def nltk2wn_tag(nltk_tag):
    if nltk_tag.startswith('J'):
        return wordnet.ADJ
    elif nltk_tag.startswith('V'):
        return wordnet.VERB
    elif nltk_tag.startswith('N'):
        return wordnet.NOUN
    elif nltk_tag.startswith('R'):
        return wordnet.ADV
    else:
        return None

In [0]:
import nltk
nltk.download('wordnet')
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')

from nltk.stem import WordNetLemmatizer
from nltk.corpus import wordnet

lemmatizer = WordNetLemmatizer()

with open('stopwords.txt', 'r', encoding='utf-8') as f:
    stop_words = f.read().split('\n')

with open('data.txt', 'r', encoding='utf-8') as f:
    sents = f.readlines()
    sents = [[word.lower() for word in nltk.word_tokenize(sent) if word.isalpha() and word.lower() not in stop_words] for sent in tqdm(sents)]
    tagged_sents = [nltk.pos_tag(sent) for sent in tqdm(sents)]
    wn_tagged_sents = [[(word[0], nltk2wn_tag(word[1])) for word in sent] for sent in tqdm(tagged_sents)]
    sents = [[word[0] if word[1] is None else lemmatizer.lemmatize(word[0], word[1]) for word in sent] for sent in tqdm(wn_tagged_sents)]

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


HBox(children=(IntProgress(value=0, max=4551), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4551), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4551), HTML(value='')))




HBox(children=(IntProgress(value=0, max=4551), HTML(value='')))




In [0]:
class Attention(nn.Module):
    def __init__(self, embed_size):
        super().__init__()
        self.linear = nn.Linear(embed_size, embed_size)

    def forward(self, input_):
        mean_input = tt.mean(input_, dim=-1)
        lin_out = self.linear(mean_input).unsqueeze(2)
        bmm_out = tt.bmm(input_.transpose(1, 2), lin_out).tanh()
        res = bmm_out / sum(tt.exp(bmm_out))
        return res.squeeze(1)

In [0]:
from gensim.models import Word2Vec
EMBED_SIZE = 300
BATCH_SIZE = 32

In [0]:
w2v_model = Word2Vec(sents, size=EMBED_SIZE, min_count=1, workers=-1)
w2v_model.save("word2vec.model")

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


In [0]:
class AspectExtraction(nn.Module):
    def __init__(self, n_vocab=1000, d_embed=EMBED_SIZE, n_aspects=12):
        super().__init__()
        self.E = nn.Embedding(n_vocab, d_embed)
        self.T = nn.Embedding(n_aspects, d_embed)
        self.attention = Attention(d_embed)
        self.linear = nn.Linear(d_embed, n_aspects)

    def forward(self, pos, negs):
        p_t, z_s = self.predict(pos)
        r_s = F.normalize(tt.mm(self.T.weight.t(), p_t.t()).t(), dim=-1)
        e_n = self.E(negs).transpose(-2, -1)
        z_n = F.normalize(tt.mean(e_n, dim=-1), dim=-1)
        return r_s, z_s, z_n

    def predict(self, x):
        e_i = self.E(x).transpose(1, 2)
        a_i = self.attention(e_i)
        z_s = F.normalize(tt.bmm(e_i, a_i).squeeze(2), dim=-1)
        p_t = F.softmax(self.linear(z_s), dim=1)
        return p_t, z_s

    def aspects(self):
        E_n = F.normalize(self.E.weight, dim=1)
        T_n = F.normalize(self.T.weight, dim=1)
        projection = tt.mm(E_n, T_n.t()).t()
        return projection

In [0]:
def max_margin_loss(r_s, z_s, z_n):
    device = r_s.device
    pos = tt.bmm(z_s.unsqueeze(1), r_s.unsqueeze(2)).squeeze(2)
    negs = tt.bmm(z_n, r_s.unsqueeze(2)).squeeze()
    J = tt.ones(negs.shape).to(device) - pos.expand(negs.shape) + negs
    return tt.sum(tt.clamp(J, min=0.0))

In [0]:
def orthogonal_regularization(T):
    T_n = F.normalize(T, dim=1)
    I = tt.eye(T_n.shape[0]).to(T_n.device)
    return tt.norm(T_n.mm(T_n.t()) - I)

In [0]:
def _test_epoch(model, iterator, n_batches=100):
    model.eval()
    epoch_loss = 0

    with tt.no_grad():
        for i in tqdm(range(n_batches)):
            pos, neg = next(iterator)
            r_s, z_s, z_n = model(pos, neg)
            J = max_margin_loss(r_s, z_s, z_n).item()
            U = orthogonal_regularization(model.T.weight).item()
            loss = J + 0.1 * BATCH_SIZE * U
            epoch_loss += loss
    return epoch_loss / n_batches

In [0]:
def _train_epoch(model, iterator, optimizer, curr_epoch, n_batches=1000):
    model.train()
    running_loss = 0

    tqdm_range = tqdm(range(n_batches))
    for i in tqdm_range:
        pos, neg = next(iterator)
        optimizer.zero_grad()
        r_s, z_s, z_n = model(pos, neg)
        J = max_margin_loss(r_s, z_s, z_n)
        U = orthogonal_regularization(model.T.weight)
        loss = J + 0.1 * BATCH_SIZE * U
        loss.backward()
        optimizer.step()

        curr_loss = loss.data.cpu().detach().item()
        loss_smoothing = i / (i+1)
        running_loss = loss_smoothing * running_loss + (1 - loss_smoothing) * curr_loss
        tqdm_range.set_postfix(loss='%.5f' % running_loss)

    return running_loss

In [0]:
def nn_train(model, train_iterator, valid_iterator, optimizer, n_epochs=10, early_stopping=2):
    prev_loss = 100500
    es_epochs = 0
    best_epoch = None
    history = pd.DataFrame()

    for epoch in range(n_epochs):
        train_loss = _train_epoch(model, train_iterator, optimizer, epoch)
        valid_loss = _test_epoch(model, valid_iterator)
        print('validation loss %.5f' % valid_loss)

        record = {'epoch': epoch, 'train_loss': train_loss, 'valid_loss': valid_loss}
        history = history.append(record, ignore_index=True)

        if early_stopping > 0:
            if valid_loss > prev_loss:
                es_epochs += 1
            else:
                es_epochs = 0
            if es_epochs >= early_stopping:
                best_epoch = history[history.valid_loss == history.valid_loss.min()].iloc[0]
                print('Early stopping! best epoch: %d val %.5f' % (best_epoch['epoch'], best_epoch['valid_loss']))
                break
            prev_loss = min(prev_loss, valid_loss)

In [0]:
DEVICE = 'cuda'

In [0]:
tt.cuda.empty_cache()
model = AspectExtraction(n_vocab=len(w2v_model.wv.vocab)).to(DEVICE)
optimizer = tt.optim.Adam(model.parameters())

In [0]:
from random import sample

N = 50

def train_iterator(batch_size=BATCH_SIZE):
    while True:
        pos = [[w2v_model.wv.vocab[word].index for word in sent][:N] for sent in sample(sents, batch_size)]
        pos = [sent if len(sent) == N else sent + [-1] * (N - len(sent)) for sent in pos]
        negs = []
        for _ in range(batch_size):
            neg = [[[w2v_model.wv.vocab[word].index for word in sent][:N] for sent in sample(sents, batch_size)]]
            neg = [sent if len(sent) == N else sent + [-1] * (N - len(sent)) for sent in pos]
            negs.append(neg)
        yield (tt.tensor(pos).to(DEVICE), tt.tensor(negs).to(DEVICE))

In [0]:
train_iterator = train_iterator()
valid_iterator = train_iterator()

In [0]:
for instance in list(tqdm._instances): 
    tqdm._decr_instances(instance)

In [0]:
nn_train(model, train_iterator, valid_iterator, optimizer, n_epochs=5)

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0), HTML(value='')))


validation loss 0.59514


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0), HTML(value='')))


validation loss 0.01967


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0), HTML(value='')))


validation loss 0.00861


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0), HTML(value='')))


validation loss 0.00237


HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))




HBox(children=(IntProgress(value=0), HTML(value='')))


validation loss 0.00262


In [0]:
model.aspects().shape

torch.Size([12, 57869])

In [0]:
NUM = 10

aspects = []

for _, index in zip(*model.aspects().sort(dim=1)):
    index = index[-NUM:].detach().cpu().numpy()
    aspects.append([w2v_model.wv.index2word[i] for i in index])
    print(f'Aspect: {aspects[-1]}')

Aspect: ['pud', 'pattinson', 'shirtfront', 'yara', 'suffragist', 'busies', 'bardamu', 'murdered', 'nin', 'olam']
Aspect: ['bickert', 'abolish', 'securitised', 'biography', 'institut', 'closeups', 'unseaworthy', 'topped', 'perilously', 'sploding']
Aspect: ['committal', 'transience', 'deficit', 'numerical', 'hauteur', 'yuletide', 'flouride', 'nt', 'loanee', 'sire']
Aspect: ['wardrop', 'fireman', 'catfishers', 'ark', 'wireheads', 'roble', 'bellen', 'szkola', 'regulation', 'georgics']
Aspect: ['schwartz', 'midia', 'wtf', 'bareminerals', 'bebe', 'throats', 'choo', 'beis', 'unspecific', 'chemikal']
Aspect: ['busybody', 'fart', 'legouad', 'norma', 'portable', 'kidulthood', 'slapdash', 'sauvage', 'bauhaus', 'januarians']
Aspect: ['richie', 'generation', 'clarkson', 'dunkley', 'progeny', 'discogs', 'discoverable', 'tydfil', 'abbott', 'fernwood']
Aspect: ['griezmann', 'extramarital', 'outstation', 'commercialized', 'connecticut', 'sacrament', 'freak', 'fluke', 'wordchazer', 'anthrax']
Aspect: ['

In [0]:
from itertools import combinations


def calculate_coherence(w2v_model, term_rankings):
    overall_coherence = 0.0
    for topic_index in range(len(term_rankings)):
        pair_scores = []
        for pair in combinations( term_rankings[topic_index], 2 ):
            pair_scores.append( w2v_model.wv.similarity(pair[0], pair[1]) )
        topic_score = sum(pair_scores) / len(pair_scores)
        print(f'Topic {topic_index}: {topic_score}')
        overall_coherence += topic_score
    return overall_coherence / len(term_rankings)

In [0]:
import warnings
warnings.filterwarnings("ignore")

print('overall coherence:', calculate_coherence(w2v_model, aspects))

Topic 0: 0.0061295477932112085
Topic 1: -0.005702072462170488
Topic 2: -0.008498671112789048
Topic 3: -0.007325978023517463
Topic 4: -0.009897147411377066
Topic 5: 0.0007009131407054762
Topic 6: 0.009437471370781875
Topic 7: 0.017422807183013193
Topic 8: 0.004483165240122212
Topic 9: -0.0020589512410677142
Topic 10: -0.0017562802674041854
Topic 11: -0.002074044073621432
overall coherence: 7.173001132388056e-05
