In [1]:
!git clone https://github.com/google-research/google-research.git

Cloning into 'google-research'...
remote: Enumerating objects: 73342, done.[K
remote: Counting objects: 100% (458/458), done.[K
remote: Compressing objects: 100% (336/336), done.[K
remote: Total 73342 (delta 143), reused 403 (delta 104), pack-reused 72884[K
Receiving objects: 100% (73342/73342), 596.49 MiB | 29.80 MiB/s, done.
Resolving deltas: 100% (44830/44830), done.
Updating files: 100% (18352/18352), done.


In [2]:
!mv /content/google-research/tsmixer/tsmixer_basic/data_loader.py /content

In [3]:
!mv /content/google-research/tsmixer/tsmixer_basic/models /content

In [4]:
import argparse
import os
import time
import pandas as pd
import numpy as np
import tensorflow as tf
import glob
from sklearn.model_selection import train_test_split
from data_loader import TSFDataLoader  # You may need to adapt this import
import models  # You may need to adapt this import
from models import tsmixer

In [6]:
args_data = 'word_embedding_covid_pred'
args_model = 'tsmixer'
args_delete_checkpoint = False
args_train_epochs = 100
args_batch_size = 32
args_seq_len = 5
args_pred_len = 5
args_feature_type = 'M'
args_target = 'num_patients'
args_norm_type = 'B'
args_activation = 'relu'
args_dropout = 0.05
args_n_block = 2
args_ff_dim = 2048
args_kernel_size = 4
args_learning_rate = 0.0001
args_checkpoint_dir = './checkpoints/'
args_patience = 5
args_result_path = 'result.csv'

In [9]:
def main():
  exp_id = f'{args_data}_{args_feature_type}_{args_model}_sl{args_seq_len}_pl{args_pred_len}_lr{args_learning_rate}_nt{args_norm_type}_{args_activation}_nb{args_n_block}_dp{args_dropout}_fd{args_ff_dim}'

  # load datasets
  data_loader = TSFDataLoader(
      args_data,
      args_batch_size,
      args_seq_len,
      args_pred_len,
      args_feature_type,
      args_target,
  )
  train_data = data_loader.get_train()
  val_data = data_loader.get_val()
  test_data = data_loader.get_test()

  # train model
  if 'tsmixer' in args_model:
    build_model = tsmixer.build_model
    model = build_model(
        input_shape=(args_seq_len, data_loader.n_feature),
        pred_len=args_pred_len,
        norm_type=args_norm_type,
        activation=args_activation,
        dropout=args_dropout,
        n_block=args_n_block,
        ff_dim=args_ff_dim,
        target_slice=data_loader.target_slice,
    )
  elif args_model == 'full_linear':
    model = models.full_linear.Model(
        n_channel=data_loader.n_feature,
        pred_len=args_pred_len,
    )
  elif args_model == 'cnn':
    model = models.cnn.Model(
        n_channel=data_loader.n_feature,
        pred_len=args_pred_len,
        kernel_size=args_kernel_size,
    )
  else:
    raise ValueError(f'Model not supported: {args_model}')

  optimizer = tf.keras.optimizers.Adam(learning_rate=args_learning_rate)
  model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
  checkpoint_path = os.path.join(args_checkpoint_dir, f'{exp_id}_best')
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
      filepath=checkpoint_path,
      verbose=1,
      save_best_only=True,
      save_weights_only=True,
  )
  early_stop_callback = tf.keras.callbacks.EarlyStopping(
      monitor='val_loss', patience=args_patience
  )
  start_training_time = time.time()
  history = model.fit(
      train_data,
      epochs=args_train_epochs,
      validation_data=val_data,
      callbacks=[checkpoint_callback, early_stop_callback],
  )
  end_training_time = time.time()
  elasped_training_time = end_training_time - start_training_time
  print(f'Training finished in {elasped_training_time} secconds')

  # evaluate best model
  best_epoch = np.argmin(history.history['val_loss'])
  model.load_weights(checkpoint_path)
  test_result = model.evaluate(test_data)
  if args_delete_checkpoint:
    for f in glob.glob(checkpoint_path + '*'):
      os.remove(f)

  # save result to csv
  data = {
      'data': [args_data],
      'model': [args_model],
      'seq_len': [args_seq_len],
      'pred_len': [args_pred_len],
      'lr': [args_learning_rate],
      'mse': [test_result[0]],
      'mae': [test_result[1]],
      'val_mse': [history.history['val_loss'][best_epoch]],
      'val_mae': [history.history['val_mae'][best_epoch]],
      'train_mse': [history.history['loss'][best_epoch]],
      'train_mae': [history.history['mae'][best_epoch]],
      'training_time': elasped_training_time,
      'norm_type': args_norm_type,
      'activation': args_activation,
      'n_block': args_n_block,
      'dropout': args_dropout,
  }
  if 'TSMixer' in args_model:
    data['ff_dim'] = args_ff_dim

  df = pd.DataFrame(data)
  if os.path.exists(args_result_path):
    df.to_csv(args_result_path, mode='a', index=False, header=False)
  else:
    df.to_csv(args_result_path, mode='w', index=False, header=True)

In [11]:
if __name__ == '__main__':
    main()

Epoch 1/100
Epoch 1: val_loss improved from inf to 13.28625, saving model to ./checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best
Epoch 2/100
Epoch 2: val_loss improved from 13.28625 to 11.05750, saving model to ./checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best
Epoch 3/100
Epoch 3: val_loss improved from 11.05750 to 9.31535, saving model to ./checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best
Epoch 4/100
Epoch 4: val_loss improved from 9.31535 to 7.96931, saving model to ./checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best
Epoch 5/100
Epoch 5: val_loss improved from 7.96931 to 6.90897, saving model to ./checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best
Epoch 6/100
Epoch 6: val_loss improved from 6.90897 to 6.07041, saving model to ./checkpoints/word_embedd

In [12]:
df = pd.read_csv('/content/result.csv')
df.head()

Unnamed: 0,data,model,seq_len,pred_len,lr,mse,mae,val_mse,val_mae,train_mse,train_mae,training_time,norm_type,activation,n_block,dropout
0,word_embedding_covid_pred,tsmixer,5,5,0.0001,2.724948,1.295467,1.519811,0.997578,0.569863,0.553838,65.881088,B,relu,2,0.05


In [None]:
/content/checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best.data-00000-of-00001
/content/checkpoints/word_embedding_covid_pred_M_tsmixer_sl5_pl5_lr0.0001_ntB_relu_nb2_dp0.05_fd2048_best.index
/content/checkpoints/checkpoint