In [None]:
from imports import *
from functions_uc import *

In [None]:
timesfm_backend = "gpu"  # @param

tfm = timesfm.TimesFm(
      hparams=timesfm.TimesFmHparams(
          backend=timesfm_backend,
          horizon_len=prediction_length,
          num_layers=50,
          use_positional_embedding=False,
          context_len=32,
      ),
      checkpoint=timesfm.TimesFmCheckpoint(
          huggingface_repo_id="google/timesfm-2.0-500m-jax"),
  )

In [None]:
DATA_DICT = {
    "data_ume": {
        "boundaries": [8164,10398,13654],
        "data_path": "./data/retrain_timesffm.csv",
        "freq": frequency,
    },

}
dataset = "data_ume"
data_path = DATA_DICT[dataset]["data_path"]
freq = DATA_DICT[dataset]["freq"]
int_freq = timesfm.freq_map(freq)
boundaries = DATA_DICT[dataset]["boundaries"]

data_df = pd.read_csv(open(data_path, "r"))

ts_cols = [col for col in data_df.columns if col not in ["ds", "unique_id"]]

num_cov_cols = None
cat_cov_cols = None

context_len = context_len
pred_len = prediction_length

num_ts = len(ts_cols)
batch_size = 6

dtl = data_loader.TimeSeriesdata(
      data_path=data_path,
      datetime_col="ds",
      num_cov_cols=num_cov_cols,
      cat_cov_cols=cat_cov_cols,
      ts_cols=np.array(ts_cols),
      train_range=[0, boundaries[0]],
      val_range=[boundaries[0], boundaries[1]],
      test_range=[boundaries[1], boundaries[2]],
      hist_len=context_len,
      pred_len=pred_len,
      batch_size=num_ts,
      freq=freq,
      normalize=True,
      epoch_len=None,
      holiday=False,
      permute=True,
  )
train_batches = dtl.tf_dataset(mode="train", shift=1).batch(batch_size)
val_batches = dtl.tf_dataset(mode="val", shift=pred_len)
test_batches = dtl.tf_dataset(mode="test", shift=pred_len)

# PAX shortcuts
NestedMap = py_utils.NestedMap
WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
InstantiableParams = py_utils.InstantiableParams
JTensor = pytypes.JTensor
NpTensor = pytypes.NpTensor
WeightedScalars = pytypes.WeightedScalars
instantiate = base_hyperparams.instantiate
LayerTpl = pax_fiddle.Config[base_layer.BaseLayer]
AuxLossStruct = base_layer.AuxLossStruct

AUX_LOSS = base_layer.AUX_LOSS
template_field = base_layer.template_field

# Standard prng key names
PARAMS = base_layer.PARAMS
RANDOM = base_layer.RANDOM

key = jax.random.PRNGKey(seed=1234)

model = pax_fiddle.Config(
    patched_decoder.PatchedDecoderFinetuneModel,
    name='patched_decoder_finetune',
    core_layer_tpl=tfm.model_p,
)

@pax_fiddle.auto_config
def build_learner_finetune(learning_rate=1e-3, clip_threshold=50) -> learners.Learner:
    return pax_fiddle.Config(
        learners.Learner,
        name='learner',
        loss_name='avg_qloss',
        optimizer=optimizers.Adam(
            epsilon=1e-7,
            clip_threshold=clip_threshold, 
            learning_rate=learning_rate,   
            lr_schedule=pax_fiddle.Config(
                schedules.Cosine,
                initial_value=learning_rate,
                final_value=learning_rate * 0.1,
                total_steps=40000,
            ),
            ema_decay=0.9999,
        ),
        bprop_variable_exclusion=['.*/stacked_transformer_layer/.*'],  # linear probing
    )

task_p = tasks_lib.SingleTask(
    name='ts-learn',
    model=model,
    train=tasks_lib.SingleTask.Train(
        learner=build_learner_finetune(
            learning_rate=1e-3,
            clip_threshold=50.0,
        ),
    ),
)


task_p.model.ici_mesh_shape = [1, 1, 1]
task_p.model.mesh_axis_names = ['replica', 'data', 'mdl']

