In [None]:
from collections import OrderedDict

import numpy as np
import xarray as xr
import pickle
import pandas as pd
import time
import os

import datetime
from sys import getsizeof,path


import matplotlib.pyplot as plt
import matplotlib

import theano
import theano.tensor as tt
import pymc3 as pm
theano.config.optimizer="fast_run"

In [None]:
class ScanParameterEncoder(object):
    def __init__(self):
        # values encoded 
        self.state = OrderedDict()
        self.param = OrderedDict()
        
    def AddState(self,name,state):
        self.state[name] = state
        
    def AddParameter(self,name,param):
        self.param[name] = param
        
    def Encode(self):
        states = list(self.state.values())
        params = list(self.param.values())
        return states+params
    
    def Decode(self,values):
        states = OrderedDict(zip(self.state.keys(),values[:len(self.state)]))
        params = OrderedDict(zip(self.param.keys(),values[len(self.state):]))
        return states,params
    
spe = ScanParameterEncoder()
spe.AddState("s1",np.zeros((2,2,)))
spe.AddState("s2",np.zeros((3,3,3,)))
spe.AddParameter("p1",np.ones((4,)))
spe.AddParameter("p2",np.ones((3,3,)))

values = spe.Encode()

print("values",values)

print("decode",spe.Decode(values))

In [None]:
class TimeSplitLoop(ScanParameterEncoder):
    def __init__(self,major=None,minor=None,scaling=7,other_dims=0):
        if major == None:
            major = (3,52)
        if minor == None:
            minor = (3,10)
        
        self.major = major
        self.minor = minor
        self.scaling = scaling
        
        self.other_dims = other_dims
        

    def update_diagonal(self):
        pass
    
    def initial(self,dtype="float64"):
        shape = []
        shape += [self.major[0]*self.scaling+self.major[1]]
        shape += [self.minor[0]*self.scaling+self.minor[1]]
        shape += [1]*self.other_dims

        return tt.ones(shape,dtype)
    
    def run_fine_unrolled(self,n=10):
        values = tt.cast([1.01]*n,"float64")
        
        def fn(value,state):
            
            acc = tt.zeros(self.minor[0]*self.scaling+self.minor[1],"float64")
            for i in range(self.scaling):
                
                # Shift fine-grained by 1, accumulate value from last fine-grained-cell 
                acc += state[self.major[0]*self.scaling-1]
                state = tt.set_subtensor(state[1:self.major[0]*self.scaling],state[:self.major[0]*self.scaling-1])
                state = tt.set_subtensor(state[0],state[0]*value)
                
            # Shift rough-grained by 1, inject fine-accumulated at the beginning.
            state = tt.set_subtensor(state[self.major[0]*self.scaling+1:],state[self.major[0]*self.scaling:-1])
            state = tt.set_subtensor(state[self.major[0]*self.scaling],acc)

            #            state = tt.set_subtensor(state[self.major[0]*self.scaling],tt.sum(state[(self.major[0]-1)*self.scaling:self.major[0]*self.scaling],axis=0))
            return state
        
        states,end = theano.scan(fn,outputs_info=[self.initial()],sequences=[values])
                
        return states
        

In [None]:
tsl = TimeSplitLoop((2,6,),(2,4,))

print(tsl.initial().eval().shape)


t0 = time.time()
result = tsl.run_fine_unrolled()
states = result[-1].eval()
t1 = time.time()

print(t1-t0)

#print(states)

theano.printing.debugprint(result)
theano.printing.debugprint(states)

In [None]:
# Dict test

def fn(**kwargs):
    print(kwargs)
    state = kwargs["state"]
    state = tt.set_subtensor(state[1:],state[:-1])
    return state

initial = tt.cast(range(10),"float64")
r = theano.scan(fn,outputs_info={"state":initial},n_steps=10)


print(r.eval())