In [1]:
# utils.py
import socket
import pickle
import argparse
import time
import os
from functools import reduce
import numpy as np
from scipy.special import logit
import matplotlib as mpl
# mpl.use('Agg')
# import matplotlib.pyplot as plt
from tqdm import tqdm
import mxnet as mx
# from sklearn.neighbors import NearestNeighbors
# from IPython import embed
import collections

def gpu_helper(gpu):
    if gpu >= 0 and gpu_exists(gpu):
        model_ctx = mx.gpu(gpu)
    else:
        model_ctx = mx.cpu()
    return model_ctx

def gpu_exists(gpu):
    try:
        mx.nd.zeros((1,), ctx=mx.gpu(gpu))
    except:
        return False
    return True

def reverse_dict(d):
    return {v:k for k,v in d.items()}

def to_numpy(X):
    x_npy = []
    for x in X:
        if isinstance(x,list):
            x_npy += [to_numpy(x)]
        else:
            x_npy += [x.asnumpy()]
    return x_npy

def stack_numpy(X,xnew):
    for i in range(len(X)):
        if isinstance(xnew[i],list):
            X[i] = stack_numpy(X[i], xnew[i])
        else:
            X[i] = np.vstack([X[i], xnew[i]])
    return X


def get_topic_words_decoder_weights(D, data, ctx, k=10, decoder_weights=False):
    if decoder_weights:
        params = D.collect_params()
        params = params['decoder0_dense0_weight'].data().transpose()
    else:
        y = D.y_as_topics()
        params = D(y.copyto(ctx))
    top_word_ids = mx.nd.argsort(params, axis=1, is_ascend=False)[:,:k].asnumpy()
    if hasattr(data, 'id_to_word'):
        top_word_strings = [[data.id_to_word[int(w)] for w in topic] for topic in top_word_ids]
    else:
        top_word_strings = [[data.maps['dim2vocab'][int(w)] for w in topic] for topic in top_word_ids]

    return top_word_strings


def get_topic_words(D, data, ctx, k=10):
    y, z = D.yz_as_topics()
    if z is not None:
        params = D(y.copyto(ctx), z.copyto(ctx))
    else:
        params = D(y.copyto(ctx), None)
    top_word_ids = mx.nd.argsort(params, axis=1, is_ascend=False)[:,:k].asnumpy()
    if hasattr(data, 'id_to_word'):
        top_word_strings = [[data.id_to_word[int(w)] for w in topic] for topic in top_word_ids]
    else:
        top_word_strings = [[data.maps['dim2vocab'][int(w)] for w in topic] for topic in top_word_ids]

    return top_word_strings


def calc_topic_uniqueness(top_words_idx_all_topics):
    """
    This function calculates topic uniqueness scores for a given list of topics.
    For each topic, the uniqueness is calculated as:  (\sum_{i=1}^n 1/cnt(i)) / n,
    where n is the number of top words in the topic and cnt(i) is the counter for the number of times the word
    appears in the top words of all the topics.
    :param top_words_idx_all_topics: a list, each element is a list of top word indices for a topic
    :return: a dict, key is topic_id (starting from 0), value is topic_uniquness score
    """
    n_topics = len(top_words_idx_all_topics)

    # build word_cnt_dict: number of times the word appears in top words
    word_cnt_dict = collections.Counter()
    for i in range(n_topics):
        word_cnt_dict.update(top_words_idx_all_topics[i])

    uniqueness_dict = dict()
    for i in range(n_topics):
        cnt_inv_sum = 0.0
        for ind in top_words_idx_all_topics[i]:
            cnt_inv_sum += 1.0 / word_cnt_dict[ind]
        uniqueness_dict[i] = cnt_inv_sum / len(top_words_idx_all_topics[i])

    return uniqueness_dict

def request_pmi(topic_dict=None, filename='', port=1234):
    try:
        # create a socket object
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        # get local machine name
        host = socket.gethostname()
        # host = socket.gethostbyname('localhost')

        # connection to hostname on the port.
        s.connect((host, port))

        if filename != '':
            s.sendall(pickle.dumps(filename), )
        else:
            s.send(pickle.dumps(topic_dict), )

        data = []
        while True:
            packet = s.recv(4096)
            # time.sleep(1.0)
            # print('looking at packet # {0}'.format(len(data)))
            # print(packet)
            # print(type(packet))
            wait = len(packet)
            if not packet:
                # embed()
                break
            data.append(packet)
            # print('received packet # {0}'.format(len(data)))
            # time.sleep(1.0)
        res_dict = pickle.loads(b"".join(data))

        s.close()
        pmi_dict = res_dict['pmi_dict']
        npmi_dict = res_dict['npmi_dict']
    except:
        # print('Failed to run NPMI calc, NPMI and PMI set to 0.0')
        pmi_dict = dict()
        npmi_dict = dict()
        for k in topic_dict:
            pmi_dict[k] = 0
            npmi_dict[k] = 0
        # embed()

    return pmi_dict, npmi_dict


def print_topics(topic_json, npmi_dict, topic_uniqs, data, print_topic_names=False):
    for k,v in topic_json.items():
        prefix_msg = '[ '
        if hasattr(data, 'maps') and print_topic_names:
            prefix_msg += data.maps['dim2topic'][k]
        else:
            prefix_msg += str(k)
        if hasattr(data, 'selected_topics') and print_topic_names:
            if data.maps['dim2topic'][k] in data.selected_topics:
                prefix_msg += '*'
        prefix_msg += ' - '
        prefix_msg += '{:.5g}'.format(topic_uniqs[k])
        prefix_msg += ' - '
        prefix_msg += '{:.5g}'.format(npmi_dict[k])
        prefix_msg += ']: '
        print(prefix_msg, v)

def print_topic_with_scores(topic_json, **kwargs):
    """
    :param topic_json:
    :param kwargs: dict_name: content_dict; special argument sortby='xxx' will enable descending sort in printed result
    :return:
    """
    topic_keys = sorted(list(topic_json.keys()))

    sortby = kwargs.pop('sortby', None)
    if sortby is None:
        sortby = kwargs.pop('sort_by', None)

    if sortby in kwargs.keys():
        topic_keys = sorted(kwargs[sortby], key=kwargs[sortby].get)[::-1]

    entries = []
    dict_names = sorted(list(kwargs.keys()))
    header_str = 'Avg scores: '
    for dn in dict_names:
        assert isinstance(kwargs[dn], dict)
        header_str += '{}: {:.2f} '.format(dn, mean_dict(kwargs[dn]))
    for k in topic_keys:
        score_str = []
        for dn in dict_names:
            # score_str.append('{} {:.2f}'.format(dn, kwargs[dn][k]))
            score_str.append('{:.2f}'.format(kwargs[dn][k]))
        score_str = ', '.join(score_str)
        entries.append('T{} [{}] '.format(k, score_str) + ', '.join(topic_json[k]))

    msg = header_str + '\n' + '\n'.join(entries)
    print(msg)
    return msg

In [2]:
# core.py
import pickle
import numpy as np

import matplotlib.pyplot as plt

from sklearn.metrics import log_loss, v_measure_score

import mxnet as mx
from mxnet import gluon, io


# import misc as nm
# import datasets as nuds
import scipy.sparse as sparse
import json
# import wordvectors as nuwe


