# Load and Pre-process Data

## Prepare Environment

### Libraries

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import unicodedata
import datetime
from scipy.integrate import solve_ivp
from scipy.stats import poisson
from scipy.optimize import minimize
import json
import plotly.graph_objects as go
import plotly.express as px

### Constants

In [None]:
# Adjustables
# -- keep updating and attempting
# own data: 在这里使用自己的数据集（需要配合下面的人口和交通数据）
NCOV_DB_FILENAME = "data/ncov_20200205.csv"
CHECKPOINT_FILENAME = "checkpoints/seir_fit_result_20200205"
# own population data
POPULATION_DB_FILENAME = "data/ChinaPopulation.csv"
MIGRATION_INDEX_FILENAME = "data/BaiduMigarationIndex.csv"
MIGRATION_DEST_FILENAME = "data/BaiduMigarationTo.csv"

OPTIONS = dict(
    POPULATION_THRESHOLD=3e6, # popu greater than th to be included in analysis
    ALPHA=1/7,
    INIT_BETA=0.4,
    INIT_GAMMA=0.15,
    INIT_INIT_INFECT=50,
    RECENT_INLIER_DAYS=5,
    SIMULATION_DAYS=90,
    RANSAC_SAMPLE_NUM=4,
    ROUNDS=1000,
    REGULARISE=dict(
        small_beta=1/1000,
        small_beta_k=1,
        large_beta=10,
        large_beta_k=1,
        small_gamma=1/30,
        small_gamma_k=1,
        large_gamma=1/2,
        large_gamma_k=1,
        small_I0=2,
        small_I0_k=1,
        large_I0=1000,
        large_I0_k=1)
    )

# Factual constants
# -- don't change often unless you know what you are doing
# Exclude global cases. See discussion in "Data Source".
GLOBAL_CITIES = [
    '美国', '法国', '澳大利亚', '意大利', '德国', '香港',
    '韩国', '英国', '日本', '泰国', '日本', '台湾', '新加坡',
    '俄罗斯', '芬兰', '加拿大', '马来西亚', '菲律宾', '印度',
    '越南',  '阿联酋', '柬埔寨', '斯里兰卡', '尼泊尔'
]
DAY_LIWENLIANG = np.datetime64("2020-01-01") # t=0, name is a tribute to our fallen hero



### Helper Functions


In [None]:
# Functions to parse infection database
def match_place_name(name1, name2):
    if name1.find(name2) == 0 or name2.find(name1) == 0:
        return True
    else:
        return False
    
def get_city_id(db, q):
    if not isinstance(q, str):
        return None
    ids = db.index[db['City']\
        .map(lambda x:match_place_name(x, q))].tolist()
    return ids[0] if len(ids) == 1 else None

def parse_chinese_datetime(s):
    if not isinstance(s, str):
        raise ValueError(f"Unknown Date {s}")
    
    i0 = s.find("月")
    i1 = s.find("日")
    month_s = unicodedata.normalize('NFKD', s[:i0])\
        .encode('ascii','ignore')\
        .decode()
    date_s = unicodedata.normalize('NFKD', s[i0 + 1:i1])\
        .encode('ascii','ignore')\
        .decode()
    return datetime.datetime(2020, int(month_s), int(date_s))

# To save results
class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(NpEncoder, self).default(obj)

### Load data

#### Population

In [None]:
# g_: global variables
g_popu_df = pd.read_csv(POPULATION_DB_FILENAME, encoding="utf-8")
g_all_city_num = len(g_popu_df) 
# all cities with population record, we will selection a subset to fit out model
# 这里是所有有人口记录的城市，我们会选取一个子集，满足 1人口足够多，2与武汉的交通在被统计的时间段排入前100

g_start_loc_id_in_all = get_city_id(g_popu_df, "武汉")
g_all_popu_vec = np.zeros(g_all_city_num)
for i, r in g_popu_df.iterrows():
    p = r['Population2018']
    if np.isnan(p):
        p = r['Population2010']
    g_all_popu_vec[i] = p * 1e4

#### Infection Number

