In [10]:
from hmmlearn import hmm

In [5]:
import numpy as np

In [46]:
states = ('Healthy', 'Fever')
end_state = 'E'
 
observations = ('normal', 'cold', 'dizzy')
 
start_probability = {'Healthy': 0.6, 'Fever': 0.4}
 
transition_probability = {
   'Healthy' : {'Healthy': 0.69, 'Fever': 0.3, 'E': 0.01},
   'Fever' : {'Healthy': 0.4, 'Fever': 0.59, 'E': 0.01},
   }
 
emission_probability = {
   'Healthy' : {'normal': 0.5, 'cold': 0.4, 'dizzy': 0.1},
   'Fever' : {'normal': 0.1, 'cold': 0.3, 'dizzy': 0.6},
   }

In [72]:
states_ws = ('biased_h', 'all_h', 'none_h', 'correct_h')
observations_ws = ('biased', 'all', 'none', 'correct', '0')
start_probability_ws = {'biased_h': 0.7, 'all_h': 0.05, 'none_h': 0.05, 'correct_h': 0.2 }

transition_ws = {
    'biased_h': {'biased_h': 0.6, 'all_h':0.1, 'none_h': 0.05, 'correct_h': 0.05, 'E':0.1},
    'all_h': {'biased_h': 0.3, 'all_h':0.5, 'none_h': 0.05, 'correct_h': 0.05, 'E':0.1},
    'none_h': {'biased_h': 0.3, 'all_h':0.3, 'none_h': 0.25, 'correct_h': 0.05, 'E':0.1},
    'correct_h': {'biased_h': 0.3, 'all_h':0.1, 'none_h': 0.00, 'correct_h': 0.5, 'E':0.1}
}

emissions_ws = {
    'biased_h': {'biased': 0.3, 'all':1e-6, 'none': 1e-6, 'correct': 1e-6, '0': 0.7},
    'all_h': {'biased': 1e-6, 'all':0.25, 'none': 1e-6, 'correct': 1e-6, '0': 0.75},
    'none_h': {'biased': 1e-6, 'all':1e-6, 'none': 0.2, 'correct': 1e-6, '0': 0.8},
    'correct_h': {'biased': 1e-6, 'all':1e-6, 'none': 1e-6, 'correct': 0.3, '0': 0.7}
}

In [89]:
states_ws_r = ('biased', 'all', 'none', 'correct', '0')
observations_ws_r = ('biased_o', 'all_o', 'none_o', 'correct_o')
start_probability_ws_r = {'biased': 0.7, 'all': 0.05, 'none': 0.05, 'correct': 0.2, '0': 1e-6  }

transition_ws_r = {
    'biased': {'biased': 0.1, 'all':0.05, 'none': 0.02, 'correct': 0.02, 'E':0.1, '0': 0.71},
    'all': {'biased': 0.01, 'all':0.1, 'none': 0.02, 'correct': 0.02, 'E':0.1, '0': 0.8},
    'none': {'biased': 0.01, 'all':0.01, 'none': 0.1, 'correct': 0.02, 'E':0.1, '0': 0.8},
    'correct': {'biased': 0.01, 'all':0.01, 'none': 0.02, 'correct': 0.1, 'E':0.1, '0': 0.8},
    '0': {'biased': 0.01, 'all':0.1, 'none': 0.02, 'correct': 0.02, 'E':0.01, '0': 0.8},
}

emissions_ws_r = {
    'biased': {'biased_o': 1, 'all_o':1e-6, 'none_o': 1e-6, 'correct_o': 1e-6},
    'all': {'biased_o': 1e-6, 'all_o':1, 'none_o': 1e-6, 'correct_o': 1e-6},
    'none': {'biased_o': 1e-6, 'all_o':1e-6, 'none_o': 1, 'correct_o': 1e-6},
    'correct': {'biased_o': 1e-6, 'all_o':1e-6, 'none_o': 1e-6, 'correct_o': 1},
    '0': {'biased_o': 0.25, 'all_o':0.25, 'none_o': 0.25, 'correct_o': 0.25}
}

