In [None]:
import numpy as np
import xarray as xr
import pickle
import pandas as pd
import time
import os

import datetime
from sys import getsizeof,path


import matplotlib.pyplot as plt
import matplotlib

import theano
import theano.tensor as tt
from theano.printing import debugprint
import pymc3 as pm
theano.config.optimizer="fast_run"

In [None]:
# Sample Dataset
a = np.arange(1,1+6*3*2).reshape(6,3,2)
b = np.arange(51,51+3*4*4).reshape(3,4,4)

a = xr.DataArray(a,dims=("age","BL","sex"),coords={"sex":["m","f"],"age":[2,3,4,5,6,7],"BL":["A","B","C"]})
b = xr.DataArray(b,dims=("sex","age","week"),coords={"sex":["f","u","m"],"age":[1,2,3,6],"week":[11,13,14,18]})


print(a.coords)
print(b.coords)

print(a.dims,b.dims)

print(a)
print(b)



In [None]:
def IndexProperties(x):
    try:
        steps = np.diff(x)
        stepsizes = len(set(steps))
    
        return {"countable":True,"continous":stepsizes==1,"stepsize":min(steps)}
    except:
        return {"countable":False}

def IndexMap(x,common):
    return np.array([(a in common) for a in x ])

x = np.array([1,2,3,4,8,9])
print(IndexProperties(x))

print(IndexMap(np.array([0,1,2,3]),x))

In [None]:
def IndexSumToMatch(X,iX,iC,dims,axis,sum_dir="skip"):
    """ 
        X,iX tensor and index along axis
        iC common index
        dims,axis number of dimensions and axis
        sum_dir skip, left or right """
    slice_blueprint = [slice(None) for n in range(dims)]
    # Skipped output
    slice_blueprint[axis] = IndexMap(iX,iC)  
    O = X[slice_blueprint]
    
    if sum_dir == "skip":
        pass
    else:
        # Accumulator zero
        slice_blueprint[axis] = 0
        z = tt.zeros_like(X[slice_blueprint])
        
        # Setup loop variables, x-counter, number of summed entries, accumulator(xsum)
        nc,nsum,xsum = 0,0,z
        iCnext = iC[0]
        for nx,ix in enumerate(iX):
            if ix != iCnext:
                nsum += 1
                slice_blueprint[axis] = nx
                xsum += X[slice_blueprint]
            else:
                if nsum > 0:
                    if sum_dir == "right":
                        slice_blueprint[axis] = nc
                        O = tt.inc_subtensor(O[slice_blueprint],xsum)
                    elif sum_dir == "left" and nc != 0:
                        slice_blueprint[axis] = nc-1
                        O = tt.inc_subtensor(O[slice_blueprint],xsum)
                        
                if iCnext < iC[-1]:
                    nc += 1
                    iCnext = iC[nc]
                # Reset sum
                nsum = 0
                xsum = z
        # Edge case, left add
        if sum_dir == "left" and nsum > 0:
            slice_blueprint[axis] = nc
            O = tt.inc_subtensor(O[slice_blueprint],xsum)
            
    return O


In [None]:
"""
ModelParams keeps track of coordinate-ranges for Model-internal datasets

"""


class ModelParams(object):
    def __init__(self,coords={}):
        self.coords = coords
        
        self.params = {}
        
    def AddParam(self,param):
        
        pname = param.name
        if pname not in self.params.keys():
            self.params[pname] = param
        else:
            print("Param %s already exists",pname)
            
    def __getitem(self,name):
        return self.params.get(name,None)
    
    def Overlap(self):
        pass