In [None]:
# Load infection data
df0 = pd.read_csv(NCOV_DB_FILENAME, encoding="utf-16")
df0 = df0.rename(columns={"报道时间": "Date", 
                         "新增确诊": "Infected.Inc",
                         "省份": "Place.State",
                         "城市": "Place.City",})

df = df0[['Infected.Inc', 'Place.State', 'Place.City']].copy()
df['Date'] = pd.Series(
    df0['Date'].map(parse_chinese_datetime),
    index=df.index)
df['Day.Delta'] = np.round(
    (df['Date'] - DAY_LIWENLIANG) / np.timedelta64(1, 'D')).astype(int)
df['Place.Id'] = pd.Series([-1] * len(df), index=df.index)
df.loc[df['Place.City']=="恩施土家族苗族自治州", 'Place.City'] = "恩施州"

# 1. Some records are at district level, which is not in the city list. 
# we aggragate those records to cities. E.g. "通州" -> "北京"
# 2. Ignore oversea cases
omitted = 0
omitted_ind = []
verbose = False
for index, row in df.iterrows():
    p = row['Place.City']
    city_id = get_city_id(g_popu_df, p)
    if not city_id is None:
        df.loc[index, 'Place.Id'] = city_id
    else:
        p1 = row['Place.State']
        city_id1 = get_city_id(g_popu_df, p1)
        if not city_id1 is None:
            # Aggregate district to megapolitan cities, e.g.
            
            df.loc[index, 'Place.City'] = p1
            df.loc[index, 'Place.Id'] = city_id1
        elif not p1 in GLOBAL_CITIES:
            if verbose:
                print(f"{p, p1} is unknown")
            omitted += row['Infected.Inc']
df = df.drop(df.index[df['Place.Id'] < 0])
print(f"Omitted {omitted} cases of missing location information.",
      "Total:", df['Infected.Inc'].sum())

g_days = df["Day.Delta"].max() + 1
g_all_daily_increase = np.zeros((g_all_city_num, g_days))
for i, r in df.iterrows():
    ci = r['Place.Id']
    t = r['Day.Delta']
    g_all_daily_increase[ci, t] += r['Infected.Inc']

#### Migration

In [None]:
# Estimate Unit in Baidu Migration Index
mg_df = pd.read_csv(MIGRATION_INDEX_FILENAME)
total_out_index = mg_df.loc[9:20]["Out"].sum()
total_out_number = 4096800
num_per_index = total_out_number / total_out_index
print("Per Index:", num_per_index)

# Day[0]: 1 Jan 2020
net_out_bound = {
    i: (row['Out'] - row['In']) * num_per_index
    for i, row in mg_df.iterrows()
}

# Get Destination Proportion
dest_df = pd.read_csv(MIGRATION_DEST_FILENAME)
dest_df['Place.Id'] = pd.Series(
    [-1] * len(dest_df), 
    index=dest_df.index)
for i, r in dest_df.iterrows():
    ci = get_city_id(g_popu_df, r['City'])
    if not ci is None:
        dest_df.loc[i, 'Place.Id'] = ci
        
prop_vec_ = np.zeros(g_all_city_num)
for i, r in dest_df.iterrows():
    prop_vec_[r['Place.Id']] = r['Percent'] / 100

g_trans_data_max_day = len(mg_df)
zero_trans_mat = np.zeros((g_all_city_num, g_all_city_num))
g_all_trans_mats = np.zeros((g_trans_data_max_day, g_all_city_num, g_all_city_num))
for t in range(g_trans_data_max_day):
    g_all_trans_mats[t][g_start_loc_id_in_all] = prop_vec_ * net_out_bound[t]

def get_all_city_trans(t):
    """
    From Wuhan to other cities at time t.
    
    e.g. get day-15, net-outbound number from Wuhan.
    get_city_trans(15)[7]
    (7 is the CityID of Wuhan)
    :param t: time
    
    NOTE: mutual-comm needs to be considered in later stage
    """

    if t < 0 or t >= 22.99:
        return zero_trans_mat
    else:
        return g_all_trans_mats[int(t)]
    

#### Select a subset of cities for analysis

In [None]:
tmp_ids = dest_df['Place.Id'].tolist()
g_subset_place_ids = []
for i, p in enumerate(g_all_popu_vec):
    if p > OPTIONS['POPULATION_THRESHOLD'] and i in tmp_ids:
        g_subset_place_ids.append(i)
