In [23]:
import tensorflow as tf
import json

import numpy

from keras.models import load_model
from keras.utils import plot_model

from pretrained_models import *
from preprocess import preprocess_text, clean

In [2]:
config = {
    'use_only_cpu_compatible_models': True
}
models_paths = get_serialized_models(config['use_only_cpu_compatible_models'])

global models
global graph
models = [*map(load_model, models_paths)]
graph = tf.get_default_graph()

In [68]:
def compute_category_difference(previous_result, new_result):
    return abs(previous_result['probability'] - new_result['probability'])

def compute_words_importance(text, averaged_most_probable_category):
    cleaned_text = clean(text)
    words = cleaned_text.split(' ')
    
    words_with_texts = [
        *map(
            lambda word: (word, ' '.join([*filter(lambda text_word: text_word != word, words)])), 
            words
        )
    ]
    results_without_each_word = [
        *map(
            lambda group: (group[0], get_models_predictions(group[1])['most_probable_category']),
            words_with_texts
        )
    ]
    
    return [
        *map(
            lambda group: (group[0], compute_category_difference(averaged_most_probable_category, group[1])),
            results_without_each_word
        )
    ]

In [69]:
def get_models_predictions(text):
    preprocessed_text = preprocess_text(text)
    labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
    with graph.as_default():
        raw_probabilities = [
            *map(
                lambda model: numpy.squeeze(model.predict(preprocessed_text), axis=0).tolist(),
                models
            )
        ]
        probabilities_with_labels = [
            *map(
                lambda probability: [
                    *map(
                        lambda i: {'label': labels[i], 'probability': probability[i]},
                        range(0, 6)
                    )],
                raw_probabilities
            )
        ]

    averaged_probabilities = numpy.average(raw_probabilities, axis=0).tolist()

    return {
        'probabilities_of_models': raw_probabilities,
        'probabilities_of_models_with_labels': probabilities_with_labels,
        'models_averaged_probabilities': averaged_probabilities,
        'most_probable_category': {
            'label': labels[numpy.argmax(averaged_probabilities)],
            'probability': numpy.max(averaged_probabilities)
        }
    }

In [74]:
def predict(text):
    predictions = get_models_predictions(text)
    importances = compute_words_importance(text, predictions['most_probable_category'])
    
    dto = predictions
    dto['word_importances'] = importances
    
    return dto

In [75]:
predict('I will fuckin kill you!')

{'models_averaged_probabilities': [0.996890127658844,
  0.4814731180667877,
  0.9424256980419159,
  0.9009456038475037,
  0.7575989067554474,
  0.018248425796628],
 'most_probable_category': {'label': 'toxic',
  'probability': 0.996890127658844},
 'probabilities_of_models': [[0.9966059923171997,
   0.40833544731140137,
   0.9401798248291016,
   0.9129313826560974,
   0.6990986466407776,
   0.017510369420051575],
  [0.9971742630004883,
   0.5546107888221741,
   0.9446715712547302,
   0.8889598250389099,
   0.8160991668701172,
   0.018986482173204422]],
 'probabilities_of_models_with_labels': [[{'label': 'toxic',
    'probability': 0.9966059923171997},
   {'label': 'severe_toxic', 'probability': 0.40833544731140137},
   {'label': 'obscene', 'probability': 0.9401798248291016},
   {'label': 'threat', 'probability': 0.9129313826560974},
   {'label': 'insult', 'probability': 0.6990986466407776},
   {'label': 'identity_hate', 'probability': 0.017510369420051575}],
  [{'label': 'toxic', 'proba

In [71]:
compute_words_importance('I will fuckin kill you!', get_models_predictions('I will fuckin kill you!')['most_probable_category'])

[('i', 0.00015231966972351074),
 ('will', 0.0004157721996307373),
 ('fuckin', 0.039737969636917114),
 ('kill', 0.0010592639446258545),
 ('you', 0.001063704490661621)]