In [1]:
import pandas as pd
import numpy as np
import itertools
from tqdm import tqdm
import random
import functools
import copy

import pomegranate
from pomegranate import HiddenMarkovModel ,State
from pomegranate.distributions import IndependentComponentsDistribution
from pomegranate.distributions import NormalDistribution,DiscreteDistribution

import networkx as nx

import pickle as pkl
import matplotlib.pyplot as plt
import IPython
import seaborn as sns


# build the meta network

In [2]:
#params :
# n_dim_of_chain=8
# n_of_chains = 5
# possible_number_of_walks = [1,2,3,4]

n_dim_of_chain=3
n_of_chains = 2
possible_number_of_walks = [1,2]
number_of_possible_states_limit = 10000000
chance_to_be_in_path = 1.1
prob_of_dim=0.7

In [3]:
def generate_all_binarys(n, arr, i): 
    if i==n :
        all_binarys.append(np.nonzero(arr.copy())[0].tolist())
        return 

    arr[i] = 0
    generate_all_binarys(n, arr, i + 1)  

    arr[i] = 1
    generate_all_binarys(n, arr, i + 1) 

def build_IX_mock(n, possible_number_of_walks):
    imx_possible_walks = {}

    if (max(possible_number_of_walks) + 1 ) > n :
        raise Exception( f"there is no {max(possible_number_of_walks)} possible")

    cell_idx = 0
    while cell_idx < n :
        number_of_walks = np.random.choice(possible_number_of_walks)
        walks_idx = np.random.choice(range(n), number_of_walks, False)
        walks_idx = np.append(walks_idx, np.array(cell_idx))
        walks_idx = np.unique(walks_idx)

        if len(walks_idx) != (number_of_walks + 1) :
            continue

        imx_possible_walks[cell_idx] = walks_idx
        cell_idx = cell_idx + 1
    return imx_possible_walks
        
def build_network_from_IX(n,all_binarys,imx_possible_walks) : 
    network_dic = {} #key:cell vec, value : all conncted cell vecs
    pathway_network = {}
    
    for cell_vec in tqdm(all_binarys) : 
        if len(cell_vec) == 0  :
            continue

        possible_walk_for_idx_matrix = [imx_possible_walks[cell_idx] for cell_idx in cell_vec]
        all_possible_walks_from_cell = [frozenset(comb) for comb in itertools.product(*possible_walk_for_idx_matrix)]
        
        #we dont want transition to the same state
        non_cyclic_walks = [walk for walk in all_possible_walks_from_cell if tuple(walk) != tuple(cell_vec)]
       
        network_dic[tuple(cell_vec)] = non_cyclic_walks
    return network_dic

def build_model_networks(n_of_chains,n,all_binarys,imx_possible_walks) : 
    model_networks = {}
    for i in range(n_of_chains) : 
        _net_walks = build_network_from_IX(n,all_binarys,imx_possible_walks)
        model_networks[i] = _net_walks
    return model_networks

def build_pathways_mock(imx_possible_walks) : 
    hard_walk_to_pathways_map = {}
    for vec in imx_possible_walks : 
        path = np.random.choice([1,2,3,4])
        hard_cells_to_pathways_map[tuple(vec)] = path
    return hard_walk_to_pathways_map

all_binarys = [] 
generate_all_binarys(n_dim_of_chain,[None]*n_dim_of_chain,0)
imx_possible_walks = build_IX_mock(n_dim_of_chain,possible_number_of_walks)
model_networks = build_model_networks(n_of_chains,n_dim_of_chain,all_binarys,imx_possible_walks)

100%|██████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<00:00, 7977.75it/s]
100%|████████████████████████████████████████████████████████████████████████████████████████████| 8/8 [00:00<?, ?it/s]


In [4]:
#all the naive comb of states across chains
all_curr_state_comb_between_networks = itertools.product(*[network for network in model_networks.values()])
#we belive we can filter base on the real data
filtered_curr_state_comb_between_networks = (itertools.islice(all_curr_state_comb_between_networks,number_of_possible_states_limit))
filtered_curr_state_comb_between_networks_1,filtered_curr_state_comb_between_networks_2 = itertools.tee(filtered_curr_state_comb_between_networks)

#now we need to find all possible combinations of walks : 
#we start by building all comb of walks across chains for all comb of states across chain
all_walks_across_chains = (map(lambda comb:[model_networks[net_idx][chain_state] for net_idx,chain_state in enumerate(comb)],filtered_curr_state_comb_between_networks_1))