g_subset_place_ids.append(g_start_loc_id_in_all)
g_subset_place_ids = np.array(sorted(g_subset_place_ids), dtype=int)


g_daily_increase = g_all_daily_increase[g_subset_place_ids]
g_start_loc_id = np.nonzero(g_subset_place_ids == g_start_loc_id_in_all)[0][0]
g_city_num = len(g_subset_place_ids)

g_popu_vec = g_all_popu_vec[g_subset_place_ids]
g_trans_mats = np.zeros((g_trans_data_max_day, g_city_num, g_city_num))
for t in range(g_trans_data_max_day):
    g_trans_mats[t][:, :] = g_all_trans_mats[t][g_subset_place_ids][:, g_subset_place_ids]
def get_city_trans(t):
    """
    From Wuhan to other cities under investigation at time t.
    """
    if t < 0 or t >= 22.99:
        return np.zeros((g_city_num, g_city_num))
    else:
        return g_trans_mats[int(t)]

# Define Networ SEIR Model

In [None]:
class NetworkSEIR:
    def __init__(self, log_alpha, log_beta, log_gamma, 
                 log_init_infected, 
                 start_loc_id,
                 population, tr_fn):
        """
        From Wu et al., the parameters are initialised as 

        :param log_alpha: log of alpha: esposed -> infected
        :param log_beta: log: susceptible -> exposed
        :param log_gamma: log: infected -> recover
        :param init_infected: wuhan number @ day0
        :param start_loc_id: id of the location, e.g. wuhan
        :param population: a vector of population in each place
        :param tr_fn: function with signature 
          tr_fn(time):return #.cities x #.cities matrix.
          return_mat[i, j]number of people travelling from city-i
          to city-j
        """
        self.log_alpha = log_alpha
        self.log_beta = log_beta
        self.log_gamma = log_gamma
        self.log_init_infected = log_init_infected
        
        self.population = population
        self.start_loc_id = start_loc_id
        
        self.tr_fn = tr_fn
        self.place_n = len(population)

    def diff(self, t, y):
        alpha = np.exp(self.log_alpha)
        beta = np.exp(self.log_beta)
        gamma = np.exp(self.log_gamma)
        tr_mat = self.tr_fn(t)
        
        ymat = y.reshape(4, self.place_n)
        
        S = ymat[0]
        E = ymat[1]
        I = ymat[2]
        R = ymat[3]
        N = self.population
        K = self.tr_fn(t)
        
        # K.T / N : K.T[:, 7]:out num from Wuhan, N[7], popu of Wuhan
        d = beta * (np.dot(K.T / N, I) + I) / N
        dS = - d * S
        dE = d * S - alpha * E
        dI = alpha * E - gamma * I
        dR = gamma * I

        ret = np.concatenate([dS, dE, dI, dR])
        return ret


    def get_init_state(self):
        init_infected = np.exp(self.log_init_infected)
        init_exposed = init_infected / np.exp(self.log_alpha)
        
        v_init_infected = np.zeros(self.place_n)
        v_init_exposed = np.zeros(self.place_n)
        v_init_infected[self.start_loc_id] = init_infected
        v_init_exposed = v_init_infected / np.exp(self.log_alpha)
        v_init_suscept = self.population \
            - v_init_exposed - v_init_infected
        v_init_recov = np.zeros(self.place_n)
        
        init_state = np.concatenate([
            v_init_suscept,
            v_init_exposed,
            v_init_infected,
            v_init_recov
        ])
        return init_state
        
    def integrate(self, t_eval, method="RK45"):
        t_end = t_eval.max() # the end of period to run ODE
        
        sol = solve_ivp(
            self.diff, [0, t_end],
            self.get_init_state(), t_eval=t_eval, method=method)
        return dict(S=sol.y[0:self.place_n, :],
                    E=sol.y[self.place_n:2*self.place_n, :],
                    I=sol.y[2*self.place_n:3*self.place_n, :],
                    R=sol.y[3*self.place_n:4*self.place_n, :]), sol


    def neglog_likelihood(self, 
                          infection_inc_data,
                          return_cdf=False):
        """
        :param infection_inc_data: [place_num x T] of infection number 
          records.
        :param return_cdf: if True compute CDF
        :return: 
            negative-log-pmf
            log-cdf: can be used to judge if an obsevation is within confidence 
              interval
        """
        T = infection_inc_data.shape[1]
        res, full_solu = self.integrate(np.arange(0, T))
        
        expected_infected = res["I"]
        expected_increase = np.maximum(
            expected_infected[:, 1:] - expected_infected[:, :-1],
            1e-9)
        
        real_increase = infection_inc_data[:, 1:].astype(int)
        
        nll = - poisson.logpmf(real_increase, expected_increase)
        log_cdf = poisson.logcdf(real_increase, expected_increase) \
            if return_cdf else None
        return nll, log_cdf

