## Imports

In [None]:
from  functools import partial
import numpy as np
import pandas as pd
from shadow.plot import *
set_things()
import scipy.integrate
from sklearn.metrics import r2_score


plt.rcParams['figure.max_open_warning'] = 100
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['mathtext.bf'] = 'Arial'
plt.rcParams['axes.formatter.use_mathtext'] = True

    
from matplotlib import rc
rc('font',**{'family':'sans-serif','sans-serif': ['Arial']})
rc('text', usetex=False)


def set_font(p,plt):
    for size in ['axes.titlesize','axes.labelsize','font.size','figure.titlesize','legend.fontsize','xtick.labelsize','ytick.labelsize']:
        plt.rcParams[size] = p

        
set_font(26,plt)

In [None]:
def rename(name):
    l = [name[0]]
    for i in name[1:]:
        if i!=i.lower():
            l += [' ']
        l += [i]
    l = ''.join(l)
    if 'land' not in l:
        r = l.split('and')
        if len(r)>1:
            if r[1].lower()!=r[1]:
                l = ' and'.join(r)
    return l



## Statewise population

In [None]:
statewise_pop = {'UttarPradesh': 199812341,
 'Maharashtra': 112374333,
 'Bihar': 104099452,
 'WestBengal': 91276115,
 'MadhyaPradesh': 72626809,
 'TamilNadu': 72147030,
 'Rajasthan': 68548437,
 'Karnataka': 61095297,
 'Gujarat': 60439692,
 'AndhraPradesh': 49577103,
 'Odisha': 41974218,
 'Telangana': 35003674,
 'Kerala': 33406061,
 'Jharkhand': 32988134,
 'Assam': 31205576,
 'Punjab': 27743338,
 'Chhattisgarh': 25545198,
 'Haryana': 25351462,
 'Uttarakhand': 10086292,
 'HimachalPradesh': 6864602,
 #'Tripura': 3673917,
 #'Meghalaya': 2966889,
 'Manipur': 2570390,
 #'Nagaland': 1978502,
 'Goa': 1458545,
 'ArunachalPradesh': 1383727,
 'Mizoram': 1097206,
 #'Sikkim': 610577,
 'Delhi': 16787941,
 'JammuandKashmir': 12267032,
 'Puducherry': 1247953,
 'Chandigarh': 1055450,
 #'DadraandNagarHaveliandDamanandDiu': 585764,
 'AndamanandNicobarIslands': 380581,
 'Ladakh': 274000,
 #'Lakshadweep': 64473,
}

## Data for India

In [None]:
big_df = pd.read_csv('./BIG_DF.csv')
big_df = big_df[:-1]

DATA = None
ydata0 = None

def get_data(name,p):
    DATA = big_df[['Date',name]].loc[p:]
    ydata0 = DATA.values[:-1,1]
    return ydata0, DATA

## Model class

