In [1]:
import numpy as np
from random import gauss
from math import sqrt
import matplotlib as mpl
mpl.use('nbagg')
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

In [2]:
def SimplifyState(x):
    output = np.zeros_like(x)
    states = np.sort(np.unique(x))
    counter = 0
    for state in states:
        output[np.where(x==state)] = counter
        counter += 1
    return output, np.arange(counter)

def StateList2D21D(x, statenum):
    output = np.zeros(x.shape[1])
    for i in range(x.shape[0]):
        output = output + x[i]*statenum**i
    return output

def TE(x, y, xhist_list, yhist_list):
    Information = dict()
    x, xstates = SimplifyState(x)
    y, ystates = SimplifyState(y)
    xdat = []
    ydat = []
    end = min(xhist_list[-1], yhist_list[-1])
    for hist in xhist_list:
        xhist = x[-end+hist:hist]
        xdat.append(xhist.copy())
        yhist = y[-end+hist:hist]
        ydat.append(yhist.copy())
    xhist = StateList2D21D( np.array(xdat), len(xstates))
    yhist = StateList2D21D( np.array(ydat), len(ystates))
    
    xedges = np.append(np.unique(xhist),(max(xhist)+1))
    vedges = np.append(np.unique(yhist),(max(yhist)+1))
    redges = np.append(np.unique(y),(max(y)+1))
    dat = np.concatenate((xhist[:,np.newaxis], yhist[:,np.newaxis],y[-end:,np.newaxis]), axis=1)
    N, edges = np.histogramdd(dat, bins=(xedges, vedges, redges))
    #Calculate all kinds of probability and make sure the shape of them, 0 -> x, 1 -> v, 2 -> r
    # x repersent x-hist
    # v repersent y-hist
    # r repersent y-now
    px=(np.sum(N,axis=(1,2))/np.sum(N))[:, np.newaxis, np.newaxis]
    pv=(np.sum(N,axis=(0,2))/np.sum(N))[np.newaxis, :, np.newaxis]
    pr=(np.sum(N,axis=(0,1))/np.sum(N))[np.newaxis ,np.newaxis, :]
    pxv=(np.sum(N,axis=2)/np.sum(N))[:, :, np.newaxis]
    pxr=(np.sum(N,axis=1)/np.sum(N))[:, np.newaxis, :]
    pvr=(np.sum(N,axis=0)/np.sum(N))[np.newaxis, :, :]
    pxvr=(N/np.sum(N))


    MIxr=np.nansum(pxr*np.log2(pxr/px/pr))
    MIvr=np.nansum(pvr*np.log2(pvr/pv/pr))
    MIxvR=np.nansum(pxvr*np.log2(pxvr/pxv/pr))
    PI_xR = np.nansum(pxr*np.log2(pxr/px/pr), axis = (0,1))
    PI_vR = np.nansum(pvr*np.log2(pvr/pv/pr), axis = (0,1))
    R = sum(np.minimum(PI_xR, PI_vR))
    Information[('Beer','Red')] = R
    Information[('Beer','UIx2y')] = MIxr - R
    Information[('Beer','UIy2y')] = MIvr - R
    Information[('Beer','Syn')] = MIxvR - MIxr - MIvr + R

    return Information