In [1]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
nek19_trans_df = pd.read_csv('results/nek19_trans_df.csv')
# nek19_chat_df = pd.read_csv('results/nek19_chat_df.csv')
df19 = nek19_trans_df
perf_df = pd.read_csv('data/Project_RED/calculated performance data/NEKMTSCalcs.csv')

In [3]:
def transition_matrix(transitions):
    print(transitions)
    n = 1+ max(transitions) #number of states

    M = [[0]*n for _ in range(n)]

    for (i,j) in zip(transitions,transitions[1:]):
        M[i][j] += 1

    #now convert to probabilities:
    for row in M:
        s = sum(row)
        if s > 0:
            row[:] = [f/s for f in row]
    return M

In [4]:
def trans_wrapper(df, normalized=True):
    transitions = df['labels_h'].copy()

    states = sorted(set(transitions))
    print(states)

    ids = {k: 0 + i for i, k in enumerate(states)}
    transitions = transitions.map(ids)

    counts = transitions.value_counts().sort_index()
    frequencies = counts/counts.sum()

    transition = np.array(transition_matrix(transitions))
    if(normalized):
        transition = transition/np.array(frequencies)[None,:]
    return transition

In [8]:
state = ['disruption', 'follow-me', 'question', 'statement']
session_mats = [trans_wrapper(df19[df19['session'] == session]) for session in range(1,5)]

['disruption', 'follow-me', 'question', 'statement']
4081    2
4082    2
4083    2
4084    2
4085    2
       ..
6260    1
6261    2
6262    2
6263    2
6264    1
Name: labels_h, Length: 2184, dtype: int64
['disruption', 'follow-me', 'question', 'statement']
878     3
879     2
880     2
881     1
882     2
       ..
2512    2
2513    1
2514    2
2515    1
2516    2
Name: labels_h, Length: 1639, dtype: int64
['disruption', 'follow-me', 'question', 'statement']
0      1
1      3
2      2
3      3
4      2
      ..
873    1
874    2
875    3
876    2
877    3
Name: labels_h, Length: 878, dtype: int64
['disruption', 'follow-me', 'question', 'statement']
2517    2
2518    0
2519    2
2520    1
2521    2
       ..
4076    1
4077    2
4078    2
4079    3
4080    2
Name: labels_h, Length: 1564, dtype: int64


In [16]:

from scipy.stats import pearsonr

perfs = perf_df['MTSPerf'].iloc[0:4]
perfs

corrs = np.zeros((len(state), len(state)))

for i in range(0,4):
    for j in range(0,4):
        l = []
        for m in session_mats:
            l = l + [m[i][j]]
        corr, pval = pearsonr(l,perfs)
        corrs[i][j] = corr
        print("from: {}\tto: {}\tcorr: {}".format(state[i], state[j], corr))
        
corrs

from: disruption	to: disruption	corr: -0.581948017438173
from: disruption	to: follow-me	corr: 0.17832838839891785
from: disruption	to: question	corr: 0.7193263591144496
from: disruption	to: statement	corr: -0.5369095280392331
from: follow-me	to: disruption	corr: 0.3865923311242453
from: follow-me	to: follow-me	corr: 0.9102310226568568
from: follow-me	to: question	corr: -0.9064214040440033
from: follow-me	to: statement	corr: 0.5875791348925888
from: question	to: disruption	corr: -0.10472273781718294
from: question	to: follow-me	corr: -0.6178003234503924
from: question	to: question	corr: 0.9155254499612617
from: question	to: statement	corr: -0.03626713520546343
from: statement	to: disruption	corr: -0.13875913561754993
from: statement	to: follow-me	corr: -0.6878615046573555
from: statement	to: question	corr: 0.8064211661961862
from: statement	to: statement	corr: -0.43999311603755265


array([[-0.58194802,  0.17832839,  0.71932636, -0.53690953],
       [ 0.38659233,  0.91023102, -0.9064214 ,  0.58757913],
       [-0.10472274, -0.61780032,  0.91552545, -0.03626714],
       [-0.13875914, -0.6878615 ,  0.80642117, -0.43999312]])

In [7]:
def heatmaps(data, title="", verbose=False):
    plt.rcParams["figure.figsize"] = [7.00, 3.50]
    plt.rcParams["figure.autolayout"] = True
    f, axes = plt.subplots(1, 2)
    transitions = data.copy()

    # print(transitions[0:10])
    states = set(transitions)
    state_names = [s[0].upper() for s in states]

    ids = {k: 0 + i for i, k in enumerate(states)}
    transitions = transitions.map(ids)
    # print(transitions[0:10])

    counts = transitions.value_counts().sort_index()
    frequencies = counts/counts.sum()
    # print(counts)
    # print(frequencies)

    transition = np.array(transition_matrix(transitions))

    ax1 = sns.heatmap(transition, square=True, annot=True,fmt=".2f",linewidth=.5,cmap="mako_r",ax=axes[0])
    ax1.set_xticklabels(state_names)
    ax1.set_yticklabels(state_names)
    ax1.set_title(title + " transition")

    normalized = transition/np.array(frequencies)[None,:]

    ax2 = sns.heatmap(normalized, square=True, annot=True,fmt=".2f",linewidth=.5,cmap="mako_r",ax=axes[1])
    ax2.set_xticklabels(state_names)
    ax2.set_yticklabels(state_names)
    ax2.set_title(title + " normalized")
    plt.show()
    if(verbose):
        print(ids)
        print(normalized)
        print(transition)

heatmaps(nek21_trans_df['labels_h'].copy(), title="trans")