## Importing relevant packages for finetuning

In [2]:
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['JAX_PMAP_USE_TENSORSTORE'] = 'false'

In [3]:
import timesfm
import gc
import numpy as np
import pandas as pd
from timesfm import patched_decoder
from timesfm import data_loader

TimesFM v1.2.0. See https://github.com/google-research/timesfm/blob/master/README.md for updated APIs.
Loaded Jax TimesFM.


2024-12-29 11:29:48.811050: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


In [4]:
from tqdm import tqdm
import dataclasses
import IPython
import IPython.display
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['figure.figsize'] = (8, 6)
mpl.rcParams['axes.grid'] = False

## Loading TimesFM pretrained checkpoint

In [5]:
tfm = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend="gpu",
          per_core_batch_size=32,
          horizon_len=128,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id="google/timesfm-1.0-200m"),
  )

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

2024-12-29 11:30:02.081064: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.4 which is older than the ptxas CUDA version (12.6.85). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Multiprocessing context has already been set.
Constructing model weights.




Constructed model weights in 1.45 seconds.
Restoring checkpoint from /home/ming/.cache/huggingface/hub/models--google--timesfm-1.0-200m/snapshots/8775f7531211ac864b739fe776b0b255c277e2be/checkpoints.


ERROR:absl:For checkpoint version > 1.0, we require users to provide
          `train_state_unpadded_shape_dtype_struct` during checkpoint
          saving/restoring, to avoid potential silent bugs when loading
          checkpoints to incompatible unpadded shapes of TrainState.


Restored checkpoint in 0.63 seconds.
Jitting decoding.
Jitted decoding in 9.57 seconds.


## Evaluating pretrained checkpoint on ETT datasets

In [6]:
"""
DATA_DICT = {
    "ettm2": {
        "boundaries": [34560, 46080, 57600],
        "data_path": "../datasets/ETT-small/ETTm2.csv",
        "freq": "15min",
    },
    "ettm1": {
        "boundaries": [34560, 46080, 57600],
        "data_path": "../datasets/ETT-small/ETTm1.csv",
        "freq": "15min",
    },
    "etth2": {
        "boundaries": [8640, 11520, 14400],
        "data_path": "../datasets/ETT-small/ETTh2.csv",
        "freq": "H",
    },
    "etth1": {
        "boundaries": [8640, 11520, 14400],
        "data_path": "../datasets/ETT-small/ETTh1.csv",
        "freq": "H",
    },
    "elec": {
        "boundaries": [18413, 21044, 26304],
        "data_path": "../datasets/electricity/electricity.csv",
        "freq": "H",
    },
    "traffic": {
        "boundaries": [12280, 14036, 17544],
        "data_path": "../datasets/traffic/traffic.csv",
        "freq": "H",
    },
    "weather": {
        "boundaries": [36887, 42157, 52696],
        "data_path": "../datasets/weather/weather.csv",
        "freq": "10min",
    },
}
"""
DATA_DICT = {
    "ettm1": {
        "boundaries": [80, 100, 115],#train,val,test
        "data_path": "/home/ming/aaguolishaData/time_xulie/timedata.csv",
        "freq": "D",
    }}



In [7]:
dataset = "ettm1"
data_path = DATA_DICT[dataset]["data_path"]
freq = DATA_DICT[dataset]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset]["boundaries"]

#data_df = pd.read_csv(open(data_path, "r"))

# 加载数据并处理缺失日期
data_df = pd.read_csv(open(data_path, "r"))

# 输出补全数据之前的数据条数
print("补全数据之前的数据条数:", len(data_df))
print(data_df.head())

data_df['date'] = pd.to_datetime(data_df['date'])
date_range = pd.date_range(start=data_df['date'].min(), end=data_df['date'].max(), freq='D')
data_df = data_df.set_index('date').reindex(date_range, fill_value=0).rename_axis('date').reset_index()

# 输出补全数据之后的数据条数
print("补全数据之后的数据条数:", len(data_df))
# 查看处理后的数据
print(data_df.head())

补全数据之前的数据条数: 116
        date  value
0  2014/7/21      1
1   2014/8/8      1
2  2014/8/12      2
3  2014/8/16      2
4  2014/8/18      2
补全数据之后的数据条数: 164
        date  value
0 2014-07-21      1
1 2014-07-22      0
2 2014-07-23      0
3 2014-07-24      0
4 2014-07-25      0


In [None]:

#ts_cols = [col for col in data_df.columns if col != "date"]
ts_cols = ['value']
num_cov_cols = []
cat_cov_cols = []

