In [None]:
import numpy as np
import warnings
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode, iplot, plot
import plotly.figure_factory as ff
init_notebook_mode(connected=True)

def create_taskload_heatmap(taskload_parameters, taskload_array, in_notebook=False,
                            colorscale='Jet',annotate_heatmap=True):
    
    """
    Create a heatmap to visualise the taskload array.
    """
    
    number_time_intervals = taskload_parameters['number_time_intervals']
    number_jobs = taskload_parameters['number_jobs']
    
    time_interval_taskloads = taskload_array.sum(axis=0)
    
    
    z_hover = np.empty((number_jobs, number_time_intervals),dtype='object')
    for i in range(number_time_intervals):
        for j in range(number_jobs):
            z_hover[j, i] = 'time interval ' + str(i) + '  job #' + str(j)
            
    if annotate_heatmap:
        annotation_text = taskload_array
    else:
        annotation_text = np.empty(taskload_array.shape,dtype='str')
        
    trace = ff.create_annotated_heatmap(z=taskload_array
    , x=[i for i in range(number_time_intervals)]
    , y=[i for i in range(number_jobs)]
    , xgap=1, ygap=1
    , annotation_text=annotation_text
    , colorscale=colorscale
    , text=z_hover
    , hoverinfo='text'
    )
    
    trace.layout.update({'title':'Taskload Matrix',
                        'xaxis':go.layout.XAxis(#title='Time Intervals',
                        tickvals = [i for i in range(number_time_intervals)],
                        ticktext = [str(int(i)) for i in time_interval_taskloads],
                        tickfont = {'size':10},
                        tickangle=30,
                        side='top'),                  
                                                
                         
                         'yaxis':go.layout.YAxis(title='Job ID')})
    
    if in_notebook:
        iplot(trace)
    else:
        plot(trace)
        
        
def is_solution_valid(taskload_array, solution, partition_max):

    max_partition_index = int(solution.max())
    
    valid = True
    for ti in range(solution.shape[1]): 
        
        for partition in range(max_partition_index+1):
            
            job_indexes = np.where(solution[:,ti]==partition)[0]
            temp_array = np.zeros(solution.shape[0])
            temp_array[job_indexes] = 1
            job_workload = np.multiply(temp_array, taskload_array[:,ti]).sum()
                                                 
            if job_workload > partition_max:
                warnings.warn('solution not valid for partition ' + str(partition) + ' in time interval ' + str(ti))
                valid = False
        
    return valid

def make_solution_groups_array(model, taskload_parameters):

    number_jobs = taskload_parameters['number_jobs']
    number_time_intervals = taskload_parameters['number_time_intervals']
    
    solution_array = np.zeros((number_jobs, number_time_intervals))
    
    job_time_to_group_dict = {(k2[0], k2[2]): k2[1] for k2, v2 in 
                                    {k: v for k, v in model.p.get_values().items() if v == 1
                                }.items()}
    
    for k, v in job_time_to_group_dict.items():
        
        solution_array[k[0],k[1]] = v
        
    return solution_array


def make_solution_agents_array(model, taskload_parameters):
    
    number_jobs = taskload_parameters['number_jobs']
    number_time_intervals = taskload_parameters['number_time_intervals']
    
    solution_array = np.zeros((number_jobs, number_time_intervals))
    
    job_time_to_group_dict = {(k2[0], k2[2]): k2[1] for k2, v2 in 
                                    {k: v for k, v in model.p.get_values().items() if v == 1
                                }.items()}
    group_time_to_agent_dict = {(k2[1], k2[2]): k2[0] for k2, v2 in 
                                        {k: v for k, v in model.q.get_values().items() if v == 1
                                    }.items()}
    job_time_to_agent_dict = {k: group_time_to_agent_dict[(v, k[1])]  for k, v in job_time_to_group_dict.items()}
    
    for k, v in job_time_to_agent_dict.items():
        
        solution_array[k[0],k[1]] = v
        
    return solution_array


def make_solution_heatmap(model, taskload_parameters, agent_parameters, taskload_array, 
                          agent_heatmap = True, in_notebook=False, colorscale='Jet',
                          annotate_heatmap=True):
    
    number_jobs = taskload_parameters['number_jobs']
    number_time_intervals = taskload_parameters['number_time_intervals']
    
    if agent_heatmap:
        solution_array = make_solution_agents_array(model, taskload_parameters)
    else:
        solution_array = make_solution_groups_array(model, taskload_parameters)
    
    time_interval_taskloads = taskload_array.sum(axis=0)
    
    z_hover = np.empty((number_jobs, number_time_intervals),dtype='object')
    for i in range(number_time_intervals):
        for j in range(number_jobs):
            z_hover[j, i] = 'time interval ' + str(i) + '  job #' + str(j)
            
    if annotate_heatmap:
        annotation_text = taskload_array
    else:
        annotation_text = np.empty(taskload_array.shape,dtype='str')
        
    trace = ff.create_annotated_heatmap(z=solution_array
        , x=[i for i in range(number_time_intervals)]
        , y=[i for i in range(number_jobs)]
        , xgap=1, ygap=1
        , annotation_text=annotation_text
        , colorscale=colorscale
        , text=z_hover
        , hoverinfo='text'
        , showscale=True
    )
    
    trace.layout.update({'title':'Job to Console Solution',
                        'xaxis':go.layout.XAxis(#title='Time Intervals',
                        tickvals = [i for i in range(number_time_intervals)],
                        ticktext = [str(int(i)) for i in time_interval_taskloads],
                        tickfont = {'size':10},
                        tickangle=30,
                        side='top'),                  
                                                
                         
                         'yaxis':go.layout.YAxis(title='Job ID')})
    
    if in_notebook:
        iplot(trace)
    else:
        plot(trace)
        
        

def get_solution_stats(group_solution_array, agent_solution_array):
    
    solution_stats = {}
    
    # helper function to get number of unique elements in each column of an array
    def nunique_percol_sort(a):
        b = np.sort(a,axis=0)
        return (b[1:] != b[:-1]).sum(axis=0)+1
    
    solution_stats['time_intervals_working'] = nunique_percol_sort(group_solution_array).sum()
    
    group_job_reconfig_count = 0
    group_reconfig_count = 0
    for i in range(group_solution_array.shape[0]):
        for j in range(group_solution_array.shape[1]-1):
            if i == 0:
                if not np.array_equal(group_solution_array[:,j], group_solution_array[:,j+1]):
                    group_reconfig_count += 1
            if group_solution_array[i,j] != group_solution_array[i,j+1]:
                group_job_reconfig_count += 1
    
    solution_stats['group_job_reconfig_count'] = group_job_reconfig_count
    solution_stats['group_reconfig_count'] = group_reconfig_count
    
    agent_reconfig_count = 0
    for j in range(agent_solution_array.shape[1]-1):
        if not np.array_equal(agent_solution_array[:,j], agent_solution_array[:,j+1]):
            agent_reconfig_count += min(len(np.unique(agent_solution_array[:,j])), 
                                        len(np.unique(agent_solution_array[:,j+1])))
            
    solution_stats['agent_reconfig_count'] = agent_reconfig_count
    
    return solution_stats