In [5]:
def is_walk_in_path(x,_path,chance_to_be_in_path):
    return random.random() < chance_to_be_in_path

def return_filtered_walks_per_curr_chain(_walks_per_curr_chain,path,chance_to_be_in_path):
    return list(map(lambda walk:list(filter(lambda x:is_walk_in_path(x,path,chance_to_be_in_path),walk)),_walks_per_curr_chain))

In [6]:
#we now build diffrent stream for each pathway - **only comb where all the walks are in the same pathway are possible **
def filter_comb_of_walks_across_chains(_walks_per_curr_comb,chance_to_be_in_path):
    all_combs_of_walks_all_paths = []
    for _path in [1,2,3,4,5] : 

        #keep only walks in the pathway    
        _walks_per_curr_comb_in_path = return_filtered_walks_per_curr_chain(_walks_per_curr_comb,_path,chance_to_be_in_path)
        
        #its smart to filter out first comb of walks where there is at least one chain with no walk in this pathway :
        if any([len(_walks)==0 for _walks in _walks_per_curr_comb_in_path]):
            continue

        _combs_of_walks = list(itertools.product(*_walks_per_curr_comb_in_path))
        for _comb in  _combs_of_walks : 
            if _comb not in all_combs_of_walks_all_paths :
                all_combs_of_walks_all_paths = all_combs_of_walks_all_paths +  [_comb]
    return all_combs_of_walks_all_paths

filt_comb_walks_across_chains = (map(lambda _walks_per_curr_comb:filter_comb_of_walks_across_chains(_walks_per_curr_comb,chance_to_be_in_path),all_walks_across_chains) )

#zip the combinations of states with the comb of walks 
state_comb_to_walks_comb = zip(filtered_curr_state_comb_between_networks_2,filt_comb_walks_across_chains)

#now we filter out comb of walks and combs of states where there is no possible walk from this current state 
meta_state_comb_to_walks_comb = filter(lambda _walks : len(_walks[1]) >0 ,state_comb_to_walks_comb)

explanation of the results "state_comb_to_walks_comb" :

state_comb_to_walks_comb[0] : states - state of every network

state_comb_to_walks_comb[1] : walks - list of possible comb - every row contains :

state_comb_to_walks_comb[1][i] : state j from state_comb_to_walks_comb[0] walk to state state_comb_to_walks_comb[1][i]

# pick sub network

In [7]:
# i=0

# for s in meta_state_comb_to_walks_comb :
#     i=i+1
# print(i)
# raise

In [8]:
def build_long_state_vector(set_of_states,n_dim_of_chain):
    def build_long_state(small_state,i) : 
        return [dim + i*n_dim_of_chain  for dim in small_state]
        
    state_vector = [build_long_state(small_state,i) for small_state,i in  zip(set_of_states,range(len(set_of_states)))]
    flatten = [item for sublist in state_vector for item in sublist]
    return frozenset(flatten)

In [9]:
size_of_pomegranate_network = 10000

state_comb_to_walks_comb,_state_comb_to_walks_comb =itertools.tee(itertools.islice(meta_state_comb_to_walks_comb,size_of_pomegranate_network))

state_comb_to_walks_comb_dict = {}

with tqdm(size_of_pomegranate_network) as pbar : 
    for sample in _state_comb_to_walks_comb : 
        curr_state = build_long_state_vector(sample[0],n_dim_of_chain)
        
        next_possible = [build_long_state_vector(_next,n_dim_of_chain) for _next in sample[1]]
        
        state_comb_to_walks_comb_dict[curr_state] = next_possible
        pbar.update(1)

49it [00:00, 12284.57it/s]


# explore network

# build the pomegranate network

In [None]:
def return_relevant_multi_distribution(state_vactor,prob_of_dim,n_dim_of_chain,n_of_chains,dist_option = "discrete") : 
    multi_hot_vector_state = np.zeros((n_of_chains*n_dim_of_chain,1))
    multi_hot_vector_state[list(state_vactor)] = 1
    
    if dist_option == "normal" : 
        list_of_normal_dist = [NormalDistribution(dim[0],0.1) for dim in multi_hot_vector_state]
    
    if dist_option == "discrete" :
        list_of_normal_dist = [DiscreteDistribution({dim[0]:prob_of_dim,(1-dim[0]):(1-prob_of_dim)}) for dim in multi_hot_vector_state]
    return IndependentComponentsDistribution(list_of_normal_dist) 