#num_cov_cols = None
#cat_cov_cols = None


context_len = 30
pred_len = 5

num_ts = len(ts_cols)
print("^^^^^^")
print(num_ts)
batch_size = 2

#初始化一个时间序列数据加载器对象 dtl，用于加载和处理时间序列数据
dtl = data_loader.TimeSeriesdata(
      data_path=data_path,
      datetime_col="date",
      num_cov_cols=num_cov_cols,
      cat_cov_cols=cat_cov_cols,
      ts_cols=np.array(ts_cols),

      train_range=[0, boundaries[0]],
      val_range=[boundaries[0], boundaries[1]],
      test_range=[boundaries[1], boundaries[2]],
      
      hist_len=context_len,
      pred_len=pred_len,
      batch_size=num_ts,
      freq=freq,
      normalize=False,
      epoch_len=None,
      holiday=False,
      permute=True,
  )

In [None]:
print(data_df.columns)#确认有没有协变量(查看是否存在除了 date 和 ts_cols 之外的其他列)

##查看原始训练数据的前10行
train_range = [0, boundaries[0]]
# 假设 data_df 是加载原始数据框
# 使用 train_range 提取训练数据
train_data = data_df.iloc[train_range[0]:train_range[1]]

print(train_data.head(10))

In [131]:
#时间序列数据加载器对象 dtl
##dtl.tf_dataset() 方法将原始数据转换为 TensorFlow 数据集（tf.data.Dataset 对象）
train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
val_batches = dtl.tf_dataset(mode="val", shift=pred_len)
test_batches = dtl.tf_dataset(mode="test", shift=pred_len)

如何理解batch?它的形状是什么？
 (batch_size, num_ts, context_len)，其中：

batch_size 是批次大小。
num_ts 是时间序列的数量。
context_len 是历史数据长度

batch 是一个包含多个元素的元组或列表，每个元素代表不同的部分。例如：

batch[0]：历史数据（past），用于模型输入。
batch[1]：协变量数据（covariates），通常是额外的特征数据。
batch[2]：目标值（targets），用于模型训练或评估。
batch[3]：实际值（actuals），用于与模型预测值进行比较。
因此，batch[3] 是 batch 的第四个元素，表示实际值。


In [None]:
for tbatch in tqdm(train_batches.as_numpy_iterator()):
    pass
print(tbatch[0].shape)

#tiao_shi_dai_ma
# 遍历训练集数据批次并打印最后一个批次中第一个元素的形状
for tbatch in tqdm(train_batches.as_numpy_iterator()):
    #train_batches.as_numpy_iterator() 将 TensorFlow 数据集转换为 NumPy 数组
    # 打印第一个批次的数据
    print("tbatch[0]\n", tbatch[0])#"历史数据（past）
    print("tbatch[1]\n", tbatch[1])#变量数据（covariates）
    print("tbatch[2]\n", tbatch[2])#目标值（targets
    break  # 只打印第一个批次，避免输出过多数据
print("最后一个批次中第一个元素的形状：", tbatch[0].shape)

### MAE on the test split for the pretrained TimesFM model

In [None]:
mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):
    past = batch[0]
    actuals = batch[3]

    forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)
    forecasts = forecasts[:, 0 : actuals.shape[1]]
    mae_losses.append(np.abs(forecasts - actuals).mean())

print(f"MAE: {np.mean(mae_losses)}")

In [None]:

#我改代码
for batch in tqdm(test_batches.as_numpy_iterator()):
    past = batch[0]
    actuals = batch[3]

    # 输出查看 past 和 actuals
    print("历史数据（past）：", past)
    print("实际值（actuals）：", actuals)
    
    forecasts, _ = tfm.forecast(list(past), [0] * past.shape[0], normalize=True)
    forecasts = forecasts[:, 0 : actuals.shape[1]]
    
    # 输出预测值
    print("预测值（forecasts）：", forecasts)

    mae_losses.append(np.abs(forecasts - actuals).mean())
    break  # 如果只想查看第一个批次, 可以使用 break 退出循环

In [None]:
mae_losses = []
sample_count = 0  # 用于记录样本编号

for batch in tqdm(test_batches.as_numpy_iterator()):
    past = batch[0]
    actuals = batch[3]
    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])
    forecasts = forecasts[:, 0 : actuals.shape[1], 5]
    
    # 计算当前样本的 MAE
    sample_mae = np.abs(forecasts - actuals).mean()
    mae_losses.append(sample_mae)
    
    # 输出当前样本的信息
    sample_count += 1
    print(f"样本{sample_count}：")
    print("历史数据（past）：", past.tolist())  # 将数组转换为列表以便输出
    print("实际值（actuals）：", actuals.tolist())
    print("预测值（forecasts）：", forecasts.tolist())
    print(f"样本{sample_count} MAE: {sample_mae}")
    print("-" * 50)  # 分隔线

