In [123]:
import numpy as np
from copy import deepcopy
from normalizer import Scale
import math

In [124]:
class BaseStateConstructor:
    def __init__(self):
        self.parents = []
        self.children = None
        return
    
    def __call__(self, o):
        if len(self.parents) == 0: # base case
            o_parents = [o]
        else:
            o_parents = [p(o) for p in self.parents]
            
        o_next = self.process_observation(o_parents)
        return o_next
    
    
    def set_parents(self, parents):
        self.parents = parents
        
        
    def set_children(self, children):
        self.children = children
    
    
    def process_observation(self, o_parents):
        """
        takes a list and returns a VECTOR
        """
        raise NotImplementedError
        

In [125]:
class Identity(BaseStateConstructor):
    def process_observation(self, o_parents):
        assert len(o_parents) == 1
        return o_parents[0]
    

In [126]:

class KOrderHistory(BaseStateConstructor):
    """
    Keeps a running list of observations
    """
    def __init__(self, k=1):
        super().__init__()
        self.k = k
        self.obs_history = []
        self.num_elements = 0
        
        
    def process_observation(self, o_parents):
        """
        takes a list and returns a VECTOR
        """
        assert(len(o_parents)) == 1
        o_parent = o_parents[0]
        self.obs_history.append(o_parent)
        
        if len(self.obs_history) > self.k:
            self.obs_history.pop(0)
             
        return_list = deepcopy(self.obs_history)
        # ensure returned list has self.k elements
        if len(return_list) < self.k:
            last_element = return_list[-1]
            for _ in range(len(return_list), self.k):
                return_list.append(last_element)
                
        return np.array(return_list)
    

In [127]:
class MemoryTrace(BaseStateConstructor):
    def __init__(self, trace_decay):
        super().__init__()
        self.trace_decay = trace_decay
        self.trace = None

    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o_parent = o_parents[0]
        if self.trace is None: # first observation received
            self.trace = o_parent
        else:
            self.trace = (1-self.trace_decay)*o_parent + self.trace_decay*self.trace
        return self.trace

In [128]:
class LastK(BaseStateConstructor):
    def __init__(self, k):
        super().__init__()
        self.k = k
        self.trace = None

    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o_parent = o_parents[0]
        if len(o_parent.shape) == 2: # 2D array
            o = o_parent[-self.k:, ]
            o = o.flatten()
        elif len(o_parent.shape) == 1:
            o = o_parent[-self.k:]    
        return o

In [129]:
class Flatten:
    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        return o_parents[0].flatten()

In [130]:
class Concatenate(BaseStateConstructor):
    def process_observation(self, o_parents):
        return np.concatenate(o_parents, axis=0)

In [131]:
class Normalize(BaseStateConstructor):
    def __init__(self, scaler, bias):
        super().__init__()
        self.normalizer = Scale(scaler, bias)
        
    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o_parent = o_parents[0]
        return self.normalizer(o_parent)

In [132]:
class WindowAverage(BaseStateConstructor):
    """
    Averages every window_size observations together
    """
    def __init__(self, window_size):
        super().__init__()
        self.window_size = window_size
    
    
    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o = o_parents[0]
        assert(len(o.shape)==2)
        num_rows =  o.shape[0] 
        o = o[num_rows%self.window_size:]
        assert o.shape[0] % self.window_size == 0
        o = o.reshape(-1, self.window_size, o.shape[1])
        return np.mean(o, axis = 1)


In [133]:
class Beginning(BaseStateConstructor):
    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o = o_parents[0]
        assert(len(o.shape)==2)
        return o[0, :]
    
    
class End(BaseStateConstructor):
    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o = o_parents[0]
        assert(len(o.shape)==2)
        return o[-1, :]
    
class Mid(BaseStateConstructor):
    def process_observation(self, o_parents):
        assert(len(o_parents)) == 1
        o = o_parents[0]
        assert(len(o.shape)==2)
        i = math.ceil(o.shape[0]/2)
        return o[i, :]