# RANSAC Model Fitting

In [None]:
OPTIONS["ROUNDS"] = 10

In [None]:
#=============================================================
# Model fitting
def regulariser(params):
    """
    :param params: optim variables
    To replace penalties like
    pen2 = np.maximum(lg - np.log(1/2), 0)  # if gamma > 1/2
    pen3 = np.maximum(np.log(1/1000) - lb, 0) # if beta < 1/1000
    pen4 = np.maximum(lb - np.log(10), 0) # if beta > 10
    """
    
    
    lb, lg, I0 = params
    
    param_values = [lb, lb, lg, lg, I0, I0]
    regu_bound_keys = ['small_beta', 'large_beta', 
                       'small_gamma', 'large_gamma',
                       'small_I0', 'large_I0']
    pen = 0
    for pv, rk in zip(param_values, regu_bound_keys):
        rv = np.log(OPTIONS['REGULARISE'][rk])
        w = OPTIONS['REGULARISE'][rk+"_k"]
        if rk.startswith("small"): 
            # penalise if pv < rv
            pen += np.maximum(rv - pv, 0) * w
        elif rk.startswith("large"): 
            pen += np.maximum(pv - rv, 0) * w
    
    return pen

def obj_fn(params, 
           inc_data, observ_mask,
           alpha, start_id, popu_vec, tr_fn):
    """
    :param place_mask: [place_num x T] vector. Counting on limited evidences
    """
    
    lb, lg, linfect0 = params
    
    seir_model = NetworkSEIR(
        np.log(alpha), 
        lb, lg, linfect0, start_id,
        popu_vec, tr_fn)
    nll, _ = seir_model.neglog_likelihood(infection_inc_data=inc_data)
    
    nll *= observ_mask
    nlls = nll.sum()
    s = observ_mask.sum()
    
    obj_val = nlls + regulariser(params) * s
    return obj_val

def fit_model(obj_fn, daily_increase_data, observ_mask):
    """
    This function depends on global variables
    """
    beta0 = OPTIONS['INIT_BETA']
    gamma0 = OPTIONS['INIT_GAMMA']
    infect0_guess = OPTIONS['INIT_INIT_INFECT']
    al = OPTIONS['ALPHA']
    msol = minimize(
        obj_fn, [np.log(beta0), np.log(gamma0), np.log(infect0_guess)],
        args=(daily_increase_data,
              observ_mask,
              al,
              g_start_loc_id,
              g_popu_vec,
              get_city_trans),
        method='Nelder-Mead', options=dict(disp=False))
    
    lb, lg, i0 = msol.x
    seir_model = NetworkSEIR(
        np.log(al), lb, lg, i0, 
        g_start_loc_id,
        g_popu_vec,
        get_city_trans)
    
    cdf = np.exp(seir_model.neglog_likelihood(
        g_daily_increase, True)[1])
    is_inlier = np.logical_and(cdf > 0.05, cdf < 0.95)
    inlier_num = is_inlier
    
    simu_resu, sol = seir_model.integrate(
        np.arange(0, OPTIONS['SIMULATION_DAYS']))
    
    return msol, is_inlier, seir_model, simu_resu

resu = dict(city_id=[],
            days=[],
            param_sol=[],
            fit_num=[])