# 输出所有样本的平均 MAE
mae_mean = np.mean(mae_losses)
print(f"所有样本平均 MAE: MAE_mean: {mae_mean}")


## Finetuning the model on the ETT dataset

In [136]:
import jax
from jax import numpy as jnp
from praxis import pax_fiddle
from praxis import py_utils
from praxis import pytypes
from praxis import base_model
from praxis import optimizers
from praxis import schedules
from praxis import base_hyperparams
from praxis import base_layer
from paxml import tasks_lib
from paxml import trainer_lib
from paxml import checkpoints
from paxml import learners
from paxml import partitioning
from paxml import checkpoint_types

In [137]:
# PAX shortcuts
NestedMap = py_utils.NestedMap
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
InstantiableParams = py_utils.InstantiableParams
JTensor = pytypes.JTensor
NpTensor = pytypes.NpTensor
WeightedScalars = pytypes.WeightedScalars
instantiate = base_hyperparams.instantiate
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
AuxLossStruct = base_layer.AuxLossStruct

AUX_LOSS = base_layer.AUX_LOSS
template_field = base_layer.template_field

# Standard prng key names
PARAMS = base_layer.PARAMS
RANDOM = base_layer.RANDOM

key = jax.random.PRNGKey(seed=1234)

In [138]:
model = pax_fiddle.Config(
    patched_decoder.PatchedDecoderFinetuneModel,
    name='patched_decoder_finetune',
    core_layer_tpl=tfm.model_p,
)

### We will hold the transformer layers fixed while finetuning, while training all other components.

In [139]:
@pax_fiddle.auto_config
def build_learner() -> learners.Learner:
  return pax_fiddle.Config(
      learners.Learner,
      name='learner',
      loss_name='avg_qloss',
      optimizer=optimizers.Adam(
          epsilon=1e-7,
          clip_threshold=1e2,
          learning_rate=1e-2,
          lr_schedule=pax_fiddle.Config(
              schedules.Cosine,
              initial_value=1e-3,
              final_value=1e-4,
              total_steps=40000,
          ),
          ema_decay=0.9999,
      ),
      # Linear probing i.e we hold the transformer layers fixed.
      bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],
  )

In [140]:
task_p = tasks_lib.SingleTask(
    name='ts-learn',
    model=model,
    train=tasks_lib.SingleTask.Train(
        learner=build_learner(),
    ),
)

In [None]:
task_p.model.ici_mesh_shape = [1, 1, 1]
task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']

DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])

num_devices = jax.local_device_count()
print(f'num_devices: {num_devices}')
print(f'device kind: {jax.local_devices()[0].device_kind}')

In [None]:
print(batch[0].shape)

jax_task = task_p
key, init_key = jax.random.split(key)

# To correctly prepare a batch of data for model initialization (now that shape
# inference is merged), we take one devices*batch_size tensor tuple of data,
# slice out just one batch, then run the prepare_input_batch function over it.


def process_train_batch(batch):
    past_ts = batch[0].reshape(batch_size * num_ts, -1)
    actual_ts = batch[3].reshape(batch_size * num_ts, -1)
    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)


def process_eval_batch(batch):
    past_ts = batch[0]
    actual_ts = batch[3]
    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)


jax_model_states, _ = trainer_lib.initialize_model_state(
    jax_task,
    init_key,
    process_train_batch(tbatch),
    checkpoint_type=checkpoint_types.CheckpointType.GDA,
)

### Setting the initial model weights to the pretrained TimesFM parameters.

In [None]:
jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
jax_vars = jax_model_states.mdl_vars
gc.collect()

### Training loop

In [144]:
jax_task = task_p


def train_step(states, prng_key, inputs):
  return trainer_lib.train_step_single_learner(
      jax_task, states, prng_key, inputs
  )


def eval_step(states, prng_key, inputs):
  states = states.to_eval_state()
  return trainer_lib.eval_step_single_learner(
      jax_task, states, prng_key, inputs
  )

key, train_key, eval_key = jax.random.split(key, 3)
train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())
eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())

p_train_step = jax.pmap(train_step, axis_name='batch')
p_eval_step = jax.pmap(eval_step, axis_name='batch')

In [145]:
replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)
replicated_jax_vars = replicated_jax_states.mdl_vars