DEVICES = np.array(jax.devices()).reshape([1, 1, 1])
MESH = jax.sharding.Mesh(DEVICES, ['replica', 'data', 'mdl'])

num_devices = jax.local_device_count()
print(f'num_devices: {num_devices}')
print(f'device kind: {jax.local_devices()[0].device_kind}')

jax_task = task_p
key, init_key = jax.random.split(key)




def process_train_batch(batch):
    past_ts = batch[0].reshape(batch_size * num_ts, -1)
    actual_ts = batch[3].reshape(batch_size * num_ts, -1)
    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)


def process_eval_batch(batch):
    past_ts = batch[0]
    actual_ts = batch[3]
    return NestedMap(input_ts=past_ts, actual_ts=actual_ts)


first_train_batch = next(train_batches.as_numpy_iterator())
tbatch = process_train_batch(first_train_batch)

jax_model_states, _ = trainer_lib.initialize_model_state(
    jax_task,
    init_key,
    tbatch,
    checkpoint_type=checkpoint_types.CheckpointType.GDA,
)
jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
jax_vars = jax_model_states.mdl_vars
gc.collect()

jax_task = task_p


def train_step(states, prng_key, inputs):
  return trainer_lib.train_step_single_learner(
      jax_task, states, prng_key, inputs
  )


def eval_step(states, prng_key, inputs):
  states = states.to_eval_state()
  return trainer_lib.eval_step_single_learner(
      jax_task, states, prng_key, inputs
  )

key, train_key, eval_key = jax.random.split(key, 3)
train_prng_seed = jax.random.split(train_key, num=jax.local_device_count())
eval_prng_seed = jax.random.split(eval_key, num=jax.local_device_count())

p_train_step = jax.pmap(train_step, axis_name='batch')
p_eval_step = jax.pmap(eval_step, axis_name='batch')

replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)
replicated_jax_vars = replicated_jax_states.mdl_vars

best_eval_loss = 1e7
step_count = 0
patience = 0
NUM_EPOCHS = 10
PATIENCE = 5
TRAIN_STEPS_PER_EVAL = 1000
CHECKPOINT_DIR='./timesfm_finetune'

def reshape_batch_for_pmap(batch, num_devices):
  def _reshape(input_tensor):
    bsize = input_tensor.shape[0]
    residual_shape = list(input_tensor.shape[1:])
    nbsize = bsize // num_devices
    return jnp.reshape(input_tensor, [num_devices, nbsize] + residual_shape)

  return jax.tree.map(_reshape, batch)

for epoch in range(NUM_EPOCHS):
    print(f"__________________Epoch: {epoch}__________________", flush=True)
    train_its = train_batches.as_numpy_iterator()
    if patience >= PATIENCE:
        print("Early stopping.", flush=True)
        break
    for batch in tqdm(train_its):
        train_losses = []
        if patience >= PATIENCE:
            print("Early stopping.", flush=True)
            break
        tbatch = process_train_batch(batch)
        tbatch = reshape_batch_for_pmap(tbatch, num_devices)
        replicated_jax_states, step_fun_out = p_train_step(
            replicated_jax_states, train_prng_seed, tbatch
        )
        train_losses.append(step_fun_out.loss[0])
        if step_count % TRAIN_STEPS_PER_EVAL == 0:
            print(
                f"Train loss at step {step_count}: {np.mean(train_losses)}",
                flush=True,
            )
            train_losses = []
            print("Starting eval.", flush=True)
            val_its = val_batches.as_numpy_iterator()
            eval_losses = []
            for ev_batch in tqdm(val_its):
                ebatch = process_eval_batch(ev_batch)
                ebatch = reshape_batch_for_pmap(ebatch, num_devices)
                _, step_fun_out = p_eval_step(
                    replicated_jax_states, eval_prng_seed, ebatch
                )
                eval_losses.append(step_fun_out.loss[0])
            mean_loss = np.mean(eval_losses)
            print(f"Eval loss at step {step_count}: {mean_loss}", flush=True)
            if mean_loss < best_eval_loss or np.isnan(mean_loss):
                best_eval_loss = mean_loss
                print("Saving checkpoint.")
                jax_state_for_saving = py_utils.maybe_unreplicate_for_fully_replicated(
                    replicated_jax_states
                )
                checkpoints.save_checkpoint(
                    jax_state_for_saving, CHECKPOINT_DIR, overwrite=True
                )
                patience = 0
                del jax_state_for_saving
                gc.collect()
            else:
                patience += 1
                print(f"patience: {patience}")
        step_count += 1

