In [None]:
from collections import OrderedDict

import numpy as np
import xarray as xr
import pickle
import pandas as pd
import time
import os

import datetime
from sys import getsizeof,path


import matplotlib.pyplot as plt
import matplotlib

import theano
import theano.tensor as tt
import pymc3 as pm
theano.config.optimizer="fast_run"

In [None]:
def tt_lognormal(x, mu, sigma,clip=True):
    if clip:
        x = tt.clip(x,1e-12,1e12)
        mu = tt.clip(mu,1e-12,1e12)
        sigma = tt.clip(sigma,1e-9,1e12)
    
    distr = .5 * (1+tt.erf((tt.log(x)-mu)/(tt.sqrt(2.)*sigma)))
    zero = tt.shape_padright(tt.zeros_like(distr[...,-1]),1)
    residual_tail = 1.-distr[...,-1]
    return tt.concatenate([zero,distr[...,1:]-distr[...,:-1]],axis=-1)#, 1.-distr[...,-1]
    

In [None]:
class ScanParameterEncoder(object):
    def __init__(self):
        # values encoded 
        self.state = OrderedDict()
        self.param = OrderedDict()
                
    def AddState(self,name,state):
        self.state[name] = state
        
    def AddParameter(self,name,param):
        self.param[name] = param
    
    def GetStateList(self):
        return list(self.state.values())
        
    def GetParameterList(self):
        return list(self.param.values())
        
    def Encode(self):
        states = self.GetStateList()
        params = self.GetParameterList()
        return states+params
    
    def Decode(self,values):
        end = len(self.param)+len(self.state)
        params = OrderedDict(zip(self.param.keys(),values[:len(self.param)]))
        state = OrderedDict(zip(self.state.keys(),values[len(self.param):end]))
        non_seq = values[end:]
        return state,params,non_seq
    
spe = ScanParameterEncoder()
spe.AddState("s1",np.zeros((2,2,)))
spe.AddState("s2",np.zeros((3,3,3,)))
spe.AddParameter("p1",np.ones((4,)))
spe.AddParameter("p2",np.ones((3,3,)))

values = spe.Encode()

print("values",values)
print("decode",spe.Decode(values))

print("list",spe.GetStateList())

In [None]:
class TimeSplitLoop(ScanParameterEncoder):
    def __init__(self,major=None,minor=None,scaling=7,other_dims=0):
        super(TimeSplitLoop,self).__init__()
        
        if major == None:
            major = (3,52)
        if minor == None:
            minor = (3,10)
        
        self.major = major
        self.minor = minor
        self.scaling = scaling
        
        self.other_dims = other_dims
        
#        self.AddParameter("test",tt.zeros((3,3,),dtype="float64"))

    def update_diagonal(self):
        pass
    
    def initial(self,dtype="float64"):
        shape = []
        shape += [self.major[0]*self.scaling+self.major[1]]
        shape += [self.minor[0]*self.scaling+self.minor[1]]
        shape += [1]*self.other_dims
        
        initial = tt.ones(shape,dtype)
        self.AddState("state",initial)
        
        print("Shape of state",shape,self.major,self.minor)
        
        return initial
    
    def run_fine_unrolled(self,n=20):
        values = tt.cast([1.01]*n,"float64")
        self.AddParameter("value",values)
        
        def fn(*args):
            # decode args to orderedDicts for state, param_i, non_sequence
            s,p,non_seq = self.Decode(args)
                        
            state = s["state"]
            value = p["value"]
            
            # sum inner
            acc = tt.zeros(self.minor[0]*self.scaling+self.minor[1],"float64")
            for i in range(self.scaling):
                
                # Shift fine-grained by 1, accumulate value from last fine-grained-cell 
                acc += state[self.major[0]*self.scaling-1]
                # rotate
                state = tt.set_subtensor(state[1:self.major[0]*self.scaling],state[:self.major[0]*self.scaling-1])
                # Math
                state = tt.set_subtensor(state[0],state[0]*value)
                
            # Shift rough-grained by 1, inject fine-accumulated at the beginning.
            state = tt.set_subtensor(state[self.major[0]*self.scaling+1:],state[self.major[0]*self.scaling:-1])
            state = tt.set_subtensor(state[self.major[0]*self.scaling],acc)

            #            state = tt.set_subtensor(state[self.major[0]*self.scaling],tt.sum(state[(self.major[0]-1)*self.scaling:self.major[0]*self.scaling],axis=0))
            s["state"] = state
            
            # state-dict to list
            return list(s.values())
        
        states,end = theano.scan(fn,outputs_info=self.initial(),sequences=self.GetParameterList())
                
        return states

In [None]:
tsl = TimeSplitLoop((2,6,),(2,4,))

print(tsl.initial().eval().shape)


t0 = time.time()
result = tsl.run_fine_unrolled()
states = result[-1].eval()
t1 = time.time()

print(t1-t0)

#print(states)

print("Results")
theano.printing.debugprint(result)
print("States")
theano.printing.debugprint(states)