class ModelParam(object):
    """ Everything is a parameter in a bayesian model """
    def __init__(self,name,coords,param,is_variable=True):
        
        self.name = name
        self.coords = coords
        self.param = param
        
    def Dims(self):
        return list(self.coords.keys())
        
    def DimIndex(self,dim_name):
        return self.Dims().index(dim_name)
        
    def Overlap(self,other,sum_missing={}):
        """ Returns overlap of both params,
            returns A overlap with B, B with A, Common coords
        """
        dims,other_dims = self.Dims(),other.Dims()
        dimS,other_dimS = set(dims),set(other_dims)
        
        dim_overlap = dimS.intersection(other_dimS)
        overlap,other_overlap = self.param, other.param
        if len(dim_overlap) > 0:
            # figure non-matching dimensions, sum over non-matching dimensions
            not_in,other_not_in = dimS-other_dimS,other_dimS-dimS
            not_index,other_not_index = list(map(self.DimIndex,not_in)),list(map(other.DimIndex,other_not_in))
            if len(not_in) > 0:
                overlap = overlap.sum(axis=not_index)
                for k in not_in:
                    dims.remove(k)
            if len(other_not_in) > 0:
                other_overlap = other_overlap.sum(axis=other_not_index)
                for k in other_not_in:
                    other_dims.remove(k) # dims, other_dims are the new indices
            
            # Transpose matching dimensions
            other_transpose = list(map(other_dims.index,dims))
            other_overlap = other_overlap.dimshuffle(other_transpose)
            
            # Reduce coordinates for both theano objects
            overlap_index = {k:self.coords[k] for k in dims}
            other_index = {other_dims[i]:other.coords[other_dims[i]] for i in other_transpose}
            
            # Do the nasty stuff
            A,B,common_index = self.Overlap_Axes(overlap,other_overlap,overlap_index,other_index,sum_missing)
            return A,B,common_index
        else:
            return None,None,None
        
    def Overlap_Axes(self,A,B,A_index,B_index,sum_missing={}):
        """ Returns theano objects A,B
        
            For each axis, compare indizes and slice inputs A,B to match each other
        """
        
        A_indexer,B_indexer,common_index = [],[],{}
        for axis,dim in enumerate(A_index.keys()):
            iA,iB = A_index[dim], B_index[dim] 
                   
            iC = np.array(sorted( set(iA).intersection(set(iB)) ))
            common_index[dim] = iC
            
            iPA,iPB = IndexProperties(iA),IndexProperties(iB)
            A_sum = sum_missing.get(dim,"skip") 
            B_sum = sum_missing.get(dim,"skip")
            
            if iPA["countable"] == iPB["countable"] and iPA["countable"] == False:
                # Simple case for non-countable indices, only skipping with a boolean mask
                A_indexer.append(IndexMap(iA,iC))
                B_indexer.append(IndexMap(iB,iC))
                
            else: # more complex case, might involve skipping/summing of elements
                iPC = IndexProperties(iC) 
                
                if iPC["continous"] == True:
                    if iPC["stepsize"] == iPA["stepsize"]:
                        A_indexer.append( slice(iA.index(iC[0]),iA.index(iC[-1])+1) )
                    else:
                        A = IndexSumToMatch(A,iA,iC,len(A_index.keys()),axis,A_sum)
                        A_indexer.append(slice(None))

                    if iPC["stepsize"] == iPB["stepsize"]:
                        B_indexer.append( slice(iB.index(iC[0]),iB.index(iC[-1])+1) )
                    else:
                        B = IndexSumToMatch(B,iB,iC,len(B_index.keys()),axis,B_sum)
                        B_indexer.append(slice(None))
                        
                else: # involves skipping/summing in A and/or B
                    
                    print(axis,iPC)
                    print("\nA",iA)
                    A = IndexSumToMatch(A,iA,iC,len(A_index.keys()),axis,A_sum)
                    A_indexer.append(slice(None))
                    
                    print("\nB",iB)
                    B = IndexSumToMatch(B,iB,iC,len(B_index.keys()),axis,B_sum)
                    B_indexer.append(slice(None))
                    
        
        return A[A_indexer],B[B_indexer],common_index        

class ObservedData(ModelParam):
    def __init__(self,name,data):
        """ data : xarray"""
        coords = {} # xarray.DataArray.coords is not properly ordered.
        for d in data.dims:
            coords[d] = sorted(data.coords[d].values) # Make sure indices are sorted as well.
        param = theano.shared(data.sel(coords).values)
        super(ObservedData,self).__init__(name,coords,param,is_variable=False)
        
        
    

In [None]:
o1 = ObservedData("by_sex_BL_age",a)

o2 = ObservedData("by_sex_age_week",b)


ov1,ov2,ocoords = o1.Overlap(o2)
print("coords:",ocoords)



od1 = ov1.eval()
od2 = ov2.eval()

print(od1.shape)
print(od2.shape)

print(od1)

print(od2)


# Test Cases