def return_relevant_state(state_vector,prob_of_dim, n_dim_of_chain,n_of_chains) : 
    d = return_relevant_multi_distribution(state_vector,prob_of_dim, n_dim_of_chain,n_of_chains)
    state_name = str(sorted(state_vector))
    return State(d,state_name)

In [None]:
# state_vactor = frozenset({2, 5})
# multi_hot_vector_state = np.zeros((n_of_chains*n_dim_of_chain,1))
# multi_hot_vector_state[list(state_vactor)] = 1

# dist = return_relevant_multi_distribution(state_vactor)

In [None]:
# multi_hot_vector_state.T

In [None]:
# dist.probability(np.array([[1., 1., 0., 0., 0., 1.]]))
# dist.distributions

In [None]:
# for now we take random number of states :

first = True

with tqdm(size_of_pomegranate_network) as pbar : 
    state_holder ={}

    markov_model = HiddenMarkovModel('first_try')
    for sample in state_comb_to_walks_comb :
        curr_state = build_long_state_vector(sample[0],n_dim_of_chain)

        if curr_state not in state_holder.keys():
            curr_pomp_state = return_relevant_state(curr_state,prob_of_dim, n_dim_of_chain,n_of_chains)
            markov_model.add_states(curr_pomp_state)
            state_holder[curr_state] = curr_pomp_state
        else : 
            curr_pomp_state = state_holder[curr_state]

        for _next in sample[1] : 
            next_possible = build_long_state_vector(_next,n_dim_of_chain)
            if next_possible not in state_holder.keys():
                next_pomp_state = return_relevant_state(next_possible,prob_of_dim, n_dim_of_chain,n_of_chains)
                markov_model.add_states(next_pomp_state)
                state_holder[next_possible] = next_pomp_state
            else : 
                next_pomp_state = state_holder[next_possible]
            
            if first : 
                markov_model.add_transition(markov_model.start,curr_pomp_state,probability =0.5)
                first = False
            markov_model.add_transition(curr_pomp_state,next_pomp_state,probability =0.1)
        pbar.update(1)
    markov_model.add_transition(next_pomp_state,markov_model.end,probability =0.1)

markov_model.bake()

In [None]:
print("finish bake")
raise

In [None]:
# states = [state.name for state in markov_model.states]
# Q = markov_model.dense_transition_matrix()

# G = nx.MultiDiGraph()
# labels={}
# edge_labels={}

# for i, origin_state in enumerate(states):
#     for j, destination_state in enumerate(states):
#         rate = Q[i][j]
#         if rate > 0:
#             G.add_edge(origin_state,
#                        destination_state,
#                        weight=rate,
#                        label="{:.02f}".format(rate))
#             edge_labels[(origin_state, destination_state)] = label="{:.02f}".format(rate)
            
# from nxviz import CircosPlot

# c = CircosPlot(G)
# c.draw()

# simulate samples

In [None]:
size_of_possible_rw = 50
number_of_seqs = 50000

In [None]:
def pick_random_next_stage(_possible_next_steps,state_comb_to_walks_comb_dict,counter = 0) : 
    if counter == 50 : 
        return None
    if len(_possible_next_steps) == 0 :
        return None
    
    first_pick = random.choice(_possible_next_steps)
    if first_pick in all_possible_states :
        return first_pick
    _possible_next_steps.remove(first_pick)
    counter = counter + 1
    return pick_random_next_stage(_possible_next_steps,state_comb_to_walks_comb_dict,counter)
    

In [None]:
# nodes_in_model = [node.name for node in markov_model.graph.nodes]
# isolated = [node.name for node in nx.algorithms.isolate.isolates(markov_model.graph)]

all_possible_states = list(state_comb_to_walks_comb_dict.keys())
# all_possible_states = [_state for _state in all_possible_states if str(sorted(_state)) in nodes_in_model ]
# all_possible_states = [_state for _state in all_possible_states if str(sorted(_state)) not in isolated ]

n_of_states_in_meta_network = len(all_possible_states)

