In [None]:
import artm
import pandas as pd
import numpy as np
import json

import zmq

In [None]:
from utils import sample_from, get_text_processor, prepare_for_artm
from collections import Counter

text_processor = get_text_processor()

In [None]:
modalities_set = set(['classes', 'tag', 'text'])

In [None]:
model_file = '../data/big_model.artm.mtx'
dict_file = '../data/dict.filtered'

In [None]:
with open('../data/insception-classes.tsv') as income:
    class_names = dict(map(lambda l: map(str.strip, l.split('\t')), income))

In [None]:
sparsed_topics = 39
smoothed_topics = 5

topics = ['good_%i'%_ for _ in range(sparsed_topics)] + ['mess_%s'%_ for _ in range(smoothed_topics)]

tm = artm.ARTM(num_topics=44, num_processors=2)

tm.load(model_file)
tm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='text_sparser', tau=-0.7, class_ids=['text']))
tm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='classes_sparser', tau=-0.3, class_ids=['classes']))
tm.regularizers.add(artm.SmoothSparsePhiRegularizer(name='tags_smoother', tau=1, class_ids=['tag']))
tm.regularizers.add(artm.SmoothSparseThetaRegularizer(name='topic_sparser', tau=-2, topic_names=topics[:sparsed_topics]))
tm.regularizers.add(artm.SmoothSparseThetaRegularizer(name='topic_smoother', tau=2.5, topic_names=topics[smoothed_topics:]))

In [None]:
def make_sample(ps, times):
    return Counter(dict((ps.index[k], v) for k,v in enumerate(np.random.multinomial(times, ps.values.reshape(-1)*0.98)) if v > 0))

In [None]:
def get_top(row, treshold=0.95, number=5):
    sorted_row = row.sort_values(ascending=False)
    
    res = []
    prob_mass = 0
    for k, val in sorted_row.iteritems():
        prob_mass+=val
        res.append((k, val))
        if prob_mass>treshold:
            break
        if len(res)==number:
            res.append(('other', (1. - prob_mass)))
            break
    
    return ['%s:%s'%l for l in res]

In [None]:
def do_all(income_json, temp_name='temp'):
    data_in = json.loads(income_json) if isinstance(income_json, basestring) else income_json
    df_in = pd.DataFrame(data_in)

    if 'text' in df_in:
        df_in.text = df_in.text.apply(text_processor).apply(Counter)

    if 'tag' in df_in:
        df_in.tag = df_in.tag.apply(Counter)

    if 'classes' in df_in:
        df_in.classes = df_in.classes.apply(np.array).apply(sample_from)

    batch = prepare_for_artm(df_in, temp_name)

    modalities_to_generate = modalities_set - set(df_in.columns)

    if 'classes' in modalities_to_generate: 
        df_in['classes'] = list(tm.transform(batch, predict_class_id='classes').T.sort_index().T\
                                .apply(lambda r: sample_from(r.values), axis=0)\
                                .apply(lambda d: reduce(lambda l1,l2: l1+l2, [[e]*k for e,k in d.items()]))\
                                .apply(lambda l: map(lambda c: class_names.get(str(c), 'unknown'), l)))
 
    if 'text' in modalities_to_generate: 
        df_in['text'] = list(tm.transform(batch, predict_class_id='text').T.sort_index().T\
                            .apply(lambda r: get_top(r, 0.2, 20), axis=0))

    if 'tag' in modalities_to_generate: 
        df_in['tag'] = list(tm.transform(batch, predict_class_id='tag').T.sort_index().T\
                            .apply(lambda r: make_sample(r, 5), axis=0)\
                            .apply(lambda d: reduce(lambda l1,l2: l1+l2, [[e]*k for e,k in d.items()])))
   

    df_in['topics'] = map(lambda (k, row): map(float, row), tm.transform(batch).T.sort_index().iterrows())
    
    print df_in
    
    outcome = []
    for u, data in df_in.iterrows():
        ans = data.to_dict()
        ans['img_url'] = u
        outcome.append(u'%s\n'% json.dumps(ans))
    
    return '[%s]'% ', \n'.join(outcome)

In [None]:
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind("tcp://*:1349")

In [None]:
while True:
    print 'ready to recieve'
    income = socket.recv_json()
    print 'income is', income
    res = do_all(income)
    print 'res is', res
    socket.send_string(res.decode('utf8'))