rng = np.random.RandomState(0)
best_inlier_num = -1
g_ransac_succ = False
for epoch in range(OPTIONS['ROUNDS']):
    obs_city_id = rng.randint(g_city_num)
    obs_days = np.sort(rng.choice(
        np.arange(20, g_daily_increase.shape[1] - 1), 
        size=(OPTIONS['RANSAC_SAMPLE_NUM'],), 
        replace=False))
    
    observ_mask = np.zeros((g_city_num, g_daily_increase.shape[1]-1))
    observ_mask[obs_city_id][obs_days] = 1
    
    msol, is_inlier, seir_model, simu_resu = fit_model(
        obj_fn, g_daily_increase, observ_mask)
    
    lb, lg, i0 = msol.x
    beta = np.exp(lb)
    gamma = np.exp(lg)
    R0 = beta / gamma
    
    #simu_resu, sol = seir_model.integrate(np.arange(0, 90))
    
    inlier_num = is_inlier.sum()
    rd_ = - OPTIONS['RECENT_INLIER_DAYS']
    recent_inlier_num = is_inlier[:, rd_:].sum()
    
    
    resu['city_id'].append(obs_city_id)
    resu['days'].append(obs_days)
    resu['param_sol'].append(msol.x)
    resu['fit_num'].append((inlier_num, recent_inlier_num))
    
    if inlier_num > best_inlier_num:
        best_inlier_num = inlier_num
        best_model = seir_model
        best_simu = simu_resu
    
    print(f"[{epoch}] beta: {np.exp(lb):.3f}, gamma: {np.exp(lg):.3f}, " + 
          f"I0: {np.exp(i0):.1f}, R0: {R0:.2f}, fit:{inlier_num}/{recent_inlier_num}:{best_inlier_num}")
g_ransac_succ = True    

In [None]:
# estimation with I0 upbound 1000, times 1000, fit-city=1, fit-days=4
if g_ransac_succ:
    with open(CHECKPOINT_FILENAME + ".json", "w") as f:
        json.dump(resu, f, indent=2, cls=NpEncoder)
    with open(CHECKPOINT_FILENAME + "_options.json", "w") as f:
        json.dump(OPTIONS, f, indent=2, cls=NpEncoder)

# Graphical Report of Results - Part 1

## Estimated parameters and Make predictions

In [None]:
with open(CHECKPOINT_FILENAME + ".json", "r") as f:
    resu = json.load(f)
# summarise results to report
d_ = [[np.exp(ps[0]), np.exp(ps[1]), np.exp(ps[2]), 
       np.exp(ps[0]-ps[1])] + fn 
      for ps, fn in zip(resu['param_sol'], resu['fit_num'])]
g_resu_df = pd.DataFrame(
    data=d_, 
    columns=['beta', 'gamma', 'I0', 'R0', 'fit', 'recent_fit'])\
    .sort_values(by=['recent_fit'], ascending=False)

# Perform prediction using the optimal model
# an old set of param also good
# lb, lg, li0 = np.log([0.662, 0.117, 964.0])
g_resu_optm = g_resu_df.iloc[0]
seir_optm = NetworkSEIR(
    np.log(OPTIONS['ALPHA']),
    np.log(g_resu_optm['beta']),
    np.log(g_resu_optm['gamma']),
    np.log(g_resu_optm['I0']),
    g_start_loc_id,
    g_popu_vec,
    get_city_trans)
simu_resu, sol = seir_optm.integrate(np.arange(0, 365))
sim_I = simu_resu["I"]

accu_data = np.cumsum(g_daily_increase, axis=1)
time_to_pred = 45
time_with_data = accu_data.shape[1]
time_points_sim = DAY_LIWENLIANG \
    + np.timedelta64(1, 'D') * np.arange(0, time_to_pred)
time_points_data = time_points_sim[:time_with_data]

nlogpmf, logcdf = seir_optm.neglog_likelihood(
    g_daily_increase, True)
cdf = np.exp(logcdf)
inlier_status = np.zeros((g_city_num, time_with_data - 1)).astype(np.int)
inlier_status[cdf < 0.05] = -1
inlier_status[cdf > 0.95] = 1

In [None]:
print(g_resu_optm, 85*5)

## Draw Estimated Parameters