In [None]:
class Multi_SEIR():
    def __init__(self,sub_population_num,l1=None):
        self.n_population = sub_population_num
        self.y = None
        if type(l1)==type([]):
            X0,W0,Y0,Z0,D0 = l1
        else:
            X0 = np.ones((sub_population_num,sub_population_num))*1000
            W0 = np.zeros((sub_population_num,sub_population_num))
            Y0 = np.ones((sub_population_num,sub_population_num))*1
            Z0 = np.zeros((sub_population_num,sub_population_num))
            D0 = np.zeros((sub_population_num,sub_population_num))
        y0 = np.hstack([X0.ravel(),W0.ravel(),Y0.ravel(),Z0.ravel(),D0.ravel()])
        self.y0 = y0
        self.diff = {'X_d':[], 'W_d':[], 'Y_d':[], 'Z_d':[], 'D_d':[],}
        self._spacer = '\n'
    
    def sumit(self,t,y,fs,*args,**kwargs):
        s = np.zeros(self.n_population*self.n_population)
        for i in fs:
            s += i(t,y,self.y,**kwargs)
        return s
    
    def solve(self,t,**kwargs):
        def system(y,t):
            ret = np.hstack([self.sumit(t,y,self.diff[k]) for k in self.diff])
            return ret.ravel()
        args = ()
        self.y = scipy.integrate.odeint(system,self.y0,t,args=args,**kwargs)        
        
        
    def __repr_text(self,text):
        return (self._spacer+text)
    
    def __repr__(self):
        t = self.__repr_text
        text = t("Multi-SEIR model with following parameters:")
        text += t("Total number of sub population: {}".format(self.n_population))
        self._spacer = "\n"
        text += t("Initial conditions: ")
        
        self._spacer = "\n\t"
        n = self.n_population*self.n_population
        names = list(self.diff.keys())
        for i in range(5):
            text += t("{}: ".format(names[i])+self.y0[i*n:i*n+n].reshape(self.n_population,self.n_population).__repr__())

        self._spacer = "\n\n"
        text += t("Transitions:")
        for k,v in self.diff.items():
            self._spacer = "\n\t\n"
            text += t("{}: ".format(k))
            self._spacer = "\n\t\t"
            text += t("{}".format(v))
        
        if type(self.y)!=type(None):
            for i in range(5):
                text += t("{}: ".format(names[i])+self.y[-1][i*n:i*n+n].reshape(self.n_population,self.n_population).__repr__())
            
        return text
    
    def add(self,name,fs):
        if type(fs)!=type([]):
            fs = [fs]
        for f in fs:
            self.diff[name] += [partial(f,dim=self.n_population)]
            

## Transition functions

In [None]:
def f(t,y,y_,B=0,sign=1,dim=1):
    n = dim*dim
    X,W,Y,Z,D = [y[i*n:i*n+n].reshape(dim,dim) for i in range(5)]
    N = X+Y+W+Z
    ret =  sign*B*X*(Y.sum(axis=1)/N.sum(axis=1)).reshape(-1,1)
    return ret.ravel()

def g(t,y,y_,sig=0,sign=1,dim=1):
    n = dim*dim
    X,W,Y,Z,D = [y[i*n:i*n+n].reshape(dim,dim) for i in range(5)]
    ret = sign*sig*W
    return ret.ravel()

def h(t,y,y_,gam=0,rho=0,sign=1,dim=1):
    n = dim*dim
    X,W,Y,Z,D = [y[i*n:i*n+n].reshape(dim,dim) for i in range(5)]
    ret = sign*gam*Y*(1-rho)
    return ret.ravel()

def q(t,y,y_,rho=0.01,sign=1,dim=1):
    n = dim*dim
    X,W,Y,Z,D = [y[i*n:i*n+n].reshape(dim,dim) for i in range(5)]
    ret = sign*rho*Y
    return ret.ravel()


## Wrapper functions over model for curve fit

In [None]:
def func(t,B,POP=1.3e9,I0=258):
    sig = 1/7
    gam = 1/3
    rho = 0

    x0 = np.array([[POP]])
    y0 = np.array([[I0]])
    w0 = 2*y0
    z0 = 0*y0
    d0 = 0*y0

    model1 = Multi_SEIR(1,[x0,w0,y0,z0,d0])

    model1.add('X_d', [partial(f,B=B,sign=-1),])

    model1.add('W_d', [partial(f,B=B,sign=1),
                        partial(g,sig=sig,sign=-1),])

    model1.add('Y_d', [partial(g,sig=sig,sign=1),
                        partial(h,gam=gam,rho=rho,sign=-1),
                        partial(q,rho=rho,sign=-1),])

    model1.add('Z_d', [partial(h,gam=gam,rho=rho,sign=1),])

    model1.add('D_d', [partial(q,rho=rho,sign=1)])

    model1.solve(t)
    
    ret = model1.y
    
    return ret[:,2]

def func2(t,B,POP=1.3e9,I0=258):
    sig = 1/7
    gam = 1/3
    rho = 0

    x0 = np.array([[POP]])
    y0 = np.array([[I0]])
    w0 = 2*y0
    z0 = 0*y0
    d0 = 0*y0

    model1 = Multi_SEIR(1,[x0,w0,y0,z0,d0])

    model1.add('X_d', [partial(f,B=B,sign=-1),])

    model1.add('W_d', [partial(f,B=B,sign=1),
                        partial(g,sig=sig,sign=-1),])

    model1.add('Y_d', [partial(g,sig=sig,sign=1),
                        partial(h,gam=gam,rho=rho,sign=-1),
                        partial(q,rho=rho,sign=-1),])

    model1.add('Z_d', [partial(h,gam=gam,rho=rho,sign=1),])

    model1.add('D_d', [partial(q,rho=rho,sign=1)])

    model1.solve(t)
    
    ret = model1.y
    
    return ret 