In [None]:
seqs = [] 
for i in range(number_of_seqs):
    seq = []
    random_state_idx = random.randint(1,n_of_states_in_meta_network)
    curr_random_state = all_possible_states[random_state_idx-1]
    seq.append(curr_random_state)
    
    for j in range(size_of_possible_rw) : 
        possible_next_steps = copy.copy(state_comb_to_walks_comb_dict[curr_random_state])
        curr_random_state = pick_random_next_stage(possible_next_steps,state_comb_to_walks_comb_dict)
        
        if curr_random_state is None : 
            break
            print("dude")
            random_state_idx = random.randint(1,n_of_states_in_meta_network)
            curr_random_state = all_possible_states[random_state_idx-1]
            
        seq.append(curr_random_state)
        
    seqs.append(seq)
    

In [None]:
_G_dict_of_lists = nx.to_dict_of_lists(markov_model.graph)
_G_dict_of_lists_clean = {k.name:[_v.name for _v in v] for k,v in _G_dict_of_lists.items()}

_G = nx.from_dict_of_lists(_G_dict_of_lists_clean)
adj_df = nx.to_pandas_adjacency(_G)

In [None]:
def return_multi_hot(state_set,n_dim_of_chain,n_of_chains) : 
    multi_hot_vector_state = np.zeros((n_of_chains*n_dim_of_chain,1))
    multi_hot_vector_state[list(state_set)] = 1
    return multi_hot_vector_state.T[0]

def return_multi_hot_vectors(vectors,n_dim_of_chain,n_of_chains) :
    return np.array([return_multi_hot(vector,n_dim_of_chain,n_of_chains) for vector in vectors])
    
# sampled_seqs = [return_multi_hot_vectors(random.choices(s,k=8),n_dim_of_chain,n_of_chains) for s in seqs]
sampled_seqs = [return_multi_hot_vectors(s,n_dim_of_chain,n_of_chains) for s in seqs]

In [None]:
seq,p = markov_model.sample(path=True)

In [None]:
len(seq)

In [None]:
p

In [None]:
markov_model.predict(seq)

In [None]:
def find_most_likely_states(markov_model,samples,k=3) : 
    emmisions_list = []
    unique_samples =  [list(x) for x in set(tuple(x) for x in samples)]

    for i,state in enumerate(markov_model.states) : 
        if state.distribution is not None : 
            emmisions_for_state = state.distribution.probability(unique_samples).tolist()
            emmisions_for_state += [state.name,i]
            emmisions_list.append(emmisions_for_state)
        
    emmisions_df = pd.DataFrame(columns=[str(s) for s in unique_samples]+["state","state_idx"],data=emmisions_list)
    return emmisions_df.set_index(["state","state_idx"]).apply(lambda x : x.argsort().argsort())

def return_corresponding_states_to_samples(markov_model,seq) : 
    most_likely_states = find_most_likely_states(markov_model,seq)
    
    sample_to_state = {}

    lowest_index = 0
    for col in most_likely_states.columns : 
        _sample_to_states = most_likely_states[col].iloc[lowest_index:]
        state,lowest_index = _sample_to_states.idxmax()
        sample_to_state[col] = state
    
    return sample_to_state

def create_continuous_observations(markov_model,seq,G) : 
    sample_to_state = return_corresponding_states_to_samples(markov_model,seq)
    
    new_seq = [] 
    for _curr_sample,_next_sample in zip(seq,seq[1:]) : 
        if (all(_curr_sample == _next_sample)):
            continue
        _curr_state = sample_to_state[str(_curr_sample.tolist())]
        _next_state = sample_to_state[str(_next_sample.tolist())]
        _simple_paths = nx.simple_paths.all_simple_paths(G,_curr_state,_next_state,cutoff=50)
        return _simple_paths



# model with samples

In [None]:
# plt.imshow(markov_model.dense_transition_matrix())

In [None]:
improvement = markov_model.fit(sampled_seqs,n_jobs=3)

In [None]:
fig = plt.figure(figsize=(25,25))
transition_matrix = markov_model.dense_transition_matrix()
sns.heatmap(transition_matrix, cmap='viridis')

In [None]:
_G_dict_of_lists = nx.to_dict_of_lists(markov_model.graph)
_G_dict_of_lists_clean = {k.name:[_v.name for _v in v] for k,v in _G_dict_of_lists.items()}

_G = nx.from_dict_of_lists(_G_dict_of_lists_clean)
adj_df = nx.to_pandas_adjacency(_G)

In [None]:
fig = plt.figure(figsize=(25,25))
sns.heatmap(adj_df.sort_index().sort_index(axis=1), cmap='viridis')

In [None]:
all_walks_seqs = []