class Data(object):
    '''
    Data Generator object. Main functionality is contained in ``minibatch'' method
    and ``subsampled_labeled_data'' if training in a semi-supervised fashion.
    Introducing new datasets requires implementing ``load'' and possibly overwriting
    portions of ``__init__''.
    '''
    def __init__(self, batch_size=1, data_path='', ctx=mx.cpu(0)):
        '''
        Constructor for Data.
        Args
        ----
        batch_size: int, default 1
          An integer specifying the batch size - required for precompiling the graph.
        data_path: string, default ''
          This is primarily used by Mulan to specify which dataset to load from Mulan,
          e.g., data_path='bibtex'.
        ctx: mxnet device context, default mx.cpu(0)
          Which device to store/run the data and model on.
        Returns
        -------
        Data object
        '''
        self.batch_size = batch_size
        if data_path == '':
            data, labels, maps = self.load()
        else:
            data, labels, maps = self.load(data_path)
        self.ctx = ctx
        # # normalize the data:
        # def softmax(x):
        #     """Compute softmax values for each sets of scores in x."""
        #     e_x = np.exp(x - np.max(x, axis=1).reshape((-1,1)))
        #     return e_x / np.sum(e_x, axis=1).reshape((-1,1))
        # for i in range(len(data)):
        #     data[i] = softmax(data[i])

        data_names = ['train','valid','test','train_with_labels','valid_with_labels','test_with_labels']
        label_names = ['train_label', 'valid_label', 'test_label']

        self.data = dict(zip(data_names, data))
        self.labels = dict(zip(label_names, labels))

        # repeat data to at least match batch_size
        for k, v in self.data.items():
            if v is not None and v.shape[0] < self.batch_size:
                print('NOTE: Number of samples for {0} is smaller than batch_size ({1}<{2}). Duplicating samples to exceed batch_size.'.format(k,v.shape[0],self.batch_size))
                if type(v) is np.ndarray:
                    self.data[k] = np.tile(v, (self.batch_size // v.shape[0] + 1, 1))
                else:
                    self.data[k] = mx.nd.tile(v, (self.batch_size // v.shape[0] + 1, 1))

        for k, v in self.labels.items():
            if v is not None and v.shape[0] < self.batch_size:
                print('NOTE: Number of samples for {0} is smaller than batch_size ({1}<{2}). Duplicating samples to exceed batch_size.'.format(k,v.shape[0],self.batch_size))
                self.labels[k] = np.tile(v, (self.batch_size // v.shape[0] + 1, ))

        map_names = ['vocab2dim','dim2vocab','topic2dim','dim2topic']
        self.maps = dict(zip(map_names, maps))
        dls = [self.dataloader(d, batch_size) for d in data]
        dis = [iter(dl) if dl is not None else None for dl in dls]
        self.dataloaders = dict(zip(data_names, dls))
        self.dataiters = dict(zip(data_names, dis))
        self.wasreset = dict(zip(data_names, np.ones(len(data_names), dtype='bool')))

        self.data_dim = self.data['train'].shape[1]
        if self.data['train_with_labels'] is not None:
            self.label_dim = self.data['train_with_labels'].shape[1] - self.data['train'].shape[1]


    def dataloader(self, data, batch_size, shuffle=True):
        '''
        Constructs a data loader for generating minibatches of data.
        Args
        ----
        data: numpy array, no default
          The data from which to load minibatches.
        batch_size: integer, no default
          The # of samples returned in each minibatch.
        shuffle: boolean, default True
          Whether or not to shuffle the data prior to returning the data loader.
        Returns
        -------
        DataLoader: A gluon DataLoader iterator
        '''
        if data is None:
            return None
        else:
            # inds = np.arange(data.shape[0])
            # if shuffle:
            #     np.random.shuffle(inds)
            # ordered = data[inds]
            # N, r = divmod(data.shape[0], batch_size)
            # if r > 0:
            #     ordered = np.vstack([ordered, ordered[:r]])
            if type(data) is np.ndarray:
                return gluon.data.DataLoader(data, batch_size, last_batch='discard', shuffle=shuffle)
            else:
                return io.NDArrayIter(data={'data': data}, batch_size=batch_size, shuffle=shuffle, last_batch_handle='discard')

    def force_reset_data(self, key, shuffle=True):
        '''
        Resets minibatch index to zero to restart an epoch.
        Args
        ----
        key: string, no default
          Required to select appropriate data in ``data'' object,
          e.g., 'train', 'test', 'train_with_labels', 'test_with_labels'.
        shuffle: boolean, default True
          Whether or not to shuffle the data prior to returning the data loader.
        Returns
        -------
        Nothing.
        '''
        if self.data[key] is not None:
            if type(self.data[key]) is np.ndarray:
                self.dataloaders[key] = self.dataloader(self.data[key], self.batch_size, shuffle)
                self.dataiters[key] = iter(self.dataloaders[key])
            else:
                self.dataiters[key].hard_reset()
            self.wasreset[key] = True

    def minibatch(self, key, pad_width=0):
        '''
        Returns a minibatch of data (stored on device self.ctx).
        Args
        ----
        key: string, no default
          Required to select appropriate data in ``data'' object,
          e.g., 'train', 'test', 'train_with_labels', 'test_with_labels'.
        pad_width: integer, default 0
          The amount to zero-pad the labels to match the dimensionality of z.
        Returns
        -------
        minibatch: NDArray on device self.ctx
          An NDArray of size batch_size x # of features.
        '''
        if self.dataiters[key] is None:
            return None
        else:
            if type(self.data[key]) is np.ndarray:
                try:
                    mb = self.dataiters[key].__next__().reshape((self.batch_size, -1))
                    if pad_width > 0:
                        mb = mx.nd.concat(mb, mx.nd.zeros((self.batch_size, pad_width)))
                    return mb.copyto(self.ctx)
                except:
                    self.force_reset_data(key)
                    mb = self.dataiters[key].__next__().reshape((self.batch_size, -1))
                    if pad_width > 0:
                        mb = mx.nd.concat(mb, mx.nd.zeros((self.batch_size, pad_width)))
                    return mb.copyto(self.ctx)
            else:
                try:
                    mb = self.dataiters[key].__next__().data[0].as_in_context(self.ctx)
                    return mb
                except:
                    self.dataiters[key].hard_reset()
                    mb = self.dataiters[key].__next__().data[0].as_in_context(self.ctx)
                    return mb

    def get_documents(self, key, split_on=None):
        '''
        Retrieves a minibatch of documents via ``data'' object parameter.
        Args
        ----
        key: string, no default
          Required to select appropriate data in ``data'' object,
          e.g., 'train', 'test', 'train_with_labels', 'test_with_labels'.
        split_on: integer, default None
          Useful if self.data[key] contains both data and labels in one
          matrix and want to split them, e.g., split_on = data_dim.
        Returns
        -------
        minibatch: NDArray if split_on is None, else [NDarray, NDArray]
        '''
        if 'labels' in key:
            batch = self.minibatch(key, pad_width=self.label_pad_width)
        else:
            batch = self.minibatch(key)
        if split_on is not None:
            batch, labels = batch[:,:split_on], batch[:,split_on:]
            return batch, labels
        else:
            return batch

    @staticmethod
    def visualize_series(y, ylabel, file, args, iteration, total_samples, labels=None):
        '''
        Plots and saves a figure of y vs iterations and epochs to file.
        Args
        ----
        y: a list (of lists) or numpy array, no default
          A list (of possibly another list) of numbers to plot.
        ylabel: string, no default
          The label for the y-axis.
        file: string, no default
          A path with filename to save the figure to.
        args: dictionary, no default
          A dictionary of model, training, and evaluation specifications.
        iteration: integer, no default
          The current iteration in training.
        total_samples: integer, no default
          The total number of samples in the dataset - used along with batch_size
          to convert iterations to epochs.
        labels: list of strings, default None
          If y is a list of lists, the labels contains names for each element
          in the nested list. This is used to create an appropriate legend
          for the plot.
        Returns
        -------
        Nothing.
        '''
        if len(y) > 0:
            fig = plt.figure()
            ax = plt.subplot(111)
            x = np.linspace(0, iteration, num=len(y)) * args['batch_size'] / total_samples
            y = np.array(y)
            if len(y.shape) > 1:
                for i in range(y.shape[1]):
                    if labels is None:
                        plt.plot(x,y[:,i])
                    else:
                        plt.plot(x,y[:,i], label=labels[i])
            else:
                plt.plot(x,y)
            ax.set_ylabel(ylabel)
            ax.set_xlabel('Epochs')
            plt.grid(True)

            ax2 = ax.twiny()

            # https://pythonmatplotlibtips.blogspot.com/2018/01/add-second-x-axis-below-first-x-axis-python-matplotlib-pyplot.html
            # Decide the ticklabel position in the new x-axis,
            # then convert them to the position in the old x-axis
            # xticks list seems to be padded with extra lower and upper ticks --> subtract 2 from length
            newlabel = np.around(np.linspace(0, iteration, num=len(ax.get_xticks())-2)).astype('int') # labels of the xticklabels: the position in the new x-axis
            # ax2.set_xticks(ax.get_xticks())
            ax2.set_xticks(newlabel * args['batch_size'] / total_samples)
            ax2.set_xticklabels(newlabel//1000)

            ax2.xaxis.set_ticks_position('bottom') # set the position of the second x-axis to bottom
            ax2.xaxis.set_label_position('bottom') # set the position of the second x-axis to bottom
            ax2.spines['bottom'].set_position(('outward', 36))
            ax2.set_xlabel('Thousand Iterations')
            ax2.set_xlim(ax.get_xlim())

            if labels is not None:
                lgd = ax.legend(loc='center left', bbox_to_anchor=(1.05, 1))
                fig.savefig(args['saveto']+file, additional_artists=[lgd], bbox_inches='tight')
            else:
                fig.tight_layout()
                fig.savefig(args['saveto']+file)
            plt.close()

    def load(self, path=''):
        '''
        Loads data and maps from path.
        Args
        ----
        path: string, default ''
          A path to the data file.
        Returns
        -------
        data: list of numpy arrays
          A list of the different subsets of data, e.g.,
          `train', `test', 'train_with_labels', 'test_with_labels'.
        maps: list of dictionaries
          A list of dictionaries for mapping between dimensions and strings,
          e.g., 'vocab2dim', 'dim2vocab', 'topic2dim', 'dim2topic'.
        '''
        data = [np.empty((1,1)) for data in ['train','valid','test','train_with_labels','valid_with_labels','test_with_labels']]
        maps = [{'a':0}, {0:'a'}, {'Letters':0}, {0:'Letters'}]
        self.data_path = path + '***.npz'
        return data, maps


class ENet(gluon.HybridBlock):
    '''
    A gluon HybridBlock Encoder (skeleton) class.
    '''
    def __init__(self):
        '''
        Constructor for Encoder.
        Args
        ----
        None
        Returns
        -------
        Encoder object
        '''
        super(ENet, self).__init__()
            
    def hybrid_forward(self, x):
        '''
        Encodes x.
        Args
        ----
        x: mx.NDArray or sym, No default
          Input to encoder.
        Returns (should)
        -------
        params: list of NDArray or sym
          parameters for the encoding distribution
        samples: NDArray or sym
          samples drawn from the encoding distribution
        '''
        raise NotImplementedError('Need to write your own Encoder that inherits from ENet. Put this file in models/.')

    def init_weights(self, weights=None):
        '''
        Initializes the encoder weights. Default is Xavier initialization.
        Args
        ----
        weights: list of numpy arrays, No default
          Weights to load into the model. Not required. Preference is to
          load weights from file.
        Returns
        -------
        Nothing.
        '''
        loaded = False
        source = 'keyword argument'
        if self.weights_file != '' and weights is None:
            try:
                self.load_params(self.weights_file, self.model_ctx)
                source = 'mxnet weights file: '+self.weights_file
                print('NOTE: Loaded encoder weights from '+source+'.')
                if self.freeze:
                    self.freeze_params()
                    print('NOTE: Froze encoder weights from '+source+'.')
                weights = None
                loaded = True
            except:
                weights = pickle.load(open(self.weights_file,'rb'))
                source = 'pickle file: '+self.weights_file
        if weights is not None:
            assert self.n_layers == 0
            for p,w in zip(self.collect_params().values(), weights):
                if w is not None:
                    p.initialize(mx.init.Constant(mx.nd.array(w.squeeze())), ctx=self.model_ctx)
                    if self.freeze:
                        p.lr_mult = 0.
            print('NOTE: Loaded encoder weights from '+source+'.')
            if self.freeze:
                print('NOTE: Froze encoder weights from '+source+'.')
            loaded = True
        if not loaded:
            self.collect_params().initialize(mx.init.Xavier(), ctx=self.model_ctx)
            print('NOTE: Randomly initialized encoder weights.')
            # self.collect_params().initialize(mx.init.Zero(), ctx=self.model_ctx)
            # print('NOTE: initialized encoder weights to ZERO.')

    def freeze_params(self):
        for p in self.collect_params().values():
            p.lr_mult = 0.


class DNet(gluon.HybridBlock):
    '''
    A gluon HybridBlock Decoder (skeleton) class.
    '''
    def __init__(self):
        '''
        Constructor for Decoder.
        Args
        ----
        None
        Returns
        -------
        Decoder object
        '''
        super(DNet, self).__init__()

    def hybrid_forward(self, y, z):
        '''
        Decodes x.
        Args
        ----
        x: mx.NDArray or sym, no default
          Input to decoder.
        Returns (should)
        -------
        params: list of NDArray or sym
          parameters for the encoding distribution
        samples: NDArray or sym
          samples drawn from the encoding distribution. None if sampling is not implemented.
        '''
        raise NotImplementedError('Need to write your own Decoder that inherits from ENet. Put this file in models/.')

    def init_weights(self, weights=None):
        '''
        Initializes the decoder weights. Default is Xavier initialization.
        Args
        ----
        weights: list of numpy arrays, No default
          Weights to load into the model. Not required. Preference is to
          load weights from file.
        Returns
        -------
        Nothing.
        '''
        loaded = False
        source = 'keyword argument'
        if self.weights_file != '' and weights is None:
            try:
                self.load_params(self.weights_file, self.model_ctx)
                source = 'mxnet weights file: '+self.weights_file
                print('NOTE: Loaded decoder weights from '+source+'.')
                if self.freeze:
                    self.freeze_params()
                    print('NOTE: Froze decoder weights from '+source+'.')
                weights = None
                loaded = True
            except:
                weights = pickle.load(open(self.weights_file,'rb'))
                source = 'pickle file: '+self.weights_file
        if weights is not None:
            assert self.n_layers == 0
            for p,w in zip(self.collect_params().values(), weights):
                if w is not None:
                    p.initialize(mx.init.Constant(mx.nd.array(w.squeeze())), ctx=self.model_ctx)
                    if self.freeze:
                        p.lr_mult = 0.
            print('NOTE: Loaded decoder weights from '+source+'.')
            if self.freeze:
                print('NOTE: Froze decoder weights from '+source+'.')
            loaded = True
        if not loaded:
            self.collect_params().initialize(mx.init.Xavier(), ctx=self.model_ctx)
            print('NOTE: Randomly initialized decoder weights.')

    def freeze_params(self):
        for p in self.collect_params().values():
            p.lr_mult = 0.


class Compute(object):
    '''
    Skeleton class to manage training, testing, and retrieving outputs.
    See ``compute_op.py'' for ``flesh''.
    '''
    def __init__(self,  data, Enc, Dec,  Dis_y, args):
        '''
        Constructor for Compute.
        Returns
        -------
        Compute object
        '''
        self.data = data
        self.Enc = Enc
        self.Dec = Dec
        self.Dis_y = Dis_y
        self.args = args
        self.model_ctx = Enc.model_ctx
        self.ndim_y = args['ndim_y']

        weights_enc = Enc.collect_params()
        weights_dec = Dec.collect_params()
        weights_dis_y = Dis_y.collect_params()

        if self.args['optim'] == 'Adam':
            # args_dict = {'learning_rate': self.args['learning_rate'], 'beta1': self.args['betas'][0], 'beta2': self.args['betas'][1], 'epsilon': self.args['epsilon']}
            # optimizer_enc = gluon.Trainer(weights_enc, 'adam', args_dict)
            # optimizer_dec = gluon.Trainer(weights_dec, 'adam', args_dict)
            # optimizer_dis_y = gluon.Trainer(weights_dis_y, 'adam', args_dict)
            optimizer_enc = gluon.Trainer(weights_enc, 'adam', {'learning_rate': self.args['learning_rate'], 'beta1': 0.99})
            optimizer_dec = gluon.Trainer(weights_dec, 'adam', {'learning_rate': self.args['learning_rate'], 'beta1': 0.99})
            optimizer_dis_y = gluon.Trainer(weights_dis_y, 'adam', {'learning_rate': self.args['learning_rate']})
        if self.args['optim'] == 'Adadelta':
            # note: learning rate has no effect on Adadelta --> https://mxnet.incubator.apache.org/_modules/mxnet/optimizer.html#AdaDelta
            args_dict = {'rescale_grad': 1}  #, 'clip_gradient': 0.1}
            optimizer_enc = gluon.Trainer(weights_enc, 'adadelta', args_dict)
            optimizer_dec = gluon.Trainer(weights_dec, 'adadelta', args_dict)
            optimizer_dis_y = gluon.Trainer(weights_dis_y, 'adadelta', args_dict)
        if self.args['optim'] == 'RMSprop':
            args_dict = {'learning_rate': self.args['learning_rate'], 'epsilon': 1e-10, 'alpha': 0.9}
            optimizer_enc = gluon.Trainer(weights_enc, 'rmsprop', args_dict)
            optimizer_dec = gluon.Trainer(weights_dec, 'rmsprop', args_dict)
            optimizer_dis_y = gluon.Trainer(weights_dis_y, 'rmsprop', args_dict)
        if self.args['optim'] == 'SGD':
            args_dict = {'learning_rate': self.args['learning_rate'], 'wd': self.args['weight_decay'], 'rescale_grad': 1., 'momentum': 0.0, 'lazy_update': False}
            optimizer_enc = gluon.Trainer(weights_enc, 'sgd', args_dict)
            optimizer_dec = gluon.Trainer(weights_dec, 'sgd', args_dict)
            optimizer_dis_y = gluon.Trainer(weights_dis_y, 'sgd', args_dict)

        self.optimizer_enc = optimizer_enc
        self.optimizer_dec = optimizer_dec
        self.optimizer_dis_y = optimizer_dis_y
        self.weights_enc = weights_enc
        self.weights_dec = weights_dec
        self.weights_dis_y = weights_dis_y

    def train_op(self):
        '''
        Trains the model using one minibatch of data.
        '''
        return None, None, None, None

    def test_op(self, num_samples=None, num_epochs=None, reset=True, dataset='test'):
        '''
        Evaluates the model using num_samples.
        Args
        ----
        num_samples: integer, default None
          The number of samples to evaluate on. This is converted to
          evaluating on (num_samples // batch_size) minibatches.
        num_epochs: integer, default None
          The number of epochs to evaluate on. This used if num_samples
          is not specified. If neither is specified, defaults to 1 epoch.
        reset: bool, default True
          Whether to reset the test data index to 0 before iterating
          through and evaluating on minibatches.
        dataset: string, default 'test':
          Which dataset to evaluate on: 'valid' or 'test'.
        '''
        if num_samples is None:
            num_samples = self.data.data[dataset].shape[0]

        if reset:
            # Reset Data to Index Zero
            self.data.force_reset_data(dataset)
            self.data.force_reset_data(dataset+'_with_labels')

        return None, None, None, None

    def get_outputs(self, num_samples=None, num_epochs=None, reset=True, dataset='test'):
        '''
        Retrieves raw outputs from model for num_samples.
        Args
        ----
        num_samples: integer, default None
          The number of samples to evaluate on. This is converted to
          evaluating on (num_samples // batch_size) minibatches.
        num_epochs: integer, default None
          The number of epochs to evaluate on. This used if num_samples
          is not specified. If neither is specified, defaults to 1 epoch.
        reset: bool, default True
          Whether to reset the test data index to 0 before iterating
          through and evaluating on minibatches.
        dataset: string, default 'test':
          Which dataset to evaluate on: 'valid' or 'test'.
        '''
        if num_samples is None:
            num_samples = self.data.data[dataset].shape[0]

        if reset:
            # Reset Data to Index Zero
            self.data.force_reset_data(dataset)
            self.data.force_reset_data(dataset+'_with_labels')

        return None, None, None, None, None, None

In [4]:
# npmi_calc.py
import sys
import re
import math
import os
import time
import socket
import itertools

py_version = 2
if sys.version_info.major == 3:
    import _pickle as pickle

    py_version = 3
else:
    import cPickle as pickle
# log_prefix = re.compile(r"^\[[^\]]+\]")
# topics_pattern = re.compile(r'[T|t]opics from epoch:([^\s]+)\s*\(num_topics:([0-9]+)\):')

phrase_split_pattern = re.compile(r'-|_')


def get_terminal_width():
    try:
        term_cols = os.get_terminal_size().columns
    except:
        term_cols = 80
    return term_cols


def print_center(string):
    n_cols = get_terminal_width()
    spacer = '  '
    center_string = spacer + string + spacer
    n_front = (n_cols - len(center_string)) // 2
    n_back = n_cols - len(center_string) - n_front
    new_string = ' ' * n_front + center_string + ' ' * n_back
    print(new_string)


def print_header(string, skipline=True, symbol='#', doubleline=False):
    n_cols = get_terminal_width()
    if skipline:
        print()
    print(symbol * n_cols)
    if doubleline:
        print(symbol * n_cols)
    print_center(string)
    print(symbol * n_cols)
    if doubleline:
        print(symbol * n_cols)
    if skipline:
        print()


class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    pass_str = OKGREEN + 'Pass' + ENDC
    fail_str = FAIL + 'Fail' + ENDC


def print_warning(*args):
    s = ' '.join(args)
    print(bcolors.WARNING + '{}'.format(s) + bcolors.ENDC)


def print_green(*args):
    s = ' '.join(args)
    print(bcolors.OKGREEN + '{}'.format(s) + bcolors.ENDC)


def print_blue(*args):
    s = ' '.join(args)
    print(bcolors.OKBLUE + '{}'.format(s) + bcolors.ENDC)


def print_error(*args):
    s = ' '.join(args)
    print(bcolors.FAIL + '{}'.format(s) + bcolors.ENDC)


class RefCorpus:
    def __init__(self):
        home_dir = os.path.expanduser('~')
        self.wiki_invind_file = os.path.join(home_dir, 'wikipedia.inv_index.pkl')
        self.wiki_dict_file = os.path.join(home_dir, 'wikipedia.dict.pkl')

    def load_corpus(self):
        print_header('Loading reference corpus')

        with open(self.wiki_dict_file, 'rb') as f:
            if py_version == 3:
                corpus_vocab = pickle.load(f, encoding='utf-8')
            else:
                corpus_vocab = pickle.load(f)

        print(len(corpus_vocab))

        with open(self.wiki_invind_file, 'rb') as f:
            [inv_index, corpus_size] = pickle.load(f)

        self.corpus_vocab = corpus_vocab  # number of words
        self.inv_index = inv_index  # word_id: [doc_id1, doc_id2...]
        self.corpus_size = corpus_size  # number of documents


def get_docs_from_index(w, corpus_vocab, inv_index):
    wdocs = set()
    if re.search(phrase_split_pattern, w):
        # this is to handle the phrases in NYT corpus, without which we will have 50% of the words considered OOV.
        wdocs = intersecting_docs(w, corpus_vocab, inv_index)
    elif w in corpus_vocab:
        wdocs = inv_index[corpus_vocab[w]]
    return wdocs


def intersecting_docs(phrase, corpus_vocab, inverted_index):
    words = re.split(phrase_split_pattern, phrase)
    intersect_docs = set()
    for word in words:
        if not word in corpus_vocab:
            # if any of the words in the phrase is not the corpus, the phrase also is not in the corpus
            return set()
        if not intersect_docs:
            intersect_docs.update(inverted_index[corpus_vocab[word]])
        else:
            intersect_docs.intersection_update(inverted_index[corpus_vocab[word]])
    return intersect_docs


def get_pmi(docs_1, docs_2, corpus_size):
    assert len(docs_1)
    assert len(docs_2)
    small, big = (docs_1, docs_2) if len(docs_1) < len(docs_2) else (docs_2, docs_1)
    intersect = small.intersection(big)
    pmi = 0.0
    npmi = 0.0
    if len(intersect):
        pmi = math.log(corpus_size) + math.log(len(intersect)) - math.log(len(docs_1)) - math.log((len(docs_2)))
        npmi = -1 * pmi / (math.log(len(intersect)) - math.log(corpus_size))

    return pmi, npmi


def get_idf(w, inv_index, corpus_vocab, corpus_size):
    n_docs = len(get_docs_from_index(w, corpus_vocab, inv_index))
    return math.log(corpus_size / (n_docs + 1.0))


def test_pmi(inv_index, corpus_vocab, corpus_size):
    word_pairs = [
        ["apple", "ipad"],
        ["monkey", "business"],
        ["white", "house"],
        ["republican", "democrat"],
        ["china", "usa"],
        ["president", "bush"],
        ["president", "george_bush"],
        ["president", "george-bush"]
    ]
    pmis = []
    for pair in word_pairs:
        w1docs = get_docs_from_index(pair[0], corpus_vocab, inv_index)
        w2docs = get_docs_from_index(pair[1], corpus_vocab, inv_index)
        assert len(w1docs)
        assert len(w2docs)
        pmi, _ = get_pmi(w1docs, w2docs, corpus_size)
        assert pmi > 0.0
        print("Testing PMI: w1: {}  w2: {}  pmi: {}".format(pair[0], pair[1], pmi))
        pmis.append(pmi)
    assert pmis[0] > pmis[1]  # pmi(apple, ipad) > pmi(monkey, business)



def get_topic_pmi(wlist, corpus_vocab, inv_index, corpus_size, max_words_per_topic):
    num_pairs = 0
    pmi = 0.0
    npmi = 0.0
    # compute topic coherence only for first 10 word in each topic.
    wlist = wlist[:max_words_per_topic]
    for (w1, w2) in itertools.combinations(wlist, 2):
        w1docs = get_docs_from_index(w1, corpus_vocab, inv_index)
        w2docs = get_docs_from_index(w2, corpus_vocab, inv_index)
        if len(w1docs) and len(w2docs):
            word_pair_pmi, word_pair_npmi = get_pmi(w1docs, w2docs, corpus_size)
            pmi += word_pair_pmi
            npmi += word_pair_npmi
            num_pairs += 1
    if num_pairs:
        pmi /= num_pairs
        npmi /= num_pairs
    return pmi, npmi, num_pairs


def calc_pmi_for_all_topics(topic_dict, corpus_vocab, inv_index, corpus_size):
    pmi_dict = dict()
    npmi_dict = dict()
    for k in topic_dict.keys():
        wlist = topic_dict[k]
        use_N_words = len(wlist)  # use full list
        pmi, npmi, _ = get_topic_pmi(wlist, corpus_vocab, inv_index, corpus_size, use_N_words)
        # print(npmi, pmi)
        # print(wlist)
        pmi_dict[k] = pmi
        npmi_dict[k] = npmi

    return pmi_dict, npmi_dict


def launch_socket(port=1234, ref_corpus=None):
    print_header('Launching socket at port {}'.format(port))
    # create a socket object
    serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

    # get local machine name
    # host = socket.gethostname()
    host = socket.gethostbyname('localhost')

    # bind to the port
    serversocket.bind((host, port))

    # queue up to 5 requests
    serversocket.listen(5)
    print_header('Socket ready port {}'.format(port))

    # set up socket first, then load corpus
    if ref_corpus is None:
        ref_corpus = RefCorpus()
        ref_corpus.load_corpus()

    test_pmi(ref_corpus.inv_index, ref_corpus.corpus_vocab, ref_corpus.corpus_size)
    # test_idf(rc.inv_index, rc.corpus_vocab, rc.corpus_size)
    print_header('Test done')

    while True:
        # establish a connection
        clientsocket, addr = serversocket.accept()
        start = time.time()

        print("Got a connection from %s" % str(addr))
        # currentTime = time.ctime(time.time()) + "\r\n"
        # clientsocket.send(currentTime.encode('ascii'))

        # json_str = clientsocket.recv(10 * 1024 * 1024)
        # topic_json = json.loads(json_str.decode('ascii'))

        # https://stackoverflow.com/questions/24726495/pickle-eoferror-ran-out-of-input-when-recv-from-a-socket
        # data = b"".join(iter(partial(clientsocket.recv, 1024 * 1024), b""))
        # topic_dict = pickle.loads(data)
        # data = []
        # while True:
        #     packet = s.recv(1024 * 1024)
        #     if not packet: break
        #         data.append(packet)
        # topic_dict = pickle.loads(b"".join(data))
        try:
            data = clientsocket.recv(1024 * 1024)
            # data = []
            # while True:
            #     print('waiting for packet # {0}'.format(len(data)))
            #     packet = clientsocket.recv(4096)
            #     print(packet)
            #     wait = len(packet)
            #     if not packet:
            #         break
            #     data.append(packet)
            #     print('received packet # {0}'.format(len(data)))
            # print('got here')
            # print(data)
            # data = b"".join(data)
            # data = clientsocket.recv(4096)
            received = pickle.loads(data)
            print(received)
            if isinstance(received, str):
                filename, i = received.split(':')
                topic_dict = pickle.load(open(filename,'rb'))['Topic Words'][int(i)]
            else:
                topic_dict = received
            # print('received data: {0} Mb'.format(len(data)))
            print('received data')
            print('Time elapsed: {:.2f}s'.format(time.time() - start))
            print(topic_dict)
            pmi_dict, npmi_dict = calc_pmi_for_all_topics(topic_dict,
                                                          ref_corpus.corpus_vocab,
                                                          ref_corpus.inv_index,
                                                          ref_corpus.corpus_size)

            print('completed calculation')
            print('Time elapsed: {:.2f}s'.format(time.time() - start))
            result_dict = {'pmi_dict': pmi_dict, 'npmi_dict': npmi_dict}

            res = pickle.dumps(result_dict)
            clientsocket.send(res, )
            # clientsocket.shutdown(socket.SHUT_RDWR)
            clientsocket.close()
            del clientsocket
            print('Connection closed')
            print('Time elapsed: {:.2f}s'.format(time.time() - start))
        except Exception as e:
            print('Error occured when receiving packet')
            print(e)
            # clientsocket.shutdown(socket.SHUT_RDWR)
            clientsocket.close()
            del clientsocket


def request_pmi(topic_dict, port=1234):
    try:
        # create a socket object
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

        # get local machine name
        host = socket.gethostname()

        # connection to hostname on the port.
        s.connect((host, port))

        # json_str = json.dumps(topic_json)
        # s.send(json_str.encode('ascii'))

        s.send(pickle.dumps(topic_dict), )

        # Receive no more than 1024 * 1024 bytes
        # res = s.recv(1024*1024*1024)
        data = []
        while True:
            packet = s.recv(4096)
            if not packet:
                break
            data.append(packet)
        res_dict = pickle.loads(b"".join(data))

        s.close()
        pmi_dict = res_dict['pmi_dict']
        npmi_dict = res_dict['npmi_dict']
    except:
        print_error('Failed to run NPMI calc, NPMI and PMI set to 0.0')
        pmi_dict = dict()
        npmi_dict = dict()
        for k in topic_dict:
            pmi_dict[k] = 0
            npmi_dict[k] = 0

    return pmi_dict, npmi_dict

In [6]:
# compute_op.py
import numpy as np

from tqdm import tqdm

from mxnet import nd, autograd, gluon, io

# from core import Compute
# from utils import to_numpy, stack_numpy
# from diff_sample import normal
import os


def mmd_loss(x, y, ctx_model, t=0.1, kernel='diffusion'):
    '''
    computes the mmd loss with information diffusion kernel
    :param x: batch_size x latent dimension
    :param y:
    :param t:
    :return:
    '''
    eps = 1e-6
    n,d = x.shape
    if kernel == 'tv':
        sum_xx = nd.zeros(1, ctx=ctx_model)
        for i in range(n):
            for j in range(i+1, n):
                sum_xx = sum_xx + nd.norm(x[i] - x[j], ord=1)
        sum_xx = sum_xx / (n * (n-1))

        sum_yy = nd.zeros(1, ctx=ctx_model)
        for i in range(y.shape[0]):
            for j in range(i+1, y.shape[0]):
                sum_yy = sum_yy + nd.norm(y[i] - y[j], ord=1)
        sum_yy = sum_yy / (y.shape[0] * (y.shape[0]-1))

        sum_xy = nd.zeros(1, ctx=ctx_model)
        for i in range(n):
            for j in range(y.shape[0]):
                sum_xy = sum_xy + nd.norm(x[i] - y[j], ord=1)
        sum_yy = sum_yy / (n * y.shape[0])
    else:
        qx = nd.sqrt(nd.clip(x, eps, 1))
        qy = nd.sqrt(nd.clip(y, eps, 1))
        xx = nd.dot(qx, qx, transpose_b=True)
        yy = nd.dot(qy, qy, transpose_b=True)
        xy = nd.dot(qx, qy, transpose_b=True)

        def diffusion_kernel(a, tmpt, dim):
            # return (4 * np.pi * tmpt)**(-dim / 2) * nd.exp(- nd.square(nd.arccos(a)) / tmpt)
            return nd.exp(- nd.square(nd.arccos(a)) / tmpt)

        off_diag = 1 - nd.eye(n, ctx=ctx_model)
        k_xx = diffusion_kernel(nd.clip(xx, 0, 1-eps), t, d-1)
        k_yy = diffusion_kernel(nd.clip(yy, 0, 1-eps), t, d-1)
        k_xy = diffusion_kernel(nd.clip(xy, 0, 1-eps), t, d-1)
        sum_xx = (k_xx * off_diag).sum() / (n * (n-1))
        sum_yy = (k_yy * off_diag).sum() / (n * (n-1))
        sum_xy = 2 * k_xy.sum() / (n * n)
    return sum_xx + sum_yy - sum_xy


class Unsupervised(Compute):
    '''
    Class to manage training, testing, and
    retrieving outputs.
    '''
    def __init__(self, data, Enc, Dec,  Dis_y, args):
        '''
        Constructor.
        Args
        ----
        Returns
        -------
        Compute object
        '''
        super(Unsupervised, self).__init__(data, Enc, Dec, Dis_y, args)

    def unlabeled_train_op_mmd_combine(self, update_enc=True):
        '''
        Trains the MMD model
        '''
        batch_size = self.args['batch_size']
        model_ctx = self.model_ctx
        eps = 1e-10

        # Retrieve data
        docs = self.data.get_documents(key='train')

        y_true = np.random.dirichlet(np.ones(self.ndim_y) * self.args['dirich_alpha'], size=batch_size)
        y_true = nd.array(y_true, ctx=model_ctx)

        with autograd.record():
            ### reconstruction phase ###
            y_onehot_u = self.Enc(docs)
            y_onehot_u_softmax = nd.softmax(y_onehot_u)
            if self.args['latent_noise'] > 0:
                y_noise = np.random.dirichlet(np.ones(self.ndim_y) * self.args['dirich_alpha'], size=batch_size)
                y_noise = nd.array(y_noise, ctx=model_ctx)
                y_onehot_u_softmax = (1 - self.args['latent_noise']) * y_onehot_u_softmax + self.args['latent_noise'] * y_noise
            x_reconstruction_u = self.Dec(y_onehot_u_softmax)

            logits = nd.log_softmax(x_reconstruction_u)
            loss_reconstruction = nd.mean(nd.sum(- docs * logits, axis=1))
            loss_total = loss_reconstruction * self.args['recon_alpha']

            ### mmd phase ###
            if self.args['adverse']:
                y_fake = self.Enc(docs)
                y_fake = nd.softmax(y_fake)
                loss_mmd = mmd_loss(y_true, y_fake, ctx_model=model_ctx, t=self.args['kernel_alpha'])
                loss_total = loss_total + loss_mmd

            if self.args['l2_alpha'] > 0:
                loss_total = loss_total + self.args['l2_alpha'] * nd.mean(nd.sum(nd.square(y_onehot_u), axis=1))

            loss_total.backward()

        self.optimizer_enc.step(1)
        self.optimizer_dec.step(1)  # self.m.args['batch_size']

        latent_max = nd.zeros(self.args['ndim_y'], ctx=model_ctx)
        for max_ind in nd.argmax(y_onehot_u, axis=1):
            latent_max[max_ind] += 1.0
        latent_max /= batch_size
        latent_entropy = nd.mean(nd.sum(- y_onehot_u_softmax * nd.log(y_onehot_u_softmax + eps), axis=1))
        latent_v = nd.mean(y_onehot_u_softmax, axis=0)
        dirich_entropy = nd.mean(nd.sum(- y_true * nd.log(y_true + eps), axis=1))

        if self.args['adverse']:
            loss_mmd_return = loss_mmd.asscalar()
        else:
            loss_mmd_return = 0.0
        return nd.mean(loss_reconstruction).asscalar(), loss_mmd_return, latent_max.asnumpy(), latent_entropy.asscalar(), latent_v.asnumpy(), dirich_entropy.asscalar()


    def retrain_enc(self, l2_alpha=0.1):
        docs = self.data.get_documents(key='train')
        with autograd.record():
            ### reconstruction phase ###
            y_onehot_u = self.Enc(docs)
            y_onehot_u_softmax = nd.softmax(y_onehot_u)
            x_reconstruction_u = self.Dec(y_onehot_u_softmax)

            logits = nd.log_softmax(x_reconstruction_u)
            loss_reconstruction = nd.mean(nd.sum(- docs * logits, axis=1))
            loss_reconstruction = loss_reconstruction + l2_alpha * nd.mean(nd.norm(y_onehot_u, ord=1, axis=1))
            loss_reconstruction.backward()

        self.optimizer_enc.step(1)
        return loss_reconstruction.asscalar()


    def unlabeled_train_op_adv_combine_add(self, update_enc=True):
        '''
        Trains the GAN model
        '''
        batch_size = self.args['batch_size']
        model_ctx = self.model_ctx
        eps = 1e-10
        ##########################
        ### unsupervised phase ###
        ##########################
        # Retrieve data
        docs = self.data.get_documents(key='train')

        class_true = nd.zeros(batch_size, dtype='int32', ctx=model_ctx)
        class_fake = nd.ones(batch_size, dtype='int32', ctx=model_ctx)
        loss_reconstruction = nd.zeros((1,), ctx=model_ctx)

        ### adversarial phase ###
        discriminator_z_confidence_true = nd.zeros(shape=(1,), ctx=model_ctx)
        discriminator_z_confidence_fake = nd.zeros(shape=(1,), ctx=model_ctx)
        discriminator_y_confidence_true = nd.zeros(shape=(1,), ctx=model_ctx)
        discriminator_y_confidence_fake = nd.zeros(shape=(1,), ctx=model_ctx)
        loss_discriminator = nd.zeros(shape=(1,), ctx=model_ctx)
        dirich_entropy = nd.zeros(shape=(1,), ctx=model_ctx)

        ### generator phase ###
        loss_generator = nd.zeros(shape=(1,), ctx=model_ctx)

        ### reconstruction phase ###
        with autograd.record():
            y_u = self.Enc(docs)
            y_onehot_u_softmax = nd.softmax(y_u)
            x_reconstruction_u = self.Dec(y_onehot_u_softmax)

            logits = nd.log_softmax(x_reconstruction_u)
            loss_reconstruction = nd.sum(- docs * logits, axis=1)
            loss_total = loss_reconstruction * self.args['recon_alpha']

            if self.args['adverse']: #and np.random.rand()<0.8:
                y_true = np.random.dirichlet(np.ones(self.ndim_y) * self.args['dirich_alpha'], size=batch_size)
                y_true = nd.array(y_true, ctx=model_ctx)
                dy_true = self.Dis_y(y_true)
                dy_fake = self.Dis_y(y_onehot_u_softmax)
                discriminator_y_confidence_true = nd.mean(nd.softmax(dy_true)[:, 0])
                discriminator_y_confidence_fake = nd.mean(nd.softmax(dy_fake)[:, 1])
                softmaxCEL = gluon.loss.SoftmaxCrossEntropyLoss()
                loss_discriminator = softmaxCEL(dy_true, class_true) + \
                                       softmaxCEL(dy_fake, class_fake)
                loss_generator = softmaxCEL(dy_fake, class_true)
                loss_total = loss_total + loss_discriminator + loss_generator
                dirich_entropy = nd.mean(nd.sum(- y_true * nd.log(y_true + eps), axis=1))

        loss_total.backward()

        self.optimizer_enc.step(batch_size)
        self.optimizer_dec.step(batch_size)
        self.optimizer_dis_y.step(batch_size)

        latent_max = nd.zeros(self.args['ndim_y'], ctx=model_ctx)
        for max_ind in nd.argmax(y_onehot_u_softmax, axis=1):
            latent_max[max_ind] += 1.0
        latent_max /= batch_size
        latent_entropy = nd.mean(nd.sum(- y_onehot_u_softmax * nd.log(y_onehot_u_softmax + eps), axis=1))
        latent_v = nd.mean(y_onehot_u_softmax, axis=0)

        return nd.mean(loss_discriminator).asscalar(), nd.mean(loss_generator).asscalar(), nd.mean(loss_reconstruction).asscalar(), \
               nd.mean(discriminator_z_confidence_true).asscalar(), nd.mean(discriminator_z_confidence_fake).asscalar(), \
               nd.mean(discriminator_y_confidence_true).asscalar(), nd.mean(discriminator_y_confidence_fake).asscalar(), \
               latent_max.asnumpy(), latent_entropy.asscalar(), latent_v.asnumpy(), dirich_entropy.asscalar()


    def test_synthetic_op(self):
        batch_size = self.args['batch_size']
        dataset = 'train'
        num_samps = self.data.data[dataset].shape[0]
        batches = int(np.ceil(num_samps / batch_size))
        batch_iter = range(batches)
        enc_out = nd.zeros(shape=(batches * batch_size, self.ndim_y))
        for batch in batch_iter:
            # 1. Retrieve data
            if self.args['data_source'] == 'Ian':
                docs = self.data.get_documents(key=dataset)
            # 2. Compute loss
            y_onehot_u = self.Enc(docs)
            y_onehot_softmax = nd.softmax(y_onehot_u)
            enc_out[batch*batch_size:(batch+1)*batch_size, :] = y_onehot_softmax

        return enc_out

    def test_op(self, num_samples=None, num_epochs=None, reset=True, dataset='test'):
        '''
        Evaluates the model using num_samples.
        Args
        ----
        num_samples: integer, default None
          The number of samples to evaluate on. This is converted to
          evaluating on (num_samples // batch_size) minibatches.
        num_epochs: integer, default None
          The number of epochs to evaluate on. This used if num_samples
          is not specified. If neither is specified, defaults to 1 epoch.
        reset: bool, default True
          Whether to reset the test data index to 0 before iterating
          through and evaluating on minibatches.
        dataset: string, default 'test':
          Which dataset to evaluate on: 'valid' or 'test'.
        Returns
        -------
        Loss_u: float
          The loss on the unlabeled data.
        Loss_l: float
          The loss on the labeled data.
        Eval_u: list of floats
          A list of evaluation metrics on the unlabeled data.
        Eval_l: list of floats
          A list of evaluation metrics on the labeled data.
        '''
        batch_size = self.args['batch_size']
        model_ctx = self.model_ctx

        if num_samples is None and num_epochs is None:
            # assume full dataset evaluation
            num_epochs = 1

        if reset:
            # Reset Data to Index Zero
            if self.data.data[dataset] is not None:
                self.data.force_reset_data(dataset)
            if self.data.data[dataset + '_with_labels'] is not None:
                self.data.force_reset_data(dataset+'_with_labels')

        # Unlabeled Data
        u_loss = 'NA'
        u_eval = []
        if self.data.data[dataset] is not None:
            u_loss = 0
            if num_samples is None:
                num_samps = self.data.data[dataset].shape[0] * num_epochs
            else:
                num_samps = num_samples
            batches = int(np.ceil(num_samps / self.args['batch_size']))
            batch_iter = range(batches)
            if batches > 1: batch_iter = tqdm(batch_iter, desc='unlabeled')
            for batch in batch_iter:
                # 1. Retrieve data
                docs = self.data.get_documents(key=dataset)

                # 2. Compute loss
                y_u = self.Enc(docs)
                y_onehot_u_softmax = nd.softmax(y_u)
                x_reconstruction_u = self.Dec(y_onehot_u_softmax)

                logits = nd.log_softmax(x_reconstruction_u)
                loss_recon_unlabel = nd.sum(- docs * logits, axis=1)

                # 3. Convert to numpy
                u_loss += nd.mean(loss_recon_unlabel).asscalar()
            u_loss /= batches

        # Labeled Data
        l_loss = 0.0
        l_acc = 0.0
        if self.data.data[dataset+'_with_labels'] is not None:
            l_loss = 0
            if num_samples is None:
                num_samps = self.data.data[dataset+'_with_labels'].shape[0] * num_epochs
            else:
                num_samps = num_samples
            batches = int(np.ceil(num_samps / self.args['batch_size']))
            batch_iter = range(batches)
            if batches > 1: batch_iter = tqdm(batch_iter, desc='labeled')
            softmaxCEL = gluon.loss.SoftmaxCrossEntropyLoss(sparse_label=False)
            for batch in batch_iter:
                # 1. Retrieve data
                labeled_docs, labels = self.data.get_documents(key=dataset+'_with_labels', split_on=self.data.data_dim)
                # 2. Compute loss
                y_u = self.Enc(docs)
                y_onehot_u_softmax = nd.softmax(y_u)
                class_pred = nd.argmax(y_onehot_u_softmax, axis=1)
                l_a = labels[list(range(labels.shape[0])), class_pred]
                l_acc += nd.mean(l_a).asscalar()
                labels = labels / nd.sum(labels, axis=1, keepdims=True)
                l_l = softmaxCEL(y_onehot_u_softmax, labels)

                # 3. Convert to numpy
                l_loss += nd.mean(l_l).asscalar()
            l_loss /= batches
            l_acc /= batches

        return u_loss, l_loss, l_acc


    def save_latent(self, saveto):
        before_softmax = True
        try:
            if type(self.data.data['train']) is np.ndarray:
                dataset_train = gluon.data.dataset.ArrayDataset(self.data.data['train'])
                train_data = gluon.data.DataLoader(dataset_train, self.args['batch_size'], shuffle=False, last_batch='discard')

                dataset_val = gluon.data.dataset.ArrayDataset(self.data.data['valid'])
                val_data = gluon.data.DataLoader(dataset_val, self.args['batch_size'], shuffle=False, last_batch='discard')

                dataset_test = gluon.data.dataset.ArrayDataset(self.data.data['test'])
                test_data = gluon.data.DataLoader(dataset_test, self.args['batch_size'], shuffle=False, last_batch='discard')
            else:
                train_data = io.NDArrayIter(data={'data': self.data.data['train']}, batch_size=self.args['batch_size'],
                                            shuffle=False, last_batch_handle='discard')
                val_data = io.NDArrayIter(data={'data': self.data.data['valid']}, batch_size=self.args['batch_size'],
                                            shuffle=False, last_batch_handle='discard')
                test_data = io.NDArrayIter(data={'data': self.data.data['test']}, batch_size=self.args['batch_size'],
                                            shuffle=False, last_batch_handle='discard')
        except:
            print("Loading error during save_latent. Probably caused by not having validation or test set!")
            return

        train_output = np.zeros((self.data.data['train'].shape[0], self.ndim_y))
        # train_label_output = np.zeros(self.data.data['train'].shape[0])
        # for i, (data, label) in enumerate(train_data):
        for i, data in enumerate(train_data):
            if type(data) is io.DataBatch:
                data = data.data[0].as_in_context(self.model_ctx)
            else:
                data = data.as_in_context(self.model_ctx)
            if before_softmax:
                output = self.Enc(data)
            else:
                output = nd.softmax(self.Enc(data))
            train_output[i*self.args['batch_size']:(i+1)*self.args['batch_size']] = output.asnumpy()
            # train_label_output[i*self.args['batch_size']:(i+1)*self.args['batch_size']] = label.asnumpy()
        train_output = np.delete(train_output, np.s_[(i+1)*self.args['batch_size']:], 0)
        # train_label_output = np.delete(train_label_output, np.s_[(i+1)*self.args['batch_size']:])
        np.save(os.path.join(saveto, self.args['domain']+'train_latent.npy'), train_output)
        # np.save(os.path.join(saveto, self.args['domain']+'train_latent_label.npy'), train_label_output)

        val_output = np.zeros((self.data.data['valid'].shape[0], self.ndim_y))
        # train_label_output = np.zeros(self.data.data['train'].shape[0])
        # for i, (data, label) in enumerate(train_data):
        for i, data in enumerate(val_data):
            if type(data) is io.DataBatch:
                data = data.data[0].as_in_context(self.model_ctx)
            else:
                data = data.as_in_context(self.model_ctx)
            if before_softmax:
                output = self.Enc(data)
            else:
                output = nd.softmax(self.Enc(data))
            val_output[i*self.args['batch_size']:(i+1)*self.args['batch_size']] = output.asnumpy()
            # train_label_output[i*self.args['batch_size']:(i+1)*self.args['batch_size']] = label.asnumpy()
        val_output = np.delete(val_output, np.s_[(i+1)*self.args['batch_size']:], 0)
        # train_label_output = np.delete(train_label_output, np.s_[(i+1)*self.args['batch_size']:])
        np.save(os.path.join(saveto, self.args['domain']+'val_latent.npy'), val_output)
        # np.save(os.path.join(saveto, self.args['domain']+'train_latent_label.npy'), train_label_output)

        test_output = np.zeros((self.data.data['test'].shape[0], self.ndim_y))
        # test_label_output = np.zeros(self.data.data['test'].shape[0])
        # for i, (data, label) in enumerate(test_data):
        for i, data in enumerate(test_data):
            if type(data) is io.DataBatch:
                data = data.data[0].as_in_context(self.model_ctx)
            else:
                data = data.as_in_context(self.model_ctx)
            if before_softmax:
                output = self.Enc(data)
            else:
                output = nd.softmax(self.Enc(data))
            test_output[i*self.args['batch_size']:(i+1)*self.args['batch_size']] = output.asnumpy()
            # test_label_output[i*self.args['batch_size']:(i+1)*self.args['batch_size']] = label.asnumpy()
        test_output = np.delete(test_output, np.s_[(i+1)*self.args['batch_size']:], 0)
        # test_label_output = np.delete(test_label_output, np.s_[(i+1)*self.args['batch_size']:])
        np.save(os.path.join(saveto, self.args['domain']+'test_latent.npy'), test_output)
        # np.save(os.path.join(saveto, self.args['domain']+'test_latent_label.npy'), test_label_output)

In [8]:
# run.py
import os
import shutil
import argparse
import datetime
import pickle
import time

import numpy as np
import matplotlib as mpl
mpl.use('Agg')

import sys
sys.path.append('../')

import mxnet as mx
mx.random.seed(int(time.time()))

# from utils import gpu_helper, gpu_exists, calc_topic_uniqueness, get_topic_words_decoder_weights, request_pmi, print_topic_with_scores, print_topics


from tqdm import tqdm

def parse_args():
    parser = argparse.ArgumentParser(description='Training WAE in MXNet')
    parser.add_argument('-dom','--domain', type=str, default='twenty_news', help='domain to run', required=False)
    parser.add_argument('-data','--data_path', type=str, default='', help='file path for dataset', required=False)
    parser.add_argument('-max_labels','--max_labels', type=int, default=100, help='max number of topics to specify as labels for a single training document', required=False)
    parser.add_argument('-max_labeled_samples','--max_labeled_samples', type=int, default=10, help='max number of labeled samples per topic', required=False)
    parser.add_argument('-label_seed','--label_seed', type=lambda x: int(x) if x != 'None' else None, default=None, help='random seed for subsampling the labeled dataset', required=False)
    parser.add_argument('-mod','--model', type=str, default='dirichlet', help='model to use', required=False)
    parser.add_argument('-desc','--description', type=str, default='', help='description for the experiment', required=False)
    parser.add_argument('-alg','--algorithm', type=str, default='standard', help='algorithm to use for training: standard', required=False)
    parser.add_argument('-bs','--batch_size', type=int, default=256, help='batch_size for training', required=False)
    parser.add_argument('-opt','--optim', type=str, default='Adam', help='encoder training algorithm', required=False)
    parser.add_argument('-lr','--learning_rate', type=float, default=1e-4, help='learning rate', required=False)
    parser.add_argument('-l2','--weight_decay', type=float, default=0., help='weight decay', required=False)
    parser.add_argument('-e_nh','--enc_n_hidden', type=int, nargs='+', default=[128], help='# of hidden units for encoder or list of hiddens for each layer', required=False)
    parser.add_argument('-e_nl','--enc_n_layer', type=int, default=1, help='# of hidden layers for encoder, set to -1 if passing list of n_hiddens', required=False)
    parser.add_argument('-e_nonlin','--enc_nonlinearity', type=str, default='sigmoid', help='type of nonlinearity for encoder', required=False)
    parser.add_argument('-e_weights','--enc_weights', type=str, default='', help='file path for encoder weights', required=False)
    parser.add_argument('-e_freeze','--enc_freeze', type=lambda x: (str(x).lower() == 'true'), default=False, help='whether to freeze the encoder weights', required=False)
    parser.add_argument('-lat_nonlin','--latent_nonlinearity', type=str, default='', help='type of to use prior to decoder', required=False)
    parser.add_argument('-d_nh','--dec_n_hidden', type=int, nargs='+', default=[128], help='# of hidden units for decoder or list of hiddens for each layer', required=False)
    parser.add_argument('-d_nl','--dec_n_layer', type=int, default=0, help='# of hidden layers for decoder', required=False)
    parser.add_argument('-d_nonlin','--dec_nonlinearity', type=str, default='', help='type of nonlinearity for decoder', required=False)
    parser.add_argument('-d_weights','--dec_weights', type=str, default='', help='file path for decoder weights', required=False)
    parser.add_argument('-d_freeze','--dec_freeze', type=lambda x: (str(x).lower() == 'true'), default=False, help='whether to freeze the decoder weights', required=False)
    parser.add_argument('-d_word_dist','--dec_word_dist', type=lambda x: (str(x).lower() == 'true'), default=False, help='whether to init decoder weights with training set word distributions', required=False)
    parser.add_argument('-dis_nh','--dis_n_hidden', type=int, nargs='+', default=[128], help='# of hidden units for encoder or list of hiddens for each layer', required=False)
    parser.add_argument('-dis_nl','--dis_n_layer', type=int, default=1, help='# of hidden layers for encoder, set to -1 if passing list of n_hiddens', required=False)
    parser.add_argument('-dis_nonlin','--dis_nonlinearity', type=str, default='sigmoid', help='type of nonlinearity for discriminator', required=False)
    parser.add_argument('-dis_y_weights','--dis_y_weights', type=str, default='', help='file path for discriminator_y weights', required=False)
    parser.add_argument('-dis_z_weights','--dis_z_weights', type=str, default='', help='file path for discriminator_z weights', required=False)
    parser.add_argument('-dis_freeze','--dis_freeze', type=lambda x: (str(x).lower() == 'true'), default=False, help='whether to freeze the encoder weights', required=False)
    parser.add_argument('-include_w','--include_weights', type=str, nargs='*', default=[], help='weights to train on (default is all weights) -- all others are kept fixed; Ex: E.z_encoder D.decoder', required=False)
    parser.add_argument('-eps','--epsilon', type=float, default=1e-8, help='epsilon param for Adam', required=False)
    parser.add_argument('-mx_it','--max_iter', type=int, default=50001, help='max # of training iterations', required=False)
    parser.add_argument('-train_stats_every','--train_stats_every', type=int, default=100, help='skip train_stats_every iterations between recording training stats', required=False)
    parser.add_argument('-eval_stats_every','--eval_stats_every', type=int, default=100, help='skip eval_stats_every iterations between recording evaluation stats', required=False)
    parser.add_argument('-ndim_y','--ndim_y', type=int, default=256, help='dimensionality of y - topic indicator', required=False)
    parser.add_argument('-ndim_x','--ndim_x', type=int, default=2, help='dimensionality of p(x) - data distribution', required=False)
    parser.add_argument('-saveto','--saveto', type=str, default='', help='path prefix for saving results', required=False)
    parser.add_argument('-gpu','--gpu', type=int, default=-2, help='if/which gpu to use (-1: all, -2: None)', required=False)
    parser.add_argument('-hybrid','--hybridize', type=lambda x: (str(x).lower() == 'true'), default=False, help='declaritive True (hybridize) or imperative False', required=False)
    parser.add_argument('-full_npmi','--full_npmi', type=lambda x: (str(x).lower() == 'true'), default=False, help='whether to compute NPMI for full trajectory', required=False)
    parser.add_argument('-eot','--eval_on_test', type=lambda x: (str(x).lower() == 'true'), default=False, help='whether to evaluate on the test set (True) or validation set (False)', required=False)
    parser.add_argument('-verb','--verbose', type=lambda x: (str(x).lower() == 'true'), default=True, help='whether to print progress to stdout', required=False)
    parser.add_argument('-dirich_alpha','--dirich_alpha', type=float, default=1e-1, help='param for Dirichlet prior', required=False)
    parser.add_argument('-adverse','--adverse', type=lambda x: (str(x).lower() == 'true'), default=True, help='whether to turn on adverserial training (MMD or GAN). set to False if only train auto-encoder', required=False)
    parser.add_argument('-update_enc','--update_enc', type=lambda x: (str(x).lower() == 'true'), default=True, help='whether to update encoder for unlabed_train_op()', required=False)
    parser.add_argument('-labeled_loss_lambda','--labeled_loss_lambda', type=float, default=1.0, help='param for Dirichlet noise for label', required=False)
    parser.add_argument('-train_mode','--train_mode', type=str, default='mmd', help="set to mmd or adv (for GAN)", required=False)
    parser.add_argument('-kernel_alpha','--kernel_alpha', type=float, default=1.0, help='param for information diffusion kernel', required=False)
    parser.add_argument('-recon_alpha','--recon_alpha', type=float, default=-1.0, help='multiplier of the reconstruction loss when combined with mmd loss', required=False)
    parser.add_argument('-recon_alpha_adapt','--recon_alpha_adapt', type=float, default=-1.0, help='adaptively change recon_alpha so that [total loss = mmd + recon_alpha_adapt * recon loss], set to -1 if no adapt', required=False)
    parser.add_argument('-dropout_p','--dropout_p', type=float, default=-1.0, help='dropout probability in encoder', required=False)
    parser.add_argument('-l2_alpha','--l2_alpha', type=float, default=-1.0, help='alpha multipler for L2 regularization on latent vector', required=False)
    parser.add_argument('-latent_noise','--latent_noise', type=float, default=0.0, help='proportion of dirichlet noise added to the latent vector after softmax', required=False)
    parser.add_argument('-topic_decoder_weight','--topic_decoder_weight', type=lambda x: (str(x).lower() == 'true'), default=False, help='extract topic words based on decoder weights or decoder outputs', required=False)
    parser.add_argument('-retrain_enc_only','--retrain_enc_only', type=lambda x: (str(x).lower() == 'true'), default=False, help='only retrain the encoder for reconstruction loss', required=False)
    parser.add_argument('-l2_alpha_retrain','--l2_alpha_retrain', type=float, default=0.1, help='alpha multipler for L2 regularization on encoder output during retraining', required=False)
    args = vars(parser.parse_args())


    if args['domain'] == 'twenty_news_sklearn':
        from examples.domains.twenty_news_sklearn_wae import TwentyNews as Domain
    elif args['domain'] == 'wikitext-103':
        from examples.domains.wikitext103_wae import Wikitext103 as Domain
    elif args['domain'] == 'nytimes-pbr':
        from examples.domains.nyt_wae import Nytimes as Domain
    elif args['domain'] == 'ag_news_csv':
        from examples.domains.ag_news_wae import Agnews as Domain
    elif args['domain'] == 'dbpedia_csv':
        from examples.domains.dbpedia_wae import Dbpedia as Domain
    elif args['domain'] == 'yelp_review_polarity_csv':
        from examples.domains.yelp_polarity_wae import YelpPolarity as Domain
    elif args['domain'] == 'lda_synthetic':
        from examples.domains.lda_synthetic import LdaSynthetic as Domain
    else:
        raise NotImplementedError(args['domain'])

    if args['model'] == 'dirichlet':
#         from models.dirichlet import Encoder, Decoder, Discriminator_y
        from dirichlet import Encoder, Decoder, Discriminator_y
    else:
        raise NotImplementedError(args['model'])

    from compute_op import Unsupervised as Compute

    assert args['latent_noise'] >= 0 and args['latent_noise'] <= 1
    if args['description'] == '':
        args['description'] = args['domain'] + '-' + args['algorithm'] + '-' + args['model']
        if args['un_label_coeffs'][0] > 0 and args['un_label_coeffs'][1] == 0:
            args['description'] += '-unsup'
        elif args['un_label_coeffs'][0] > 0 and args['un_label_coeffs'][1] > 0:
            args['description'] += '-semisup'
        else:
            args['description'] += '-sup'
    elif args['description'].isdigit():
        args['description'] = args['domain'] + '-' + args['algorithm'] + '-' + args['model'] + '-' + args['description']

    if args['saveto'] == '':
        args['saveto'] = 'examples/results/' + args['description'].replace('-','/')

    saveto = args['saveto'] + '/' + datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S/{}').format('')
    if not os.path.exists(saveto):
        os.makedirs(saveto)
        os.makedirs(saveto + '/weights/encoder')
        os.makedirs(saveto + '/weights/decoder')
        os.makedirs(saveto + '/weights/discriminator_y')
        os.makedirs(saveto + '/weights/discriminator_z')
    shutil.copy(os.path.realpath('compute_op.py'), os.path.join(saveto, 'compute_op.py'))
    shutil.copy(os.path.realpath('core.py'), os.path.join(saveto, 'core.py'))
    shutil.copy(os.path.realpath('run.py'), os.path.join(saveto, 'run.py'))
    shutil.copy(os.path.realpath('utils.py'), os.path.join(saveto, 'utils.py'))

    # domain_file = args['domain']+'.py'
    # shutil.copy(os.path.realpath('examples/domains/'+domain_file), os.path.join(saveto, domain_file))
    model_file = args['model']+'.py'
    shutil.copy(os.path.realpath('models/'+model_file), os.path.join(saveto, model_file))
    args['saveto'] = saveto
    with open(saveto+'args.txt', 'w') as file:
        for key, val in args.items():
            if val != '':
                if isinstance(val, list) or isinstance(val, tuple):
                    val = [str(v) for v in val]
                    file.write('--'+str(key)+' '+' '.join(val)+'\n')
                else:
                    file.write('--'+str(key)+' '+str(val)+'\n')

    if args['gpu'] >= 0 and gpu_exists(args['gpu']):
        args['description'] += ' (gpu'+str(args['gpu'])+')'
    else:
        args['description'] += ' (cpu)'

    pickle.dump(args, open(args['saveto']+'args.p','wb'))

    return Compute, Domain, Encoder, Decoder, Discriminator_y, args


def run_experiment(Compute, Domain, Encoder, Decoder, Discriminator_y, args):
    print('\nSaving to: '+args['saveto'])

    model_ctx = gpu_helper(args['gpu'])

    data = Domain(batch_size=args['batch_size'], data_path=args['data_path'], ctx=model_ctx, saveto=args['saveto'])
    print('train dimension = ', data.data['train'].shape)
    if type(data.data['train']) is np.ndarray:
        mean_length = np.mean(np.sum(data.data['train'], axis=1))
    else:
        mean_length = mx.nd.mean(mx.nd.sum(data.data['train'], axis=1)).asscalar()
    vocab_size = data.data['train'].shape[1]
    if data.data['train_with_labels'] is not None:
        print('train_with_labels dimension = ', data.data['train_with_labels'].shape)

    if args['recon_alpha'] < 0:
        args['recon_alpha'] = 1.0 / (mean_length * np.log(vocab_size))
    print('Setting recon_alpha to {}'.format(args['recon_alpha']))

    Enc = Encoder(model_ctx=model_ctx, batch_size=args['batch_size'], input_dim=args['ndim_x'], ndim_y=args['ndim_y'],
                  n_hidden=args['enc_n_hidden'], n_layers=args['enc_n_layer'], nonlin=args['enc_nonlinearity'],
                  weights_file=args['enc_weights'], freeze=args['enc_freeze'], latent_nonlin=args['latent_nonlinearity'])
    Dec = Decoder(model_ctx=model_ctx, batch_size=args['batch_size'], output_dim=args['ndim_x'], ndim_y=args['ndim_y'],
                  n_hidden=args['dec_n_hidden'], n_layers=args['dec_n_layer'], nonlin=args['dec_nonlinearity'],
                  weights_file=args['dec_weights'], freeze=args['dec_freeze'], latent_nonlin=args['latent_nonlinearity'])
    Dis_y = Discriminator_y(model_ctx=model_ctx, batch_size=args['batch_size'], ndim_y=args['ndim_y'],
                            n_hidden=args['dis_n_hidden'], n_layers=args['dis_n_layer'],
                            nonlin=args['dis_nonlinearity'], weights_file=args['dis_y_weights'],
                            freeze=args['dis_freeze'], latent_nonlin=args['latent_nonlinearity'])
    if args['enc_weights']:
        Enc.load_parameters(args['enc_weights'], ctx=model_ctx)
    else:
        Enc.init_weights()
    if args['dec_weights']:
        Dec.load_parameters(args['dec_weights'], ctx=model_ctx)
    else:
        Dec.init_weights()
    Dis_y.init_weights()
    # load pre-trained document classifier
    if args['hybridize']:
        print('NOTE: Hybridizing Encoder and Decoder (Declaritive mode).')
        Enc.hybridize()
        Dec.hybridize()
        Dis_y.hybridize()
    else:
        print('NOTE: Not Hybridizing Encoder and Decoder (Imperative mode).')

    compute = Compute(data, Enc, Dec,  Dis_y, args)

    N_train = data.data['train'].shape[0]

    epochs = range(args['max_iter'])
    if args['verbose']:
        print(' ')
        epochs = tqdm(epochs, desc=args['description'])

    train_record = {'loss_discriminator':[], 'loss_generator':[], 'loss_reconstruction':[], 'latent_max_distr':[],
                    'latent_avg_entropy':[], 'latent_avg':[], 'dirich_avg_entropy':[], 'loss_labeled':[]}
    eval_record = {'NPMI':[], 'Topic Uniqueness':[], 'Top Words':[],
                   'NPMI2':[], 'Topic Uniqueness2':[], 'Top Words2':[],
                   'u_loss_train':[], 'l_loss_train':[],
                   'u_loss_val':[], 'l_loss_val':[],
                   'u_loss_test':[], 'l_loss_test':[],
                   'l_acc_train':[], 'l_acc_val':[], 'l_acc_test':[]}

    total_iterations_train = N_train // args['batch_size']
    training_start_time = time.time()
    i = 0
    if args['retrain_enc_only']:
        print('Retraining encoder ONLY!')
        for i in epochs:
            sum_loss_autoencoder = 0.0
            epoch_start_time = time.time()
            for itr in range(total_iterations_train):
                loss_reconstruction = compute.retrain_enc(args['l2_alpha_retrain'])
                sum_loss_autoencoder += loss_reconstruction
            if args['verbose']:
                # epochs.set_postfix({'L_Dis': loss_discriminator, 'L_Gen': loss_generator, 'L_Recon': loss_reconstruction})
                print("Epoch {} done in {} sec - loss: a={:.5g} - total {} min".format(
                    i + 1, int(time.time() - epoch_start_time),
                    sum_loss_autoencoder / total_iterations_train,
                    int((time.time() - training_start_time) // 60)))
    else:
        for i in epochs:
            sum_loss_generator = 0.0
            sum_loss_discriminator = 0.0
            sum_loss_autoencoder = 0.0
            sum_discriminator_z_confidence_true = 0.0
            sum_discriminator_z_confidence_fake = 0.0
            sum_discriminator_y_confidence_true = 0.0
            sum_discriminator_y_confidence_fake = 0.0
            sum_loss_labeled = 0.0

            latent_max_distr = np.zeros(args['ndim_y'])
            latent_entropy_avg = 0.0
            latent_v_avg = np.zeros(args['ndim_y'])
            dirich_avg_entropy = 0.0

            epoch_start_time = time.time()
            for itr in range(total_iterations_train):
                if args['train_mode'] == 'mmd':
                    loss_reconstruction, loss_discriminator, latent_max, latent_entropy, latent_v, dirich_entropy = \
                        compute.unlabeled_train_op_mmd_combine(update_enc=args['update_enc'])
                    loss_generator, \
                    discriminator_z_confidence_true, discriminator_z_confidence_fake, \
                    discriminator_y_confidence_true, discriminator_y_confidence_fake = 0,0,0,0,0
                elif args['train_mode'] == 'adv':
                    loss_discriminator, loss_generator, loss_reconstruction, \
                    discriminator_z_confidence_true, discriminator_z_confidence_fake, \
                    discriminator_y_confidence_true, discriminator_y_confidence_fake, \
                    latent_max, latent_entropy, latent_v, dirich_entropy = \
                        compute.unlabeled_train_op_adv_combine_add(update_enc=args['update_enc'])

                sum_loss_discriminator += loss_discriminator
                sum_loss_generator += loss_generator
                sum_loss_autoencoder += loss_reconstruction
                sum_discriminator_z_confidence_true += discriminator_z_confidence_true
                sum_discriminator_z_confidence_fake += discriminator_z_confidence_fake
                sum_discriminator_y_confidence_true += discriminator_y_confidence_true
                sum_discriminator_y_confidence_fake += discriminator_y_confidence_fake

                latent_max_distr += latent_max
                latent_entropy_avg += latent_entropy
                latent_v_avg += latent_v
                dirich_avg_entropy += dirich_entropy

            train_record['loss_discriminator'].append(sum_loss_discriminator / total_iterations_train)
            train_record['loss_generator'].append(sum_loss_generator / total_iterations_train)
            train_record['loss_reconstruction'].append(sum_loss_autoencoder / total_iterations_train)
            train_record['latent_max_distr'].append(latent_max_distr / total_iterations_train)
            train_record['latent_avg_entropy'].append(latent_entropy_avg / total_iterations_train)
            train_record['latent_avg'].append(latent_v_avg / total_iterations_train)
            train_record['dirich_avg_entropy'].append(dirich_avg_entropy / total_iterations_train)
            train_record['loss_labeled'].append(sum_loss_labeled / total_iterations_train)
            if args['verbose']:
                # epochs.set_postfix({'L_Dis': loss_discriminator, 'L_Gen': loss_generator, 'L_Recon': loss_reconstruction})
                print("Epoch {} done in {} sec - loss: g={:.5g}, d={:.5g}, a={:.5g}, label={:.5g} - disc_z: true={:.1f}%, fake={:.1f}% - disc_y: true={:.1f}%, fake={:.1f}% - total {} min".format(
                    i + 1, int(time.time() - epoch_start_time),
                    sum_loss_generator / total_iterations_train,
                    sum_loss_discriminator / total_iterations_train,
                    sum_loss_autoencoder / total_iterations_train,
                    sum_loss_labeled / total_iterations_train,
                    sum_discriminator_z_confidence_true / total_iterations_train * 100,
                    sum_discriminator_z_confidence_fake / total_iterations_train * 100,
                    sum_discriminator_y_confidence_true / total_iterations_train * 100,
                    sum_discriminator_y_confidence_fake / total_iterations_train * 100,
                    int((time.time() - training_start_time) // 60)))
                print('Latent avg entropy = {}, dirich_entropy={}'.format(
                    train_record['latent_avg_entropy'][-1], train_record['dirich_avg_entropy'][-1]))
                # print(train_record['latent_avg'][-1])
            if i == (args['max_iter'] - 1) or (args['eval_stats_every'] > 0 and i % args['eval_stats_every'] == 0):
                if args['recon_alpha_adapt'] > 0 and i == 0:
                    compute.args['recon_alpha'] = train_record['loss_discriminator'][-1] / \
                                                  train_record['loss_reconstruction'][-1]
                    compute.args['recon_alpha'] = abs(compute.args['recon_alpha']) * args['recon_alpha_adapt']
                    print("recon_alpha adjusted to {}".format(compute.args['recon_alpha']))

                if args['domain'] == 'synthetic':
                    enc_out = compute.test_synthetic_op()
                    np.save(os.path.join(args['saveto'], "enc_out_epoch{}".format(i)), enc_out.asnumpy())
                else:
                    # extract topic words from decoder output:
                    topic_words = get_topic_words_decoder_weights(Dec, data, model_ctx, decoder_weights=False)
                    topic_uniqs = calc_topic_uniqueness(topic_words)
                    eval_record['Topic Uniqueness'].append(np.mean(list(topic_uniqs.values())))
                    topic_json = dict()
                    for tp in range(len(topic_words)):
                        topic_json[tp] = topic_words[tp]
                    pmi_dict, npmi_dict = request_pmi(topic_dict=topic_json, port=1234)
                    eval_record['NPMI'].append(np.mean(list(npmi_dict.values())))
                    print("Topic Eval (decoder output): Uniq={:.5g}, NPMI={:.5g}".format(
                        eval_record['Topic Uniqueness'][-1], eval_record['NPMI'][-1]))
                    eval_record['Top Words'].append(topic_json)
                    print_topics(topic_json, npmi_dict, topic_uniqs, data)

                    # extract topic words from decoder weight matrix:
                    topic_words = get_topic_words_decoder_weights(Dec, data, model_ctx, decoder_weights=True)
                    topic_uniqs = calc_topic_uniqueness(topic_words)
                    eval_record['Topic Uniqueness2'].append(np.mean(list(topic_uniqs.values())))
                    topic_json = dict()
                    for tp in range(len(topic_words)):
                        topic_json[tp] = topic_words[tp]
                    pmi_dict, npmi_dict = request_pmi(topic_dict=topic_json, port=1234)
                    eval_record['NPMI2'].append(np.mean(list(npmi_dict.values())))
                    print("Topic Eval (decoder weight): Uniq={:.5g}, NPMI={:.5g}".format(
                        eval_record['Topic Uniqueness2'][-1], eval_record['NPMI2'][-1]))
                    eval_record['Top Words2'].append(topic_json)
                    print_topics(topic_json, npmi_dict, topic_uniqs, data)

                    # evaluate train, validate and test losses for w/ w/o labels:
                    u_loss_train, l_loss_train, u_loss_val, l_loss_val, u_loss_test, l_loss_test, l_acc_train, \
                    l_acc_val, l_acc_test = 0,0,0,0,0,0,0,0,0
                    u_loss_train, l_loss_train, l_acc_train = compute.test_op(dataset='train')
                    eval_record['u_loss_train'].append(u_loss_train)
                    eval_record['l_loss_train'].append(l_loss_train)
                    eval_record['l_acc_train'].append(l_acc_train)
                    if data.data['valid'] is not None:
                        u_loss_val, l_loss_val, l_acc_val = compute.test_op(dataset='valid')
                        eval_record['u_loss_val'].append(u_loss_val)
                        eval_record['l_loss_val'].append(l_loss_val)
                        eval_record['l_acc_val'].append(l_acc_val)
                    if data.data['test'] is not None:
                        u_loss_test, l_loss_test, l_acc_test = compute.test_op(dataset='test')
                        eval_record['u_loss_test'].append(u_loss_test)
                        eval_record['l_loss_test'].append(l_loss_test)
                        eval_record['l_acc_test'].append(l_acc_test)
                    print("Train loss u-l-acc: {:.5g}-{:.5g}-{:.5g}, Val: {:.5g}-{:.5g}-{:.5g}, Test: {:.5g}-{:.5g}-{:.5g}".format(
                        u_loss_train, l_loss_train, l_acc_train, u_loss_val, l_loss_val, l_acc_val, u_loss_test, l_loss_test, l_acc_test))

                    pickle.dump(train_record,open(args['saveto']+'train_record.p','wb'))
                    pickle.dump(eval_record,open(args['saveto']+'eval_record.p','wb'))

    # save final weights
    Enc.save_parameters(args['saveto']+'weights/encoder/Enc_'+str(i))
    Dec.save_parameters(args['saveto']+'weights/decoder/Dec_'+str(i))
    Dis_y.save_parameters(args['saveto']+'weights/discriminator_y/Dis_y_'+str(i))

    # save the latent features
    compute.save_latent(args['saveto'])

    if args['domain'] == 'lda_synthetic':
        # save the decoder weight matrix
        params = Dec.collect_params()
        params = params['decoder0_dense0_weight'].data().transpose()
        np.save(args['saveto']+'decoder_weight.npy', params.asnumpy())

    # print_topic_with_scores(eval_record['Top Words'][-1])
    print('Done! ' + args['description'])
    print('\nSaved to: '+args['saveto'])


In [9]:
# model.dirichlet.py
import numpy as np
from scipy.special import logit, expit

import mxnet as mx
from mxnet.gluon import nn

# from core import ENet, DNet


class Encoder(ENet):
    '''
    A gluon HybridBlock Encoder class
    '''
    def __init__(self, model_ctx, batch_size, input_dim, n_hidden=64, ndim_y=16, ndim_z=10, n_layers=0, nonlin=None,
                 weights_file='', freeze=False, latent_nonlin='sigmoid', **kwargs):
        '''
        Constructor for encoder.
        Args
        ----
        model_ctx: mxnet device context, No default
          Which device to store/run the data and model on.
        batch_size: integer, No default
          The minibatch size.
        input_dim: integer, No default
          The data dimensionality that is input to the encoder.
        n_hidden: integer or list, default 64
          If integer, specifies the number of hidden units in
          every hidden layer.
          If list, each element specifies the number of hidden
          units in each hidden layer.
        output_dim: integer, default 10
          The dimensionality of the latent space, z.
        n_layers: integer, default 0
          The number of hidden layers.
        nonlin: string, default None
          The nonlinearity to use in every hidden layer.
        weights_file: string, default ''
          The path to the file (mxnet params file or pickle file)
          containing weights for each layer of the encoder.
        freeze: boolean, default False
          Whether to freeze the encoder weights (MIGHT BE BROKEN).
        latent_nonlin: string, default 'sigmoid'
          Which space to use for the latent variable:
            if 'sigmoid': z in (0,1)
            else: z in (-inf,inf)
        Parameters
        ----------
        Returns
        -------
        encoder object.
        '''
        super(Encoder, self).__init__()

        if n_layers >= 0:
            if isinstance(n_hidden, list):
                n_hidden = n_hidden[0]
                print('NOTE: Encoder ignoring list of hiddens because n_layer >= 0. Just using first element.')
            n_hidden = n_layers*[n_hidden]
        else:
            n_layers = len(n_hidden)
            print('NOTE: Encoder reading n_hidden as list.')

        if nonlin == '':
            nonlin = None
        
        in_units = input_dim
        with self.name_scope():
            self.main = nn.HybridSequential(prefix='encoder')
            for i in range(n_layers):
                self.main.add(nn.Dense(n_hidden[i], in_units=in_units, activation=nonlin))
                in_units = n_hidden[i]
            self.main.add(nn.Dense(ndim_y, in_units=in_units, activation=None))

        self.model_ctx = model_ctx
        self.input_dim = input_dim
        self.n_hidden = n_hidden
        self.ndim_y = ndim_y
        self.ndim_z = ndim_z
        self.batch_size = batch_size
        self.n_layers = n_layers
        self.nonlin = nonlin
        self.latent_nonlin = latent_nonlin
        self.weights_file = weights_file
        self.freeze = freeze
        self.dist_params = [None]

    def hybrid_forward(self, F, x):
        '''
        Passes the input through the encoder.
        Args
        ----
        F: mxnet.nd or mxnet.sym, No default
          This will be passed implicitly when calling hybrid forward.
        x: NDarray or mxnet symbol, No default
          The input to the encoder.
        Returns
        -------
        dist_params: list
          A list of the posterior parameters as NDarrays, each being of size batch_size x z_dim.
        samples: NDarray
          The posterior samples as a batch_size x z_dim NDarray.
        '''

        y = self.main(x)

        return y


class Decoder(DNet):
    '''
    A gluon HybridBlock Decoder class with Multinomial likelihood, p(x|z).
    '''
    def __init__(self, model_ctx, batch_size, output_dim, ndim_y=16,  n_hidden=64, n_layers=0, nonlin='',
                 weights_file='', freeze=False, latent_nonlin='', **kwargs):
        '''
        Constructor for Multinomial decoder.
        Args
        ----
        model_ctx: mxnet device context, No default
          Which device to store/run the data and model on.
        batch_size: integer, No default
          The minibatch size.
        n_hidden: integer or list, default 64
          If integer, specifies the number of hidden units in
          every hidden layer.
          If list, each element specifies the number of hidden
          units in each hidden layer.
        output_dim: integer, No default
          The dimensionality of the latent space, z.
        n_layers: integer, default 0
          The number of hidden layers.
        nonlin: string, default 'sigmoid'
          The nonlinearity to use in every hidden layer.
        weights_file: string, default ''
          The path to the file (mxnet params file or pickle file)
          containing weights for each layer of the encoder.
        freeze: boolean, default False
          Whether to freeze the encoder weights (MIGHT BE BROKEN).
        latent_nonlin: string, default 'sigmoid'
          Which space to use for the latent variable:
            if 'sigmoid': z in (0,1)
            else: z in (-inf,inf)
        Parameters
        ----------
        Returns
        -------
        Multinomial decoder object.
        '''
        super(Decoder, self).__init__()

        if n_layers >= 0:
            if isinstance(n_hidden, list):
                n_hidden = n_hidden[0]
                print('NOTE: Decoder ignoring list of hiddens because n_layer >= 0. Just using first element.')
            n_hidden = n_layers*[n_hidden]
        else:
            n_layers = len(n_hidden)
            print('NOTE: Decoder reading n_hidden as list.')

        if nonlin == '':
            nonlin = None

        in_units = n_hidden[0]
        with self.name_scope():
            self.main = nn.HybridSequential(prefix='decoder')
            self.main.add(nn.Dense(n_hidden[0], in_units=ndim_y, activation=None))

        self.model_ctx = model_ctx
        self.batch_size = batch_size
        self.ndim_y = ndim_y
        self.n_hidden = n_hidden
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.nonlin = nonlin
        self.latent_nonlin = latent_nonlin
        self.weights_file = weights_file
        self.freeze = freeze

    def hybrid_forward(self, F, y):
        '''
        Passes the input through the decoder.
        Args
        ----
        F: mxnet.nd or mxnet.sym, No default
          This will be passed implicitly when calling hybrid forward.
        x: NDarray or mxnet symbol, No default
          The input to the decoder.
        Returns
        -------
        dist_params: list
          A list of the multinomial parameters as NDarrays, each being of size batch_size x z_dim.
        samples: NDarray
          The multinomial samples as a batch_size x z_dim NDarray (NOT IMPLEMENTED).
        '''
        out = self.main(y)
        return out

    def y_as_topics(self, eps=1e-10):
        y = np.eye(self.ndim_y)
        return mx.nd.array(y)

class Discriminator_y(ENet):
    '''
    A gluon HybridBlock Discriminator Class for y
    '''

    def __init__(self, model_ctx, batch_size, output_dim=2, ndim_y=16, n_hidden=64, n_layers=0, nonlin='sigmoid',
                 weights_file='', freeze=False, latent_nonlin='sigmoid', apply_softmax=False, **kwargs):
        '''
        Constructor for Discriminator Class for y.
        Args
        ----
        model_ctx: mxnet device context, No default
          Which device to store/run the data and model on.
        batch_size: integer, No default
          The minibatch size.
        n_hidden: integer or list, default 64
          If integer, specifies the number of hidden units in
          every hidden layer.
          If list, each element specifies the number of hidden
          units in each hidden layer.
        output_dim: integer, No default
          The dimensionality of the latent space, z.
        n_layers: integer, default 0
          The number of hidden layers.
        nonlin: string, default 'sigmoid'
          The nonlinearity to use in every hidden layer.
        weights_file: string, default ''
          The path to the file (mxnet params file or pickle file)
          containing weights for each layer of the encoder.
        freeze: boolean, default False
          Whether to freeze the encoder weights (MIGHT BE BROKEN).
        latent_nonlin: string, default 'sigmoid'
          Which space to use for the latent variable:
            if 'sigmoid': z in (0,1)
            else: z in (-inf,inf)
        Parameters
        ----------
        Returns
        -------
        Multinomial Discriminator object.
        '''
        super(Discriminator_y, self).__init__()

        if n_layers >= 0:
            if isinstance(n_hidden, list):
                n_hidden = n_hidden[0]
                print('NOTE: Decoder ignoring list of hiddens because n_layer >= 0. Just using first element.')
            n_hidden = n_layers * [n_hidden]
        else:
            n_layers = len(n_hidden)
            print('NOTE: Decoder reading n_hidden as list.')

        if latent_nonlin != 'sigmoid':
            print('NOTE: Latent z will be fed to decoder in logit-space (-inf,inf).')
        else:
            print('NOTE: Latent z will be fed to decoder in probability-space (0,1).')

        if nonlin == '':
            nonlin = None

        in_units = ndim_y
        with self.name_scope():
            self.main = nn.HybridSequential(prefix='discriminator_y')
            for i in range(n_layers):
                self.main.add(nn.Dense(n_hidden[i], in_units=in_units, activation=nonlin))
                in_units = n_hidden[i]
            self.main.add(nn.Dense(output_dim, in_units=in_units, activation=None))

        self.model_ctx = model_ctx
        self.batch_size = batch_size
        self.ndim_y = ndim_y
        self.n_hidden = n_hidden
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.nonlin = nonlin
        self.latent_nonlin = latent_nonlin
        self.weights_file = weights_file
        self.freeze = freeze
        self.apply_softmax = apply_softmax

    def hybrid_forward(self, F, y):
        '''
        Passes the input through the decoder.
        Args
        ----
        F: mxnet.nd or mxnet.sym, No default
          This will be passed implicitly when calling hybrid forward.
        x: NDarray or mxnet symbol, No default
          The input to the decoder.
        '''
        logit = self.main(y)
        if self.apply_softmax:
            return F.softmax(logit)
        return logit

In [10]:
Compute, Domain, Encoder, Decoder, Discriminator_y, args = parse_args()

NotImplementedError: twenty_news