## Curve fitting using scipy.optimize.curve_fit

In [None]:
from scipy.optimize import curve_fit
R0 = None
def get_me_R0(func,ydata0):
    popt, pcov = curve_fit(func,range(len(ydata0)),ydata0,p0=[1],bounds=[0.1,20])
    R0 = popt*3
    return R0, pcov


## Plot

In [None]:
def pretty_date(date):
    dates = date.split('/') 
    months = ['January','Febuary','March','April']
    return "{} {}".format(dates[0],months[int(dates[1])-1])
        
def plot_this(name,DATA,ydata0,y,R0,pcov,R2):
    Date = DATA['Date'].values

    plt.plot(y,lw=2, label=label )

    label = "Prediction"
    txt = "R$^2$ = {:.2f}\nR$_0$ = {:.2f}".format(R2,*R0.ravel())
    
    plt.gca().text(0.7,0.1,txt,transform=plt.gca().transAxes,)
    
    
    label = "Observed (I$_0$ = {})".format(ydata0[0])
    plt.plot(ydata0,'D',label=label)

    plt.yscale('log')
    plt.legend(loc=2)
    plt.xlabel('Time (days)')
    plt.ylabel('Population')
    
    period = "({} to {}, 2020)".format(pretty_date(Date[0]),pretty_date(Date[-1]))
    plt.title(rename(name.split('_')[0])+"\n"+"{}".format(period))



In [None]:
states = [i for i in big_df.columns if "_I_cumsum" in i]

def optim(state):    
    def g(state,p):
        ydata0, DATA = get_data(state,p)
        f1 = partial(func,POP=statewise_pop[state[:-9]],I0=ydata0[0])
        R0, pcov = get_me_R0(f1,ydata0)
        y = f1(range(len(ydata0)),R0/3)
        R2 = r2_score(ydata0.ravel(),y.ravel())
        return R2, R0, pcov, ydata0[-1]

    noskip = False
    R2_0 = -100
    R0_0 = np.array([-1])
    i_0 = 40
    pcov_0 = 0
    max_data = 0

    for i in range(40,60):
        R2, R0, pcov, num = g(state,i)
        if R2_0<R2:
            R2_0 = 1*R2
            R0_0 = 1*R0
            i_0 = 1*i
            pcov_0 = 1*pcov
            max_data = num
            
        if R2>0.9:
            ydata0, DATA = get_data(state,i)
            noskip = True
            break

    if noskip:
        f1 = partial(func,POP=statewise_pop[state[:-9]],I0=ydata0[0])
        y = f1(range(len(ydata0)),R0/3)
        R0_values = {state: [R0[0],R2,pcov[0,0],num]}
        #plot_this(state,DATA,ydata0,y,R0,pcov,R2)
        return R0_values
    else:
        ydata0, DATA = get_data(state,i_0)
        f1 = partial(func,POP=statewise_pop[state[:-9]],I0=ydata0[0])
        y = f1(range(len(ydata0)),R0_0/3)
        R0_values = {state: [R0_0[0],R2_0,pcov_0[0,0],max_data]}
        #plot_this(state,DATA,ydata0,y,R0_0,pcov_0,R2_0)
        return R0_values
    
R0_values = {}
set_font(26,plt)
for state in states:
    R0_values.update(optim(state))
    #filename = "./States/"+state+".png"
    #fig.savefig(filename)
    print(state)


In [None]:
import pickle
with open('./R0_values.pkl','wb') as file:
    pickle.dump(R0_values,file)


In [None]:
R0_df = pd.DataFrame(list(R0_values.values()),index=R0_values.keys(),columns=['R0','R2','pcov','num'])
R0_df = R0_df.sort_values(by=['R2'],ascending=False)
R0_df