train_state = checkpoints.restore_checkpoint(jax_model_states, CHECKPOINT_DIR)
print(train_state.step)
tfm._train_state.mdl_vars['params'] = train_state.mdl_vars['params']['core_layer']
tfm.jit_decode()

mae_losses = []
for batch in tqdm(test_batches.as_numpy_iterator()):
    past = batch[0]
    actuals = batch[3]
    _, forecasts = tfm.forecast(list(past), [0] * past.shape[0])
    forecasts = forecasts[:, 0 : actuals.shape[1], 5]
    mae_losses.append(np.abs(forecasts - actuals).mean())

print(f"MAE: {np.mean(mae_losses)}")

def objective(trial):
    # Suggest hyperparameters to tune
    learning_rate = trial.suggest_loguniform("learning_rate", 1e-5, 1e-2)
    clip_threshold = trial.suggest_uniform("clip_threshold", 1.0, 100.0)

    print(f"\n🔍 Trial {trial.number}: learning_rate={learning_rate:.5e}, clip_threshold={clip_threshold:.2f}")

    # Swap in the new learner with trial parameters
    task_p.train.learner = build_learner_finetune(
        learning_rate=learning_rate,
        clip_threshold=clip_threshold,
    )

    # Re-initialize model state for this trial
    key = jax.random.PRNGKey(seed=trial.number)
    key, init_key = jax.random.split(key)
    tbatch = process_train_batch(next(train_batches.as_numpy_iterator()))

    jax_model_states, _ = trainer_lib.initialize_model_state(
        task_p,
        init_key,
        tbatch,
        checkpoint_type=checkpoint_types.CheckpointType.GDA,
    )
    jax_model_states.mdl_vars['params']['core_layer'] = tfm._train_state.mdl_vars['params']
    replicated_jax_states = trainer_lib.replicate_model_state(jax_model_states)

    # Train for a few steps
    train_losses = []
    eval_losses = []
    step_count = 0
    NUM_STEPS = 1000
    train_prng_seed = jax.random.split(init_key, num=jax.local_device_count())
    eval_prng_seed = jax.random.split(init_key, num=jax.local_device_count())

    train_iterator = train_batches.as_numpy_iterator()
    for batch in train_iterator:
        if step_count >= NUM_STEPS:
            break
        tbatch = process_train_batch(batch)
        tbatch = reshape_batch_for_pmap(tbatch, num_devices)
        replicated_jax_states, step_fun_out = p_train_step(
            replicated_jax_states, train_prng_seed, tbatch
        )
        train_losses.append(step_fun_out.loss[0])
        step_count += 1

    # Evaluate once
    val_iterator = val_batches.as_numpy_iterator()
    for val_batch in val_iterator:
        ebatch = process_eval_batch(val_batch)
        ebatch = reshape_batch_for_pmap(ebatch, num_devices)
        _, step_fun_out = p_eval_step(replicated_jax_states, eval_prng_seed, ebatch)
        eval_losses.append(step_fun_out.loss[0])

    # Return mean eval loss as trial objective
    mean_eval_loss = float(np.mean(eval_losses))
    print(f"Trial {trial.number} Eval Loss (MAE): {mean_eval_loss:.6f}")

    gc.collect()  # Optional: free memory
    return mean_eval_loss


def run_optuna_timesfm(n_trials=10):
    study = optuna.create_study(direction="minimize")
    study.optimize(objective, n_trials=n_trials)

    print("Best Trial Summary:")
    print(study.best_trial)

    print("Best Hyperparameters:")
    for k, v in study.best_trial.params.items():
        print(f"  {k}: {v}")

    return study

train_timesffm = run_optuna_timesfm(n_trials=10)