In [None]:
# Seems to be the right thing
def IndexWalker(X,iX,iC,sum_dir="skip"):
    O = X[IndexMap(iX,iC)]
    if sum_dir == "skip":
        pass
    else:
        nc,nsum,xsum = 0,0,0
        iCnext = iC[0]
        for n,ix in enumerate(iX):
            if ix != iCnext:
                nsum += 1
                xsum += X[n]
            else:            
                if nsum > 0:
                    if sum_dir == "right":
                        O[nc] += xsum
                    elif sum_dir == "left" and nc != 0:
                        O[nc-1] += xsum
                        
                # only increment if within iC
                if iCnext < iC[-1]:
                    nc += 1
                    iCnext = iC[nc]
                
                # Reset sum
                nsum = 0
                xsum = 0
       
        # Edge case, left add
        if sum_dir == "left" and nsum > 0:
            O[nc] += xsum
    
    return O

x = np.arange(0,13,1,dtype="int32")
ix = x

ics = [[5,7]]
ics.append([2,3,4,11])
ics.append([3,4,5,7,9,11])

for ic in ics:

    print("\nlength:",len(x),len(ix),len(ic))
    print(ic)
    for sd in ["skip","left","right"]:
        print(sd,IndexWalker(x,ix,ic,sd))


# 1-Test

In [None]:


a_coords={"sex":["m","f"],"age":[0,1,2,3,4,5,6,7,8,9,10],"BL":["A","B"],"week":[11,14,18,20]}
b_coords={"sex":["f","u","m"],"age":[0,5,10],"week":[11,13,14,18,19]}

a = np.ones([len(x) for x in a_coords.values()],dtype="int64")
b = np.ones([len(x) for x in b_coords.values()],dtype="int64")

a = xr.DataArray(a,dims=a_coords.keys(),coords=a_coords)
b = xr.DataArray(b,dims=b_coords.keys(),coords=b_coords)

print(a.coords)
print(b.coords)
print(a.dims,b.dims)


o1 = ObservedData("by_sex_BL_age",a)
o2 = ObservedData("by_sex_age_week",b)

ov1,ov2,ocoords = o1.Overlap(o2,{"age":"left","week":"right"})
print("overlap_coords:",ocoords)

od1 = ov1.eval()
od2 = ov2.eval()

print(od1.shape)
print(od2.shape)

print(od1)

print(od2)


debugprint(ov1)

## Unknown Test case

In [None]:


sr,ar,br,wr = ["m","f"],range(0,11),range(1,17),range(10,30)
coords = {"sex":sr,"age":ar,"BL":br,"week":wr}

z = np.zeros(list(map(len,[sr,ar,br,wr])),dtype="int64")


print(z.shape)


for sj,s in enumerate(sr):
    si = {"m":1,"f":2}.get(s,0)
    for aj,age in enumerate(ar):
        ai = si*1000+age
        for bj,bl in enumerate(br):
            bi = ai*1000+bl
            for wj,week in enumerate(wr):
                wi = bi*10000+week
                z[sj,aj,bj,wj] = wi
                
a = xr.DataArray(z,dims=coords.keys(),coords=coords)
#print(a)     

In [None]:
a = slice(3,9,2)
c = slice(None,3)
print(a,c)

b = tt.cast(np.arange(160).reshape(20,4,2),"int64")

d = tt.cast([True,False,True,False],"bool")
print(tt.flatten(b[a,d]).eval()[c])

In [None]:
slice(10,20,1)+slice(30,35,2)

## Import Test

In [None]:


path.append("../src")
import Bernstein
import Population
import Cases
from ModelParams import ObservedData

In [None]:
age_cases10y,age_cases5y = Cases.RKI_Altersverteilung()
age_cases5y = age_cases5y.sel(publication=datetime.datetime(2021,3,15))

age_cases5y.sel(week=10)

In [None]:
o1 = ObservedData("Cases_RKI_AGs",age_cases5y)

coords = {"week":[15,20,30,40,45,50,55,60],"age":range(0,91,10),"sex":["m","f"]}
b = np.ones([len(x) for x in coords.values()],dtype="int64")
xb = xr.DataArray(b,dims=coords.keys(),coords=coords)

print(coords)

o2 = ObservedData("Cases10x10",xb)

ov1,ov2,common = o1.Overlap(o2,{"age":"left","weel":"right"})

print(common)

print(ov1.eval())
print(ov2.eval())