for _samples_seq in sampled_seqs:
    _bach = [all_walks_seqs.append([str(np.where(_samples_seq[i-1])[0]),str(np.where(_samples_seq[i])[0]),1]) for i in range(1,len(_samples_seq))]


In [None]:
all_walks_seqs_df = pd.DataFrame(columns=["from","to","walk"],data=all_walks_seqs)
agg_walks_seqs_df = all_walks_seqs_df.groupby(["from","to"]).sum()
walks_seqs_df = agg_walks_seqs_df.reset_index().pivot(index="from",columns="to").fillna(0)

In [None]:
fig = plt.figure(figsize=(25,25))
sns.heatmap((walks_seqs_df>5).sort_index().sort_index(axis=1), cmap='viridis')

# test

In [None]:
import numpy as np
e_table = np.array([[1,5,12],[4,3,12],[6,1,3]])
e_table

In [None]:
best_states_idx_per_sample = []
k=2
for states_per_obs in e_table.T :
    best_states_idx_per_sample.append(states_per_obs.argsort()[-k:][::-1])
best_states_idx_per_sample

In [None]:
from itertools import product
from networkx.algorithms.shortest_paths import dijkstra_path,has_path
_emission_table = e_table
_top_e_per_state =best_states_idx_per_sample

all_comb_of_possible_states = product(*_top_e_per_state)

In [None]:
def __viterbi_fixed_states(model,comb) : 
    _curr_state = "first"
    
    for _state in comb : 
        _next_state = _state
        dijkstra_path(b)
        

best_log_pos = -10000
best_path = None
for comb in all_comb_of_possible_states :
    _log_pos_states = sum([_emission_table[state][obs] for state,obs in zip(comb,range(len(comb)))])
    _best_path = __viterbi_fixed_states()
    _log_pos_path = markov_model.log_probability(_best_path)
    _log_pos = _log_pos_path + _log_pos_states

    if _log_pos > best_log_pos :
        best_log_pos = _log_pos
        best_path = _best_path


In [None]:
g = markov_model.graph
l=[]
for i,n in enumerate(g.nodes()) : 
    if i==3 :
        pre = n
        l.append(n)
        continue
    if has_path(g,pre,n) : 
        l.append(n)
        pre = n 
         


In [None]:
len(l)

In [None]:
import math

def func(u, v, d):
    edge_wt = d.get('weight', 1)
    return math.exp(edge_wt)


[s.name for s in dijkstra_path(g,l[1],l[6],weight=func)]


In [None]:
g.nodes.first

In [None]:
def viterbi(self,source_state,target_state):
    states = list(range(len(self.states)))
    trans_p = self.dense_transition_matrix()
    
    max_length_path = len(states)
    
    V = [{}]
    for st in states:
        if source_state == st : 
            V[0][st] = {"prob": 1 , "prev": None}
        else : 
            V[0][st] = {"prob": 0 , "prev": None}
    
    # Run Viterbi when t > 0
    t=1
    while(t < max_length_path):
        V.append({})
        for st in states:
            max_tr_prob = V[t-1][states[0]]["prob"]*trans_p[states[0]][st]
            prev_st_selected = states[0]
            for prev_st in states[1:]:
                tr_prob = V[t-1][prev_st]["prob"]*trans_p[prev_st][st]
                if tr_prob > max_tr_prob:
                    max_tr_prob = tr_prob
                    prev_st_selected = prev_st
                    
            max_prob = max_tr_prob
            V[t][st] = {"prob": max_prob, "prev": prev_st_selected}
        
        t += 1
                    
    opt = []
    best_time_point_for_target = 0.0
    max_prob = 0.0
    previous = None
    # Get most probable state and its backtrack
    for time , time_data in enumerate(V):
        target_prob = time_data[target_state]["prob"]
        if target_prob > max_prob :
            best_time_point_for_target = time+1
            max_prob = target_prob
    opt.append(target_state)
    previous = target_state

    # Follow the backtrack till the first observation
    for t in range(best_time_point_for_target - 2, -1, -1):
        opt.insert(0, V[t + 1][previous]["prev"])
        previous = V[t + 1][previous]["prev"]

    return opt,max_prob


In [None]:
opt,max_prob = viterbi(markov_model,7,3)
print(opt)
print(np.log(max_prob))

In [None]:
markov_model.dense_transition_matrix()

In [None]:
[s.distribution.sample() for s in markov_model.states if (not ('start' in  s.name)) and (not ('end' in  s.name)) ]

In [None]:
markov_model.states[2].distribution.sample()