In [7]:
import numpy as np
from operator import itemgetter

In [4]:
y1 = [[1,1,2],[1,2,1],[1,2,2],[1,1]]
y2 = [[0.2,0.3,0.2],[0.2,0.2,0.4],[0.2,0.2,0.6],[0.2,0.3]]
Y_train = [y1, y2, y1]

In [8]:
y_dim = len(Y_train)

"""
Converts a controller which looks like 

Y = [y1, y2] where

y1 = [[1,1,2],
      [1,2,1]]

y2 = [[0.2,0.3,0.2],
      [0.2,0.2,0.4]]

to a controller (zipped) which looks like 


ZY = [[(1, 0.2, 1), (1, 0.3, 1), (2, 0.2, 2)],
      [(1, 0.2, 1), (2, 0.2, 2), (1, 0.4, 1)]]

"""

def zipped_controller(Y):
    return [list(zip(*x)) for x in np.column_stack(Y)]

"""
Computes a list of dictionaries, one dictionary per control input component
Each dictionary contains a mapping from action value to number of occurances
In cases where one component has only one value throughout, an empty dictionary
is returned. For example, in
    [[(1, 0.2), (1, 0.3)],
     [(1, 0.2), (1, 0.4)]]
value 1 is the only action that can be performed in the first component of
the control input. In such a case, we should ignore the first component from
consideration when computing the max action.

Does the job in O(n) where n is the number of state-action pairs
"""
def frequency_compute(li):
    d_list = [{} for i in range(y_dim)]
    total_tups = 0
    for l in li:
        for tup in l:
            total_tups += 1
            for dim in range(y_dim):
                e = tup[dim]
                if e in d_list[dim].keys():
                    d_list[dim][e] += 1
                else:
                    d_list[dim][e] = 1
    
    # If some dictionary contains only 1 element, it means
    # that we don't need to consider this control input anymore
    pruned_d_list = []
    for d in d_list:
        if len(d) == 1:
            pruned_d_list.append({})
        else:
            pruned_d_list.append(d)
    
    return pruned_d_list, total_tups

"""
Returns the control input component which needs to be considered for filtering
Returns the value of the control input which must be retained (the max)

Does the job in O(2n) where n is the number of state-action pairs
"""
def get_best(freq_list):
    # O(n)
    max_list = [max(d.items(), key=itemgetter(1)) if d else (-1, -1) for d in freq_list]

    # O(n)
    max_index = 0
    max_count = 0
    for i in range(len(max_list)):
        if max_list[i][1] > max_count:
            max_count = max_list[i][1]
            max_index = i

    return (max_index, max_list[max_index][0])

"""
Finds the max action along any single component and keeps only that

Does the job in O(4n) where n is the number of state-action pairs
"""
def single_prune(controller):
    # O(n)
    freq_list, total_tups = frequency_compute(controller)
    # O(2n)
    index, value = get_best(freq_list)
    # O(n)
    return [[tup for tup in row if tup[index] == value] for row in controller]    

"""
Returns a determinized controller after applying the max pruning strategy

Does the job in O(num_components * 4n) (I think!)
"""
def determinize(full_controller):
    controller = full_controller
    while max([len(row) for row in controller]) > 1:
        controller = single_prune(controller)
    return controller
        
determinize(zipped_controller(Y_train))

[[(1, 0.2, 1)], [(1, 0.2, 1)], [(1, 0.2, 1)], [(1, 0.2, 1)]]