In [146]:
best_eval_loss = 1e7
step_count = 0
patience = 0
NUM_EPOCHS = 20
PATIENCE = 5
TRAIN_STEPS_PER_EVAL = 1000
#CHECKPOINT_DIR='/home/senrajat_google_com/ettm1_finetune'
CHECKPOINT_DIR='/home/ming/aaguolishaData/time_xulie/tfs_model'

In [147]:
def reshape_batch_for_pmap(batch, num_devices):
  def _reshape(input_tensor):
    bsize = input_tensor.shape[0]
    residual_shape = list(input_tensor.shape[1:])
    nbsize = bsize // num_devices
    return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)

  return jax.tree.map(_reshape, batch)

In [150]:
import shutil  # 导入 shutil 模块

In [None]:
for epoch in range(NUM_EPOCHS):
    print(f"__________________Epoch: {epoch}__________________", flush=True)
    train_its = train_batches.as_numpy_iterator()
    if patience >= PATIENCE:
        print("Early stopping.", flush=True)
        break
    for batch in tqdm(train_its):
        train_losses = []
        if patience >= PATIENCE:
            print("Early stopping.", flush=True)
            break
        tbatch = process_train_batch(batch)
        # process_train_batch协变量数据的处理
        tbatch = reshape_batch_for_pmap(tbatch, num_devices)
        replicated_jax_states, step_fun_out = p_train_step(
            replicated_jax_states, train_prng_seed, tbatch
        )
        train_losses.append(step_fun_out.loss[0])
        if step_count % TRAIN_STEPS_PER_EVAL == 0:
            print(
                f"Train loss at step {step_count}: {np.mean(train_losses)}",
                flush=True,
            )
            train_losses = []
            print("Starting eval.", flush=True)
            val_its = val_batches.as_numpy_iterator()
            eval_losses = []
            for ev_batch in tqdm(val_its):
                ebatch = process_eval_batch(ev_batch)
                ebatch = reshape_batch_for_pmap(ebatch, num_devices)
                _, step_fun_out = p_eval_step(
                    replicated_jax_states, eval_prng_seed, ebatch
                )
                eval_losses.append(step_fun_out.loss[0])
            mean_loss = np.mean(eval_losses)
            print(f"Eval loss at step {step_count}: {mean_loss}", flush=True)
            if mean_loss < best_eval_loss or np.isnan(mean_loss):
                best_eval_loss = mean_loss
                print("Saving checkpoint.")
                jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(
                    replicated_jax_states
                )
                # 在保存检查点之前，手动删除旧的检查点
                checkpoint_path = os.path.join(CHECKPOINT_DIR, f"checkpoint_1")
                if os.path.exists(checkpoint_path):
                    shutil.rmtree(checkpoint_path)
                ## 保存新的检查点
                checkpoints.save_checkpoint(
                    jax_state_for_saving, CHECKPOINT_DIR, overwrite=True
                )
                patience = 0
                del jax_state_for_saving
                gc.collect()
            else:
                patience += 1
                print(f"patience: {patience}")
        step_count += 1

## Loading and evaluating the best (according to validation loss) finetuned checkpoint

In [None]:
train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)
print(train_state.step)
tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']
tfm.jit_decode()


In [None]:
mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):
    past = batch[0]
    actuals = batch[3]
    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])
    forecasts = forecasts[:, 0 : actuals.shape[1], 5]
    mae_losses.append(np.abs(forecasts - actuals).mean())

print(f"MAE: {np.mean(mae_losses)}")

In [None]:
mae_losses = []
sample_count = 0  # 用于记录样本编号

for batch in tqdm(test_batches.as_numpy_iterator()):
    past = batch[0]
    actuals = batch[3]
    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])
    forecasts = forecasts[:, 0 : actuals.shape[1], 5]
    
    # 计算当前样本的 MAE
    sample_mae = np.abs(forecasts - actuals).mean()
    mae_losses.append(sample_mae)
    
    # 输出当前样本的信息
    sample_count += 1
    print(f"样本{sample_count}：")
    print("历史数据（past）：", past.tolist())  # 将数组转换为列表以便输出
    print("实际值（actuals）：", actuals.tolist())
    print("预测值（forecasts）：", forecasts.tolist())
    print(f"样本{sample_count} MAE: {sample_mae}")
    print("-" * 50)  # 分隔线

# 输出所有样本的平均 MAE
mae_mean = np.mean(mae_losses)
print(f"所有样本平均 MAE: MAE_mean: {mae_mean}")


## There is around a __9%__ reduction in MAE from finetuning.