In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
print(plt.get_backend())

In [None]:
import numpy as np
import torch
import mmap
import random
import contextlib
import traceback
from abae_pytorch.utils import linecount


np.random.seed(0)


class dataloader:
    
    def __init__(self, w2i, path, split=None):
        self.w2i = w2i
        self.path = path
        self.meta = './.' + os.path.basename(self.path) + '.meta.json'
        self.split = split if split else {'train': 1.0}
       
    def sample_splits(self, splits, probs):
        r = random.random()
        for s, p in zip(splits, np.cumsum(probs)):
            if r <= p:
                return s
    
    def __enter__(self):
        self.f = open(self.path, 'rb')
        self.data = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_COPY)
        if os.path.isfile(self.meta):
            self.read_meta()
        else:
            self.offsets = dict((s, []) for s in self.split)
            splits, probs = zip(*list(self.split.items()))
            desc = 'finding offsets in "%s"' % self.path
            i = 0
            for j, char in enumerate(tqdm.tqdm(self.data, desc=desc)):
                if char == b'\n':
                    split = self.sample_splits(splits, probs)
                    self.offsets[split].append((i, j))
                    i = j + 1
            self.linecounts = dict((s, len(self.offsets[s])) for s in self.split)
            self.linecount = sum(self.linecounts[s] for s in self.split)
            self.write_meta()
            
            print('offsets for splits:')
            for split in self.offsets:
                print(' "%s" : %d' % (split, len(self.offsets[split])))
        
        return self

    def __exit__(self, *ags):
        if ags[1]:
            traceback.print_exception(*ags)
        self.f.close()
        return True

    def write_meta(self):
        meta = {
            'path': self.path, 
            'linecount': self.linecount, 
            'linecounts': self.linecounts, 
            'offsets': self.offsets, 
        }
        with open(self.meta, 'w') as f:
            f.write(json.dumps(meta))
    
    def read_meta(self):
        with open(self.meta, 'r') as f:
            meta = json.loads(f.read())
        assert(self.path == meta['path'])
        self.linecount = meta['linecount']
        self.linecounts = meta['linecounts']
        self.offsets = meta['offsets']
    
    def b2i(self, batch):
        # use pytorch function for padding if one exists??
        batch = [self.data[u:v].decode('utf').split() for u, v in batch]
        lengths = [len(l) for l in batch]
        index = np.zeros((len(batch), max(lengths)))
        w2i = lambda w: (self.w2i[w] if w in self.w2i else self.w2i['<unk>'])
        for j, (words, length) in enumerate(zip(batch, lengths)):
            index[j, :length] = [w2i(w) for w in words]
        return torch.LongTensor(index)

    def batch_generator(self, split='train', device='cpu', batchsize=20, negsize=20):
        linecount = self.linecounts[split]
        batchcount = (linecount // batchsize)
        pos_offsets = self.offsets[split][:]
        neg_offsets = self.offsets[split][:]
        
        stime = time.time()
        print('shuffling "%s" data...' % (split, ), end='\r')
        random.shuffle(pos_offsets)
        random.shuffle(neg_offsets)
        print('shuffled "%s" data! (%0.1f s)' % (split, time.time() - stime))
        
        batches = 0
        while True:
            if batches == batchcount:
                
                print('shuffling data...', end='\r')
                random.shuffle(pos_offsets)
                random.shuffle(neg_offsets)
                print('shuffled data! (%0.1f s)' % (time.time() - stime))
                
                batches = 0
            pos_batch = pos_offsets[batches * batchsize:(batches + 1) * batchsize]
            pos_batch = self.b2i(pos_batch)            
            neg_batch = np.random.choice(linecount, batchsize * negsize)
            neg_batch = self.b2i([neg_offsets[i] for i in neg_batch])
            yield (pos_batch.to(device), neg_batch.to(device).view(batchsize, negsize, -1))
            batches += 1

In [None]:
from sklearn.cluster import KMeans
import numpy as np
import gensim
import codecs
import tqdm
import time


class word2vec:

    def __init__(self, corpus_path):
        self.corpus_path = corpus_path
        self.n_vocab = 0

    def __iter__(self):
        with codecs.open(self.corpus_path, 'r', 'utf-8') as f:
            for line in tqdm.tqdm(f, desc='training'):
                yield line.split()

    def add(self, *words):
        for word in words:
            if not word in self.w2i:
                self.w2i[word] = self.n_vocab
                self.i2w[self.w2i[word]] = word
                self.n_vocab += 1
                
    def embed(self, model_path, d_embed, window=5, min_count=10, workers=16):
        if os.path.isfile(model_path):
            model = gensim.models.Word2Vec.load(model_path)
        else:
            model = gensim.models.Word2Vec(self, size=d_embed, 
                window=window, min_count=min_count, workers=workers)
            model.save(model_path)
            model = gensim.models.Word2Vec.load(model_path)
        self.i2w, self.w2i = {}, {}
        self.add('<pad>')
        self.add('<unk>')
        print('loading embeddings...', end='\r')
        E = []
        n = len(model.wv.vocab)
        for word in sorted(model.wv.vocab):
            j = len(E)
            self.i2w[j] = word
            self.w2i[word] = j
            E.append(list(model.wv[word]))
        self.E = np.asarray(E)
        self.d_embed = d_embed        
        print('loaded embeddings!')
        return self
    
    def aspect(self, n_aspects):
        self.n_aspects = n_aspects

        #self.T = np.random.randn(n_aspects, self.E.shape[1]).astype(np.float32)
        #self.T /= np.linalg.norm(self.T, axis=-1, keepdims=True)
        #return self        
        km = KMeans(n_clusters=n_aspects, random_state=0)

        stime = time.time()
        print('clustering embeddings...', end='\r')
        km.fit(self.E)
        print('clustered embeddings! (%0.1f s)' % (time.time() - stime))
        clusters = km.cluster_centers_

        # L2 normalization
        norm_aspect_matrix = clusters / np.linalg.norm(clusters, axis=-1, keepdims=True)
        self.T = norm_aspect_matrix.astype(np.float32)
        return self

In [None]:
from sklearn.feature_extraction.text import CountVectorizer
from nltk.corpus import stopwords
from nltk.stem.wordnet import WordNetLemmatizer
import numpy as np
import codecs
import json
import os
from abae_pytorch.utils import linecount


class wikidata:
    
    def __init__(self, corpus_path, d_embed=200, n_aspects=14):
        self.corpus_path = corpus_path
        
        self.prep_path = self.corpus_path + '.prep'        
        if not os.path.isfile(self.prep_path):
            self.preprocess(self.corpus_path, self.prep_path)

        self.model_path = self.prep_path + '.w2v'
        w2v = word2vec(self.prep_path)
        w2v.embed(self.model_path, d_embed, min_count=100)
        w2v.aspect(n_aspects)
        self.n_vocab = len(w2v.w2i)
        self.w2v = w2v

    def preprocess(self, input_path, output_path):
        lmtzr = WordNetLemmatizer()    
        stop = stopwords.words('english')
        token = CountVectorizer().build_tokenizer()
        lc = linecount(input_path)
        with open(input_path, 'r') as in_f, open(output_path, 'w') as out_f:
            for j, l in tqdm.tqdm(enumerate(in_f), total=lc, desc='preprocessing "%s"' % input_path):
                tokens = [lmtzr.lemmatize(t) for t in token(l.lower()) if not t in stop]
                n_tokens = len(tokens)
                if len(tokens) > 5 and n_tokens < 100:
                    out_l = ' '.join(tokens)
                    out_f.write(out_l + '\n')

In [None]:
from torch.nn.functional import normalize
import torch.optim as optim
import torch


def max_margin_loss(r_s, z_s, z_n):
    device = r_s.device
    pos = torch.bmm(z_s.unsqueeze(1), r_s.unsqueeze(2)).squeeze(2)
    negs = torch.bmm(z_n, r_s.unsqueeze(2)).squeeze()
    J = torch.ones(negs.shape).to(device) - pos.expand(negs.shape) + negs
    J = torch.sum(torch.clamp(J, min=0.0))
    return J


def orthogonal_regularization(T):
    T_n = normalize(T, dim=1)
    I = torch.eye(T_n.shape[0]).to(T_n.device)
    U = torch.norm(T_n.mm(T_n.t()) - I)
    return U


def sample_aspects(projection, i2w, n=8):
    projection = torch.sort(projection, dim=1)
    for j, (projs, index) in enumerate(zip(*projection)):
        index = index[-n:].detach().cpu().numpy()
        words = ', '.join([i2w[i] for i in index])
        print('Aspect %2d: %s' % (j + 1, words))


def validate(ab, dl, split='val', batchsize=100, negsize=20, device='cuda'):

    # figure out how to disable grads??    

    batches = dl.batch_generator(split, batchsize=batchsize, negsize=negsize, device=device)
    n_batches = (dl.linecounts[split] // batchsize)
    with tqdm.tqdm(range(n_batches), total=n_batches, desc='validating') as pbar:
        losses = []
        for b in pbar:
            pos, neg = next(batches)
            r_s, z_s, z_n = ab(pos, neg)
            J = max_margin_loss(r_s, z_s, z_n)
            losses.append(J.item())
            x = (b + 1, np.mean(losses) / batchsize)
            pbar.set_description('BATCH: %d | MEAN-VAL-LOSS: %0.5f' % x)
    return np.mean(losses)

        
def train(ab, dl, device='cuda', 
          epochs=5, epochsize=50, initial_lr=0.02, batchsize=100, negsize=20, ortho_reg=0.1):    
    train_batches = dl.batch_generator('train', batchsize=batchsize, negsize=negsize, device=device)
    i2w = dict((dl.w2i[w], w) for w in dl.w2i)
    
    validate(ab, dl, 'val', batchsize, negsize, device)
    sample_aspects(ab.aspects(), i2w)

    mean_losses = []
    opt = optim.Adam(ab.parameters(), lr=initial_lr)
    for e in range(epochs):
        mean_losses.append([])
        with tqdm.trange(epochsize) as pbar:
            for b in pbar:
                pos, neg = next(train_batches)
                r_s, z_s, z_n = ab(pos, neg)
                J = max_margin_loss(r_s, z_s, z_n)
                U = orthogonal_regularization(ab.T.weight)
                loss = J + ortho_reg * batchsize * U
                opt.zero_grad()
                loss.backward()
                opt.step()

                mean_losses[-1].append(loss.item())
                x = (e + 1, opt.param_groups[0]['lr'], np.mean(mean_losses[-1]) / batchsize)
                pbar.set_description('EPOCH: %d | LR: %0.5f | MEAN-TRAIN-LOSS: %0.5f' % x)

                if b * batchsize % 100 == 0:
                    lr = initial_lr * (1.0 - 1.0 * ((e + 1) * (b + 1)) / (epochs * epochsize))
                    for pg in opt.param_groups:
                        pg['lr'] = lr
        
        validate(ab, dl, 'val', batchsize, negsize, device)
        sample_aspects(ab.aspects(), i2w)


        #all_losses = [x for y in mean_losses for x in y]
        all_losses = [np.mean(y) for y in mean_losses]
        plt.plot(list(range(len(all_losses))), all_losses, lw=4, marker='s')
        plt.semilogy()
        plt.show()

In [None]:
from abae_pytorch.model import abae

device = 'cpu'

d_embed = 100
n_aspects = 20

epochs = 20
epochsize = 100
batchsize = 100
negsize = 20
initial_lr = 0.001


#data = './data/wiki_01'
data = './data/beer.train.txt'
#data = './data/restaurant.train.txt'
split = {'train': 0.8, 'val': 0.1, 'test': 0.1}


wd = wikidata(data, d_embed, n_aspects)

x = (wd.n_vocab, wd.w2v.d_embed, wd.w2v.n_aspects)
print('n_vocab: %d | d_embed: %d | n_aspects: %d' % x)

ab = abae(wd.w2v.E, wd.w2v.T).to(device)

with dataloader(wd.w2v.w2i, wd.prep_path, split=split) as dl:
    train(ab, dl, 
          device=device,
          epochs=epochs, 
          epochsize=epochsize,
          batchsize=batchsize,
          negsize=negsize,
          initial_lr=initial_lr)

#

    offsets partitions in dataloader for splitting?
    validation loss measurement
    impose maximum vocab size
    sentence topic prediction function?

    word embeddings trained on partitions too...
    optionally use different w2v training corpus
    updating loss plot
    model saving/loading
    inferring n_aspects
    cli
    break into package
    documentation
    num tag for preprocessing
    downweight specificity?


# TRASH 

In [None]:
!{'wc %s' % wd.prep_path}

class dedicated to making a structured dataset a particular data source
    raw text -> preprocessing -> splitting
    creates a vocab
    vocab trains word embeddings

class data loader which serves batches of training/evaluation data
    requires preprocessed text to serve
    requires predetermined vocab
    vocab requires word embeddings
    
class model just the neural network parts

class wrap model with interface
    training
    evaluation
    deployment

cli script covering interface