In [79]:
def fwd_bkw(observations, states, start_prob, trans_prob, emm_prob, end_st):
    """Forward–backward algorithm."""
    
    # Forward part of the algorithm
    fwd = []
    f_prev = None
    prev_f_sum = None
    b_prev = None
    for i, observation_i in enumerate(observations):
        f_curr = {}
        for st in states:
            if i == 0:
                # base case for the forward part
                prev_f_sum = start_prob[st]
            else:
                prev_f_sum = sum(f_prev[k] * trans_prob[k][st] for k in states)

            f_curr[st] = emm_prob[st][observation_i] * prev_f_sum

        fwd.append(f_curr)
        f_prev = f_curr
    p_fwd = sum(f_curr[k] * trans_prob[k][end_st] for k in states)
    # Backward part of the algorithm
    bkw = []
    for i, observation_i_plus in enumerate(reversed(observations[1:] + (None,))):
        b_curr = {}
        for st in states:
            if i == 0:
                # base case for backward part
                b_curr[st] = trans_prob[st][end_st]
            else:
                b_curr[st] = sum(trans_prob[st][l] * emm_prob[l][observation_i_plus] * b_prev[l] for l in states)

        bkw.insert(0,b_curr)
        b_prev = b_curr

    p_bkw = sum(start_prob[l] * emm_prob[l][observations[0]] * b_curr[l] for l in states)

    # Merging the two parts
    posterior = []
    for i in range(len(observations)):
        posterior.append({st: fwd[i][st] * bkw[i][st] / p_fwd for st in states})
#     assert p_fwd == p_bkw
    return fwd, bkw, posterior

In [80]:
f_o, b_o, sm_o = fwd_bkw(observations,
                   states,
                   start_probability,
                   transition_probability,
                   emission_probability,
                   end_state)

In [81]:
for st, s in zip(states, sm_o):
    print(st, s)

Healthy {'Healthy': 0.8770110375573259, 'Fever': 0.1229889624426741}
Fever {'Healthy': 0.623228030950954, 'Fever': 0.3767719690490461}


In [82]:
f, b, sm = fwd_bkw(observations_ws,
                   states_ws,
                   start_probability_ws,
                   transition_ws,
                   emissions_ws,
                   end_state)

KeyError: 'biased_h'

In [83]:
for st, s in zip(states_ws, sm):
    print(st, s)

biased_h {'biased': 0.9999997857064603, 'all': 1.4285578370348236e-07, 'none': 1.4287345451417008e-08, 'correct': 5.714755337012051e-08, '0': 2.8571156740696466e-12}
all_h {'biased': 1.9999574583463306e-06, 'all': 0.9999814004223323, 'none': 1.9998812065208888e-06, 'correct': 3.999993025240146e-07, '0': 1.4199739700193356e-05}
none_h {'biased': 4.99983975712261e-07, 'all': 4.999733414240889e-06, 'none': 0.9999495026519841, 'correct': 4.9997035159286335e-06, '0': 3.999792710984649e-05}
correct_h {'biased': 4.999952500820091e-07, 'all': 5.00196487483373e-07, 'none': 4.999744013390266e-06, 'correct': 0.9999900001204725, '0': 3.999943776459759e-06}


In [90]:
f, b, sm = fwd_bkw(observations_ws_r,
                   states_ws_r,
                   start_probability_ws_r,
                   transition_ws_r,
                   emissions_ws_r,
                   end_state)

In [91]:
for st, s in zip(states_ws_r, sm):
    print(st, s)