In [None]:
fig = px.scatter(
    g_resu_df, x='I0', y='R0', 
    color='recent_fit', width=800, height=600)\
    .add_trace(go.Scatter(
        x=[g_resu_optm['I0']], y=[g_resu_optm['R0']], 
        marker=dict(symbol='triangle-up', size=24, 
                    line_color='black', line_width=3,
                    color='rgba(0, 255, 0, 0.5)'),
        showlegend=False))\
    .update_layout(
        coloraxis_colorbar=dict(
            title=dict(text="Number of Inliers<br>(Recent 5D)",
                       font_size=16),
            thicknessmode="pixels", thickness=50,
            yanchor="top", y=1,
            ticks="inside",
        dtick=20),
        yaxis=go.layout.YAxis(
            ticks='inside',
            tickfont=dict(size=14),
            title=dict(text="Basic Reproductive Number R0",
                       font_size=24),
            range=[0, 30]),
        xaxis=go.layout.XAxis(
            showticklabels=True,
            ticks='inside',
            tickfont=dict(size=16),
            type='log',
            title=dict(text="Infections on 1 Jan 2020",
                       font_size=24),
            range=[0, 4]),
        shapes=[
            # unfilled circle
            go.layout.Shape(
                type="circle",
                xref="x",
                yref="y",
                x0=80,
                y0=6,
                x1=240,
                y1=14,
                line=dict(color="rgb(0, 128, 0)",
                          width=3, dash='dash')),
            go.layout.Shape(
                type="circle",
                xref="x",
                yref="y",
                x0=400,
                y0=2.5,
                x1=1200,
                y1=7,
                line=dict(color="rgb(0, 128, 0)",
                          width=3, dash='dash'),
            )],
        annotations=[
            go.layout.Annotation(
                text='1',
                font=dict(
                    size=32,
                    color="rgb(0, 128, 0)"),
                align='left',
                showarrow=False,
                xref='x',
                yref='y',
                x=2,
                y=14.5,
                bordercolor=None,),
            go.layout.Annotation(
                text='2',
                font=dict(
                    size=32,
                    color="rgb(0, 128, 0)"),
                align='left',
                showarrow=False,
                xref='x',
                yref='y',
                x=3.05,
                y=2,
                bordercolor=None,)
        ])

fig.show()
with open ("checkpoints/fig1a_param_lscale.pdf", 'wb') as f:
    f.write(fig.to_image("pdf", 800, 600))

In [None]:
fig=fig.update_xaxes(range=[np.log10(10), np.log10(500)])\
    .update_yaxes(range=[5, 15], title=None)\
    .update_xaxes(range=[np.log10(80), np.log10(200)])\
    .update_layout(coloraxis=dict(showscale=False),
                   height=600, width=600)
fig.show()
with open ("checkpoints/fig1a_param_zoomed.pdf", 'wb') as f:
    f.write(fig.to_image("pdf", 600, 600))

## Draw prediction at different places

#### Eng City Name

