## Load the checkpoint

**Notice:** Please set up the backend as per your machine ("cpu", "gpu" or "tpu"). This notebook will run by default on GPU.

We load the 2.0-500m model checkpoint from HuggingFace.

In [None]:
import timesfm
timesfm_backend = "gpu"  # @param

model = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend=timesfm_backend,
          per_core_batch_size=32,
          horizon_len=128,
          num_layers=50,
          use_positional_embedding=False,
          context_len=2048,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id="google/timesfm-2.0-500m-pytorch"),
  )

In [1]:
import pandas as pd
import numpy as np
from collections import defaultdict

In [None]:

df = pd.read_csv('raw_day_data.csv')
df['onset_date'] = pd.to_datetime(df['onset_date'],format='%Y-%m-%d')
max_onset_date = df['onset_date'].max()
df = df[df['onset_date']>(max_onset_date - pd.Timedelta(days=364))]

Unnamed: 0,disease,county,onset_date,value,last_n_days,last_n_days_neighbor,weekday,month,year
0,发热伴,东平县,2013-04-22,0,0,0.000000,0,4,2013
1,发热伴,东平县,2013-04-23,0,0,0.000000,1,4,2013
2,发热伴,东平县,2013-04-24,0,0,0.000000,2,4,2013
3,发热伴,东平县,2013-04-25,0,0,0.000000,3,4,2013
4,发热伴,东平县,2013-04-26,0,0,0.000000,4,4,2013
...,...,...,...,...,...,...,...,...,...
3667575,肾综合,龙口市,2024-12-04,0,0,0.000000,2,12,2024
3667576,肾综合,龙口市,2024-12-05,0,0,0.000000,3,12,2024
3667577,肾综合,龙口市,2024-12-06,0,0,0.608152,4,12,2024
3667578,肾综合,龙口市,2024-12-07,0,0,0.000000,5,12,2024


In [None]:
counties = df['county'].unique()
diseases = df['disease'].unique()
# df_counties = pd.read_csv('df_counties.csv')

In [None]:
# Data pipelining
def get_batched_data_fn(
    disease: str,
    county: str,
    batch_size: int = 128, 
    context_len: int = 364, 
    horizon_len: int = 7,
):
  examples = defaultdict(list)
  if county is not None:
    counties_ = [county]
  else:
    counties_ = counties.copy()
  num_examples = 0
  for county in counties_:
    sub_df = df[(df["disease"] == disease) & (df["county"] == county)]
    
    for start in range(0, len(sub_df) - (context_len + horizon_len), horizon_len):
      num_examples += 1
      examples["disease"].append(disease)
      examples["county"].append(county)
      examples["inputs"].append(sub_df["value"][start:(context_end := start + context_len)].tolist())
      examples["outputs"].append(sub_df["value"][context_end:(context_end + horizon_len)].tolist())
  
  def data_fn():
    for i in range(1 + (num_examples - 1) // batch_size):
      yield {k: v[(i * batch_size) : ((i + 1) * batch_size)] for k, v in examples.items()}
  
  return data_fn


In [None]:
import time
from tqdm.notebook import tqdm
# from statsmodels.tsa.arima.model import ARIMA

# Benchmark
batch_size = 128
context_len = 364
horizon_len = 7

df_data = defaultdict(list)
metrics_mean = defaultdict(list)
for disease in tqdm(diseases):
  input_data = get_batched_data_fn(disease, None, batch_size, context_len, horizon_len)
  metrics = defaultdict(list)

  for i, example in enumerate(input_data()):

    raw_forecast, _ = model.forecast(
        inputs=example["inputs"], freq=[0] * len(example["inputs"])
    )
    raw_forecast = np.array(raw_forecast).astype(np.int16)
    raw_forecast[raw_forecast < 0] = 0
    
    

    df_data["disease"].extend(np.reshape([example["disease"]]*horizon_len, shape=-1, order='F'))
    df_data["county"].extend(np.reshape([example["county"]]*horizon_len, shape=-1, order='F'))
    df_data["onset_date"].extend([item for sublist in example["onset_date"] for item in sublist[-horizon_len:]])
    # df_data["number_of_cases"].extend(np.reshape(example["outputs"], shape=-1, order='C'))
    df_data["raw_timesfm"].extend(np.reshape(raw_forecast[:,:horizon_len], shape=-1, order='C'))
    

pd.DataFrame(data=df_data).to_csv('df_data_raw_timesfm.csv')