biased {'biased': 0.9999991051815444, 'all': 9.419139939885114e-08, 'none': 6.593401912885324e-08, 'correct': 2.6373603998072e-07, '0': 4.709569969942557e-07}
all {'biased': 3.946049382073379e-07, 'all': 0.2197800948032498, 'none': 1.1988000098636477e-07, 'correct': 8.791218577301748e-08, '0': 0.780219302799625}
none {'biased': 4.289791923540292e-08, 'all': 4.545449488743264e-07, 'none': 0.09090903338124995, 'correct': 2.7272637287374683e-07, '0': 0.909090196449509}
correct {'biased': 2.4999925064215364e-07, 'all': 2.2954465108577234e-06, 'none': 6.818158180487945e-07, 'correct': 0.49999847855637986, '0': 0.4999982941820405}


# Import real probas

In [95]:
%cd /Users/georgi/dev/dialogue_modeling

/Users/georgi/dev/dialogue_modeling


In [96]:
from featurisers.raw_wason_featuriser import calculate_stats, preprocess_conversation_dump
from solution_tracker.augment_with_solution import merge_with_solution_raw, merge_with_solution_annotation_message_level

In [97]:
from supporting_classifiers.agreement_classifier import *
from solution_tracker.simple_sol import solution_tracker, process_raw_to_solution_tracker
import spacy
import string
from read_data import read_solution_annotaions, read_wason_dump, read_3_lvl_annotation_file
import pandas as pd
from featurisers.raw_wason_featuriser import get_y

In [98]:
raw_data = read_wason_dump('data/all/')


In [99]:
nlp = spacy.load("en_core_web_sm")

In [106]:
def process_solution_simple(solution):
    res = set()
    for item in solution:
        for a_k, a_v in allowed.items():
            if item in a_v:
                res.update(a_k[0])
    return "".join(list(res))

In [107]:
allowed = {
    'vowels': {'A', 'O', 'U', 'E', 'I', 'Y'},
    'consonants': {'B', 'C', 'D', 'F', 'G', 'H', 'J', 'K', 'L', 'M', 'N', 'P', 'Q', 'R', 'S', 'T', 'V', 'X', 'Z', 'W'},
    'odds': {'1', '3', '5', '7', '9'},
    'evens': {'0', '2', '4', '6', '8'}
}

In [110]:
from collections import defaultdict

In [184]:
start_counts = defaultdict(lambda:0)
transition_counts = defaultdict(lambda:defaultdict(lambda:0))
emission_counts = defaultdict(lambda:defaultdict(lambda:0))

In [195]:
for item in raw_data:
    prepr = preprocess_conversation_dump(item.raw_db_conversation)    
    item.wason_messages_from_raw()
    item.preprocess_everything(nlp)
    agreement_predictor = Predictor('models/agreement.pkl')
    sol_tracker = solution_tracker(item, False, agreement_predictor)

    usr_sol = {}
    for s_t in sol_tracker:
        track = process_solution_simple(s_t['value'])
        if s_t['type'] == 'INITIAL':
            usr_sol[s_t['user']] = track
            start_counts[track] += 1
            emission_counts[track][track] += 1
        elif s_t['type'] == 'SUBMIT':
            emission_counts[track][track] += 1 
            last_user_transition = usr_sol.get(s_t['user'], None)
            transition_counts[last_user_transition][track] += 1
            usr_sol[s_t['user']] = track
        elif s_t['type'] == 'MENTION':
            for u_n, u_v in usr_sol.items(): 
                emission_counts[u_v][track] += 1 
                transition_counts[u_v][u_v] += 1
    for u_n, u_v in usr_sol.items():    
        transition_counts[u_v]['E'] += 1
#     res = []
#     for m in sol_tracker:
#         res.append(process_solution_simple(m['value']))
#     solutions_per_dialogue.append(res)

In [196]:
transition_counts