In [None]:
# quick dirty patch!
g_eng_city_name={"重庆市":"Chongqing",
"上海市":"Shanghai",
"北京市":"Beijing",
"成都市":"Chengdu",
"天津市":"Tianjin",
"广州市":"Guangzhou",
"深圳市":"Shenzhen",
"武汉市":"Wuhan",
"南阳市":"Nanyang",
"临沂市":"Linyi",
"石家庄市":"Shijiazhuang",
"哈尔滨市":"Harbin",
"苏州市":"Suzhou",
"保定市":"Baoding",
"郑州市":"Zhengzhou",
"西安市":"Xi'an",
"邯郸市":"Handan",
"温州市":"Wenzhou",
"周口市":"Zhoukou",
"杭州市":"Hangzhou",
"徐州市":"Xuzhou",
"赣州市":"Ganzhou",
"菏泽市":"Heze",
"东莞市":"Dongguan",
"泉州市":"Quanzhou",
"南京市":"Nanjing",
"阜阳市":"Fuyang",
"商丘市":"Shangqiu",
"南通市":"Nantong",
"盐城市":"Yancheng",
"驻马店市":"Zhumadian",
"衡阳市":"Hengyang",
"沧州市":"Cangzhou",
"福州市":"Fuzhou",
"邢台市":"Xingtai",
"邵阳市":"Shaoyang",
"长沙市":"Changsha",
"湛江市":"Zhanjiang",
"南宁市":"Nanning",
"黄冈市":"Huanggang",
"南充市":"Nanyun",
"洛阳市":"Luoyang",
"上饶市":"Shangrao",
"昆明市":"Kunming",
"无锡市":"Wuxi",
"信阳市":"Xinyang",
"台州市":"Taizhou",
"常德市":"Changde",
"新乡市":"Xinxiang",
"合肥市":"Hefei",
"荆州市":"Jingzhou",
"六安市":"Liuan",
"襄阳市":"Xiangyang",
"岳阳市":"Yueyang",
"达州市":"Dazhou",
"宜春市":"Yichun",
"宿州市":"Suzhou",
"安庆市":"Anqing",
"永州市":"Yongzhou",
"安阳市":"Anyang",
"南昌市":"Nanchang",
"平顶山市":"Pingdingshan",
"亳州市":"Haozhou",
"孝感市":"Xiaogan",
"吉安市":"Ji'an",
"桂林市":"Guilin",
"怀化市":"Huaihua",
"九江市":"Jiujiang",
"开封市":"Kaifeng",
"泰州市":"Taizhou",
"惠州市":"Huizhou",
"郴州市":"Binzhou",
"扬州市":"Yangzhou",
"益阳市":"Yiyang",
"许昌市":"Xuchang",
"宜昌市":"Yichang",
"抚州市":"Fuzhou",
"株洲市":"Zhuzhou",
"娄底市":"Loudi",
"湘潭市":"Xiangtan",
"濮阳市":"Puyang",
"焦作市":"Jiaozuo",
"厦门市":"Xiamen",
"十堰市":"Shiyan",
"恩施州":"Enshi"}

### Draw Figure

In [None]:
figures = []
show_dates = True

def draw_city_prediction(city_id):
    city_name = g_popu_df.loc[g_subset_place_ids[city_id], 'City']
    predicted_infection_numbers = sim_I[city_id, :time_to_pred]
    observed_infection_numbers = accu_data[city_id]
    inlier_mark = is_inlier[city_id]
    fig_data = [
        go.Scatter(x=time_points_sim[10:], y=predicted_infection_numbers[10:],
                   line_width=16, line_color='rgba(255,0,0,0.3)'),
        go.Scatter(x=time_points_data[11:], 
                   y=observed_infection_numbers[11:], 
                   marker=dict(
                       symbol=['triangle-down' if s_ < 0 
                               else ('triangle-up' if s_ > 0 
                                     else 'circle')
                               for s_ in inlier_status[city_id, 10:]], 
                       size=18,
                       line_width=1,
                       color=['rgba(0, 192, 0, 1)' if s_ < 0 
                              else ('rgba(192, 0, 0, 1)' if s_ > 0 
                                    else 'rgba(0, 0, 192, 0.5)')
                              for s_ in inlier_status[city_id, 10:]]),
                   mode='markers'),
#         go.Scatter(x=time_points_data[inlier_mark], y=observed_infection_numbers[inlier_mark],
#                    marker=dict(size=28, line_width=4, 
#                                color='rgba(0,0,255,0.2)'),
#                    mode='markers')
    ]
    fig = go.Figure(data=fig_data)\
        .update_layout(
            autosize=False,
            showlegend=False,
            width=800,
            height=800,
            title=dict(
                text=city_name+"<br>" + g_eng_city_name[city_name],
                y=0.9, x=0.5,
                xanchor='center', yanchor='top',
                font=dict(size=60)
            ),
            margin=dict(l=100, r=0, t=0, b=0),
            yaxis=go.layout.YAxis(
                ticks='inside',
                tickfont=dict(size=40)
            ),
            xaxis=go.layout.XAxis(
                showticklabels=show_dates,
                ticks='inside',
                tickfont=dict(size=30)
            ))
    return fig
    # fname_suf = 'wd' if show_dates else ''
    # file_name = f'checkpoints/pred_num_figures_r1/{fname_suf}{city_id:03d}.pdf'
    # with open (file_name, 'wb') as f:
    #     f.write(fig.to_image("pdf", 800, 800))


In [None]:
fig = draw_city_prediction(0)
fig.show()