defaultdict(<function __main__.<lambda>()>,
            {'ev': defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
                         {'ev': 1971,
                          'ov': 32,
                          'v': 28,
                          'o': 2,
                          'E': 151,
                          'coev': 1,
                          'oev': 13,
                          'oecv': 14,
                          'e': 4,
                          'oe': 1,
                          'ce': 3,
                          'cev': 2,
                          'oc': 1}),
             'ov': defaultdict(<function __main__.<lambda>.<locals>.<lambda>()>,
                         {'E': 148,
                          'ov': 793,
                          'ev': 11,
                          'oecv': 1,
                          'oev': 1,
                          'v': 4,
                          'oc': 1}),
             'ce': defaultdict(<function __main__.<lambda>.<locals>.<lamb

In [197]:
def normalize_dict(dictionary):
    total = sum(d for d in dictionary.values())
    res_norm = {}
    for key, value in dictionary.items():
        res_norm[key] = value / total
    return res_norm

def normalize_dict_nested(dictionary):
    res_norm = {}

    for upper_key, upper_value in dictionary.items():
        inner_norm = {}
        
        total = sum(d for d in upper_value.values())
        for key, value in upper_value.items():
            inner_norm[key] = value / total
        res_norm[upper_key] = inner_norm
    return res_norm

In [242]:
start_probabilities = normalize_dict(start_counts)
transition_prob = normalize_dict_nested(transition_counts)
emission_prob = normalize_dict_nested(emission_counts)

In [243]:
transition_prob

{'ev': {'ev': 0.8866396761133604,
  'ov': 0.014394961763382817,
  'v': 0.012595591542959963,
  'o': 0.000899685110211426,
  'E': 0.06792622582096267,
  'coev': 0.000449842555105713,
  'oev': 0.005847953216374269,
  'oecv': 0.006297795771479982,
  'e': 0.001799370220422852,
  'oe': 0.000449842555105713,
  'ce': 0.001349527665317139,
  'cev': 0.000899685110211426,
  'oc': 0.000449842555105713},
 'ov': {'E': 0.1543274244004171,
  'ov': 0.8269030239833159,
  'ev': 0.011470281543274244,
  'oecv': 0.0010427528675703858,
  'oev': 0.0010427528675703858,
  'v': 0.004171011470281543,
  'oc': 0.0010427528675703858},
 'ce': {'ce': 0.9270833333333334,
  'ov': 0.020833333333333332,
  'oev': 0.006944444444444444,
  'v': 0.013888888888888888,
  'E': 0.020833333333333332,
  'e': 0.006944444444444444,
  'ev': 0.003472222222222222},
 'v': {'ev': 0.03180760279286268,
  'v': 0.8789759503491078,
  'E': 0.06283941039565555,
  'ov': 0.016291698991466253,
  'coe': 0.0007757951900698216,
  'e': 0.00310318076027

In [244]:
len(emission_prob)

20

In [260]:
def fwd_bkw_modified(observations, states, start_prob, trans_prob, emm_prob, end_st):
    """Forward–backward algorithm."""
    
    # Forward part of the algorithm
    fwd = []
    f_prev = None
    prev_f_sum = None
    b_prev = None
    for i, observation_i in enumerate(observations):
        f_curr = {}
        for st in states:
            if i == 0:
                # base case for the forward part
                prev_f_sum = start_prob[st]
            else:
                prev_f_sum = sum(f_prev.get(k, 1e-6) * trans_prob[k].get(st, 1e-6) for k in states if k in trans_prob)

            if st in emm_prob:
                f_curr[st] = emm_prob[st].get(observation_i, 1e-6) * prev_f_sum
            else:
                f_curr[st] = 1e-6

        fwd.append(f_curr)
        f_prev = f_curr
    p_fwd = sum(f_curr.get(k, 1e-6) * trans_prob[k].get(end_st, 1e-6) for k in states if k in trans_prob)
    # Backward part of the algorithm
    bkw = []
    for i, observation_i_plus in enumerate(reversed(observations[1:] + (None,))):
        b_curr = {}
        for st in states:
            if i == 0:
                # base case for backward part
                if st in trans_prob:
                    b_curr[st] = trans_prob[st].get(end_st, 1e-6)
                else:
                    b_curr[st] = 1e-6
            else:
                b_curr[st] = sum(trans_prob[st].get(l, 1e-6) * emm_prob[l].get(observation_i_plus, 1e-6) * b_prev.get(l, 1e-6) for l in states if st in trans_prob)

        bkw.insert(0,b_curr)
        b_prev = b_curr

#     p_bkw = sum(start_prob[l] * emm_prob[l][observations[0]] * b_curr[l] for l in states)

    # Merging the two parts
#     posterior = []
#     for i in range(len(observations)):
#         for st in states:
#         posterior.append({st: fwd[i][st] * bkw[i][st] / p_fwd for st in states})
#     assert p_fwd == p_bkw

    # Merging the two parts
    posterior = {}
    for i, o_int in enumerate(observations):
        for s in states:
#             print(i, o_int)

#             print('{}>{}'.format(o_int, st))
            posterior['{}>{}'.format(o_int, s)] = fwd[i][s] * bkw[i][s] / p_fwd 
    return fwd, bkw, posterior

In [261]:
f, b, sm = fwd_bkw_modified(
                observations=tuple(emission_prob.keys()),
                states=tuple(start_probabilities.keys()),
                start_prob=start_probabilities,
                trans_prob=transition_prob,
                emm_prob=emission_prob,
                end_st=end_state)

In [262]:
sm

{'ov>ov': 0.022873693169519357,
 'ov>ev': 0.22202703438659036,
 'ov>ce': 0.07308333298962051,
 'ov>oecv': 0.03504262803971624,
 'ov>o': 0.3853000359745458,
 'ov>c': 0.006949333828747641,
 'ov>v': 0.08764677372015121,
 'ov>coev': 0.00044442849408213146,
 'ov>ocv': 2.674775515429449e-05,
 'ov>oe': 0.0009520736666423345,
 'ov>oev': 0.034422453876815556,
 'ov>e': 0.015012168405124262,
 'ov>': 0.0008820644392297219,
 'ov>cv': 0.0032995299461511195,
 'ov>co': 8.676557773848636e-05,
 'ov>oc': 0.11168105835604106,
 'ov>cev': 0.00026982896597077375,
 'ov>coe': 1.9915528655294953e-09,
 'ov>cov': 3.954636309986324e-08,
 'ov>oec': 6.870242914928129e-09,
 'ev>ov': 0.01395715097368565,
 'ev>ev': 0.22805637037140367,
 'ev>ce': 0.0733701540542651,
 'ev>oecv': 0.03633224522796028,
 'ev>o': 0.38615429233064436,
 'ev>c': 0.0054862027616768925,
 'ev>v': 0.08899082672956338,
 'ev>coev': 0.00022192911461351613,
 'ev>ocv': 2.341354990530564e-10,
 'ev>oe': 0.0003649211079229056,
 'ev>oev': 0.03568308894489524

In [216]:
flatten = {}
for st, s in zip(tuple(emission_prob.keys()), sm):
    for k, item in s.items():
        flatten["{}>{}".format(st, k)] = item

In [267]:
dict(sorted(sm.items(), key=lambda item: item[1], reverse=True))

{'coe>coe': 0.9898411200955106,
 'cov>coe': 0.9898202281257586,
 'oec>v': 0.9791635058664153,
 'cev>v': 0.6782035047619064,
 'oe>oe': 0.6423874495684423,
 'coev>oe': 0.6085814097542509,
 'e>oe': 0.5249787838390159,
 'ocv>oe': 0.504868000978853,
 'oev>oev': 0.4867531566619209,
 'co>oev': 0.484232396442178,
 '>oev': 0.48110427010545054,
 'cv>oev': 0.48056443232760315,
 'co>ev': 0.48026297179515637,
 'c>oe': 0.4030004230082211,
 'ce>o': 0.3987611676700507,
 'ev>o': 0.38615429233064436,
 'ov>o': 0.3853000359745458,
 'oecv>o': 0.37593920569758854,
 'oc>ev': 0.3488598167601991,
 'cv>ev': 0.3426976653570225,
 'v>o': 0.3218327743815401,
 'ocv>oev': 0.30756753928068653,
 'oc>oev': 0.28922165259576016,
 'cev>oev': 0.2866541446575701,
 'o>o': 0.28565274482612873,
 '>ev': 0.28013854653337406,
 'oev>oe': 0.27054035849510905,
 'o>oe': 0.25549771148664624,
 'ev>ev': 0.22805637037140367,
 'ov>ev': 0.22202703438659036,
 'oc>v': 0.21974773007003123,
 'oev>ev': 0.21894917253400464,
 'coev>oev': 0.2086305

In [234]:
dict(sorted(sm.items(), key=lambda item: item[1], reverse=True))

{'coe>coe': 0.9898411200955106,
 'coe>cov': 0.9898202281257586,
 'v>oec': 0.9791635058664153,
 'v>cev': 0.6782035047619064,
 'oe>oe': 0.6423874495684423,
 'oe>coev': 0.6085814097542509,
 'oe>e': 0.5249787838390159,
 'oe>ocv': 0.504868000978853,
 'oev>oev': 0.4867531566619209,
 'oev>co': 0.484232396442178,
 'oev>': 0.48110427010545054,
 'oev>cv': 0.48056443232760315,
 'ev>co': 0.48026297179515637,
 'oe>c': 0.4030004230082211,
 'o>ce': 0.3987611676700507,
 'o>ev': 0.38615429233064436,
 'o>ov': 0.3853000359745458,
 'o>oecv': 0.37593920569758854,
 'ev>oc': 0.3488598167601991,
 'ev>cv': 0.3426976653570225,
 'o>v': 0.3218327743815401,
 'oev>ocv': 0.30756753928068653,
 'oev>oc': 0.28922165259576016,
 'oev>cev': 0.2866541446575701,
 'o>o': 0.28565274482612873,
 'ev>': 0.28013854653337406,
 'oe>oev': 0.27054035849510905,
 'oe>o': 0.25549771148664624,
 'ev>ev': 0.22805637037140367,
 'ev>ov': 0.22202703438659036,
 'v>oc': 0.21974773007003123,
 'ev>oev': 0.21894917253400464,
 'oev>coev': 0.2086305

In [166]:
len(emission_prob.keys())

21

In [167]:
len(sm)

21

In [265]:
flattened_simple={}
for start, trans in transition_prob.items():
    for to, score in trans.items():
        flattened_simple['{}>{}'.format(start, to)] = score

In [266]:
dict(sorted(flattened_simple.items(), key=lambda item: item[1], reverse=True))

{'cov>cov': 0.96875,
 'ocv>ocv': 0.9464285714285714,
 'oec>oec': 0.9387755102040817,
 'co>co': 0.9285714285714286,
 'ce>ce': 0.9270833333333334,
 'coev>coev': 0.9178082191780822,
 'cev>cev': 0.9122807017543859,
 'oecv>oecv': 0.8922018348623854,
 'o>o': 0.8914285714285715,
 'e>e': 0.889937106918239,
 'ev>ev': 0.8866396761133604,
 'c>c': 0.8847926267281107,
 'coe>coe': 0.88,
 'v>v': 0.8789759503491078,
 'oc>oc': 0.8782608695652174,
 'oev>oev': 0.8746987951807229,
 '>': 0.8695652173913043,
 'cv>cv': 0.8653846153846154,
 'oe>oe': 0.8426966292134831,
 'ov>ov': 0.8269030239833159,
 'None>ov': 0.6666666666666666,
 'ov>E': 0.1543274244004171,
 'None>oecv': 0.1111111111111111,
 'None>oev': 0.1111111111111111,
 'None>v': 0.1111111111111111,
 '>v': 0.08695652173913043,
 'ev>E': 0.06792622582096267,
 'oev>E': 0.06746987951807229,
 'c>ev': 0.06451612903225806,
 'oecv>E': 0.06422018348623854,
 'v>E': 0.06283941039565555,
 'ocv>ov': 0.05357142857142857,
 'cv>v': 0.04807692307692308,
 'cv>ev': 0.04807