In [None]:
# This mounts your Google Drive to the Colab VM.
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

foldername = '/content/drive/My" Drive/trans/transformer-glucose/'
assert foldername is not None, "[!] Enter the foldername."

# add path to .py code
import sys
sys.path.append(foldername)

In [None]:
# load libraries
import numpy as np
import pandas as pd
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset

from gluformer.attention import *
from gluformer.encoder import *
from gluformer.decoder import *
from gluformer.embed import *
from gluformer.model import *

from gludata.data_loader import *
from utils.train import *
from utils.test import *

In [None]:
# load data
PATH = os.getcwd() + '/gludata/data/'
BATCH_SIZE = 32
LEN_PRED = 12
LEN_LABEL = 60
LEN_SEQ = 180

train_data = CGMData(PATH, 'train', [LEN_SEQ, LEN_LABEL, LEN_PRED])
train_data_loader = DataLoader(train_data, 
                               batch_size=BATCH_SIZE, 
                               shuffle=True, 
                               num_workers=0, 
                               drop_last=True, 
                               collate_fn = collate_fn_custom)

val_data = CGMData(PATH, 'val', [LEN_SEQ, LEN_LABEL, LEN_PRED])
val_data_loader = DataLoader(val_data, 
                             batch_size=BATCH_SIZE, 
                             shuffle=True, 
                             num_workers=0, 
                             drop_last=True, 
                             collate_fn = collate_fn_custom)

test_data = CGMData(PATH, 'test', [LEN_SEQ, LEN_LABEL, LEN_PRED])
test_data_loader = DataLoader(test_data, 
                              batch_size=BATCH_SIZE, 
                              shuffle=False, 
                              num_workers=0, 
                              drop_last=True,
                              collate_fn = collate_fn_custom)

In [None]:
LEN_PRED_MODEL = 12
LEN_LABEL = 60
LEN_SEQ = 180

# define model
D_MODEL = 512
D_FCN = 2048
N_HEADS = 12
R_DROP = 0.3
ACTIV = "relu"
NUM_ENC_LAYERS = 2
NUM_DEC_LAYERS = 1
DISTIL = True
DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

model = Gluformer(d_model=D_MODEL, 
                  n_heads=N_HEADS, 
                  d_fcn=D_FCN, 
                  r_drop=R_DROP, 
                  activ=ACTIV, 
                  num_enc_layers=NUM_ENC_LAYERS, 
                  num_dec_layers=NUM_DEC_LAYERS,
                  distil=DISTIL, 
                  len_pred=LEN_PRED_MODEL)
model.train()
model = model.to(DEVICE)

In [None]:
# define loss and optimizer
lr = 0.0002
criterion =  ExpLikeliLoss(num_samples=4)
model_optim = torch.optim.Adam(model.parameters(), lr=lr, betas=(0, 0.9))

In [None]:
UPPER = 402
LOWER = 38
SCALE_1 = 5
SCALE_2 = 2

# define params for training
PATH_MODEL = foldername + "/model_best.pth"
EPOCHS = 100
TRAIN_STEPS = len(train_data_loader)
early_stop = EarlyStop(20, 0)

for epoch in range(EPOCHS):
  iter_count = 0
  train_loss = []
  
  epoch_time = time.time()
  curr_time = time.time()
  
  for i, (subj_id, batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(train_data_loader):
    iter_count += 1
    # zero-out grad
    model_optim.zero_grad()

    pred, true = process_batch(subj_id = subj_id, 
                               batch_x=batch_x, 
                               batch_y=batch_y, 
                               batch_x_mark=batch_x_mark, 
                               batch_y_mark=batch_y_mark, 
                               len_pred=LEN_PRED, 
                               len_label=LEN_LABEL, 
                               model=model, 
                               device=DEVICE)
    loss = criterion(pred, true)
    train_loss.append(loss.item())

    # print every 100
    if (i+1) % 100==0:
      print("\titers: {0}, epoch: {1} | loss: {2:.7f}".format(i + 1, epoch + 1, loss.item()))
      speed = (time.time() - curr_time) / iter_count
      left_time = speed * ((EPOCHS - epoch) * TRAIN_STEPS - i)
      print('\tspeed: {:.4f}s/iter; left time: {:.4f}s'.format(speed, left_time))
      iter_count = 0
      curr_time = time.time()
      
    loss.backward()
    model_optim.step()

  # compute average train loss
  train_loss = np.average(train_loss)

  # compute validation loss
  val_loss = []
  for i, (subj_id, batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(val_data_loader):
    pred, true = predict_batch(subj_id = subj_id, 
                               batch_x=batch_x, 
                               batch_y=batch_y, 
                               batch_x_mark=batch_x_mark, 
                               batch_y_mark=batch_y_mark, 
                               len_pred=LEN_PRED, 
                               len_pred_model=LEN_PRED_MODEL,
                               len_label=LEN_LABEL, 
                               model=model, 
                               device=DEVICE)
    pred = pred.detach().cpu().numpy(); true = true.detach().cpu().numpy()
    # transform back
    pred = (pred + SCALE_1) / (SCALE_1 * SCALE_2) * (UPPER - LOWER) + LOWER
    true = (true + SCALE_1) / (SCALE_1 * SCALE_2) * (UPPER - LOWER) + LOWER

    pred = pred.transpose((1,0,2)).reshape((pred.shape[1], -1, NUM_SAMPLES)).transpose((1, 0, 2))
    pred = np.mean(pred, axis=2)
    true = true.transpose((1,0,2)).reshape((true.shape[1], -1, NUM_SAMPLES)).transpose((1, 0, 2))[:, :, 0]
    # compute APE
    ape_val = np.mean(np.abs(true - pred) / true)
    val_loss.append(ape_val)
  val_loss = np.median(np.array(val_loss))
  
  # compute test loss
  test_loss_3 = []; test_loss_6 = []; test_loss_9 = []; test_loss_12 = []
  for i, (subj_id, batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(test_data_loader):
    pred, true = predict_batch(subj_id = subj_id, 
                               batch_x=batch_x, 
                               batch_y=batch_y, 
                               batch_x_mark=batch_x_mark, 
                               batch_y_mark=batch_y_mark, 
                               len_pred=LEN_PRED, 
                               len_pred_model=LEN_PRED_MODEL,
                               len_label=LEN_LABEL, 
                               model=model, 
                               device=DEVICE)
    pred = pred.detach().cpu().numpy(); true = true.detach().cpu().numpy()
    # transform back
    pred = (pred + SCALE_1) / (SCALE_1 * SCALE_2) * (UPPER - LOWER) + LOWER
    true = (true + SCALE_1) / (SCALE_1 * SCALE_2) * (UPPER - LOWER) + LOWER

    pred = pred.transpose((1,0,2)).reshape((pred.shape[1], -1, NUM_SAMPLES)).transpose((1, 0, 2))
    pred = np.mean(pred, axis=2)
    true = true.transpose((1,0,2)).reshape((true.shape[1], -1, NUM_SAMPLES)).transpose((1, 0, 2))[:, :, 0]
    # compute APE: 15 mins (3 points)
    ape_3 = np.mean(np.abs(true[:, :3] - pred[:, :3]) / true[:, :3])
    # compute APE: 30 mins (6 points)
    ape_6 = np.mean(np.abs(true[:, :6] - pred[:, :6]) / true[:, :6])
    # compute APE: 45 mins (9 points)
    ape_9 = np.mean(np.abs(true[:, :9] - pred[:, :9]) / true[:, :9])
    # compute APE: full
    ape_12 = np.mean(np.abs(true - pred) / true)

    test_loss_3.append(ape_3)
    test_loss_6.append(ape_6)
    test_loss_9.append(ape_9)
    test_loss_12.append(ape_12)
  test_loss_3 = np.median(np.array(test_loss_3))
  test_loss_6 = np.median(np.array(test_loss_6))
  test_loss_9 = np.median(np.array(test_loss_9))
  test_loss_12 = np.median(np.array(test_loss))
  
  # check early stopping
  early_stop(val_loss, model, PATH_MODEL)
  if early_stop.stop:
    print("Early stopping...")
    break

  # update lr
  # adjust_learning_rate(model_optim, epoch, lr)
  
  print("Epoch: {} cost time: {}".format(epoch+1, time.time() - epoch_time))
  print("Epoch: {0}, Steps: {1} | Train Loss: {2:.7f} Val Loss: {3:.7f}".format(
      epoch + 1, TRAIN_STEPS, train_loss, val_loss))
  print("Test loss for 15 mins: {0:.7f}, for 30 mins: {1:.7f}, for 45 mins: : {2:.7f}, for 60 mins: {3:.7f}".format(
      test_loss_3, test_loss_6, test_loss_9, test_loss_12))

  return torch.max_pool1d(input, kernel_size, stride, padding, dilation, ceil_mode)


	iters: 100, epoch: 1 | loss: 11.6646633
	speed: 0.2554s/iter; left time: 216290.3350s
	iters: 200, epoch: 1 | loss: 9.1925278
	speed: 0.2497s/iter; left time: 211438.9984s
	iters: 300, epoch: 1 | loss: 4.9972897
	speed: 0.2499s/iter; left time: 211634.9073s
	iters: 400, epoch: 1 | loss: 6.4039292
	speed: 0.2500s/iter; left time: 211695.2129s
	iters: 500, epoch: 1 | loss: 19.2153320
	speed: 0.2510s/iter; left time: 212535.8972s
	iters: 600, epoch: 1 | loss: 5.7895002
	speed: 0.2510s/iter; left time: 212437.2104s
	iters: 700, epoch: 1 | loss: 4.5519948
	speed: 0.2516s/iter; left time: 212916.9317s
	iters: 800, epoch: 1 | loss: 5.1600418
	speed: 0.2512s/iter; left time: 212590.4390s
	iters: 900, epoch: 1 | loss: 5.0389805
	speed: 0.2516s/iter; left time: 212863.8264s
	iters: 1000, epoch: 1 | loss: 5.0200205
	speed: 0.2514s/iter; left time: 212716.7962s
	iters: 1100, epoch: 1 | loss: 4.4239936
	speed: 0.2517s/iter; left time: 212941.0563s
	iters: 1200, epoch: 1 | loss: 7.7816019
	speed: 0

	iters: 900, epoch: 2 | loss: 6.4135180
	speed: 0.2526s/iter; left time: 211601.3404s
	iters: 1000, epoch: 2 | loss: 7.7852879
	speed: 0.2524s/iter; left time: 211446.2603s
	iters: 1100, epoch: 2 | loss: 6.6175232
	speed: 0.2522s/iter; left time: 211261.1609s
	iters: 1200, epoch: 2 | loss: 10.2038193
	speed: 0.2523s/iter; left time: 211279.2725s
	iters: 1300, epoch: 2 | loss: 4.4126577
	speed: 0.2523s/iter; left time: 211243.2787s
	iters: 1400, epoch: 2 | loss: 8.6446419
	speed: 0.2527s/iter; left time: 211569.0668s
	iters: 1500, epoch: 2 | loss: 8.2923050
	speed: 0.2522s/iter; left time: 211124.4397s
	iters: 1600, epoch: 2 | loss: 4.0145264
	speed: 0.2524s/iter; left time: 211234.8770s
	iters: 1700, epoch: 2 | loss: 8.3509369
	speed: 0.2524s/iter; left time: 211234.7009s
	iters: 1800, epoch: 2 | loss: 7.7461362
	speed: 0.2526s/iter; left time: 211392.3858s
	iters: 1900, epoch: 2 | loss: 13.3950977
	speed: 0.2522s/iter; left time: 211047.5605s
	iters: 2000, epoch: 2 | loss: 7.4901338
	

	iters: 1700, epoch: 3 | loss: 11.2926702
	speed: 0.2524s/iter; left time: 209067.3578s
	iters: 1800, epoch: 3 | loss: 4.0195723
	speed: 0.2525s/iter; left time: 209160.1715s
	iters: 1900, epoch: 3 | loss: 6.8791289
	speed: 0.2523s/iter; left time: 208939.7972s
	iters: 2000, epoch: 3 | loss: 6.3821392
	speed: 0.2525s/iter; left time: 209140.5032s
	iters: 2100, epoch: 3 | loss: 2.7927985
	speed: 0.2524s/iter; left time: 208989.3854s
	iters: 2200, epoch: 3 | loss: 4.8479366
	speed: 0.2521s/iter; left time: 208725.3650s
	iters: 2300, epoch: 3 | loss: 9.4943161
	speed: 0.2523s/iter; left time: 208828.5294s
	iters: 2400, epoch: 3 | loss: 5.6303530
	speed: 0.2523s/iter; left time: 208840.3849s
	iters: 2500, epoch: 3 | loss: 6.7934537
	speed: 0.2526s/iter; left time: 209057.0021s
	iters: 2600, epoch: 3 | loss: 12.9634495
	speed: 0.2521s/iter; left time: 208590.4341s
	iters: 2700, epoch: 3 | loss: 8.3113441
	speed: 0.2521s/iter; left time: 208608.9797s
	iters: 2800, epoch: 3 | loss: 5.1617451


	iters: 2500, epoch: 4 | loss: 6.3385692
	speed: 0.2524s/iter; left time: 206735.2841s
	iters: 2600, epoch: 4 | loss: 5.5605979
	speed: 0.2526s/iter; left time: 206932.2202s
	iters: 2700, epoch: 4 | loss: 4.2199860
	speed: 0.2527s/iter; left time: 206961.7841s
	iters: 2800, epoch: 4 | loss: 2.4934599
	speed: 0.2525s/iter; left time: 206782.4864s
	iters: 2900, epoch: 4 | loss: 7.6837444
	speed: 0.2525s/iter; left time: 206740.8095s
	iters: 3000, epoch: 4 | loss: 5.7435746
	speed: 0.2526s/iter; left time: 206815.9912s
	iters: 3100, epoch: 4 | loss: 6.3491983
	speed: 0.2526s/iter; left time: 206802.9175s
	iters: 3200, epoch: 4 | loss: 6.2992010
	speed: 0.2527s/iter; left time: 206864.1179s
	iters: 3300, epoch: 4 | loss: 8.4590597
	speed: 0.2524s/iter; left time: 206558.2416s
	iters: 3400, epoch: 4 | loss: 4.3203564
	speed: 0.2526s/iter; left time: 206682.6886s
	iters: 3500, epoch: 4 | loss: 5.8724308
	speed: 0.2525s/iter; left time: 206585.0631s
	iters: 3600, epoch: 4 | loss: 3.6441824
	s

	iters: 3300, epoch: 5 | loss: 7.0855598
	speed: 0.2526s/iter; left time: 204551.5179s
	iters: 3400, epoch: 5 | loss: 5.9640393
	speed: 0.2530s/iter; left time: 204868.6118s
	iters: 3500, epoch: 5 | loss: 4.2628021
	speed: 0.2527s/iter; left time: 204617.5861s
	iters: 3600, epoch: 5 | loss: 10.4803181
	speed: 0.2528s/iter; left time: 204675.3371s
	iters: 3700, epoch: 5 | loss: 4.4719663
	speed: 0.2528s/iter; left time: 204622.6230s
	iters: 3800, epoch: 5 | loss: 8.4763079
	speed: 0.2526s/iter; left time: 204459.0853s
	iters: 3900, epoch: 5 | loss: 7.4861751
	speed: 0.2528s/iter; left time: 204609.8149s
	iters: 4000, epoch: 5 | loss: 5.3428488
	speed: 0.2525s/iter; left time: 204361.2621s
	iters: 4100, epoch: 5 | loss: 4.4351163
	speed: 0.2526s/iter; left time: 204387.9254s
	iters: 4200, epoch: 5 | loss: 6.8271160
	speed: 0.2525s/iter; left time: 204278.5805s
	iters: 4300, epoch: 5 | loss: 3.6890774
	speed: 0.2528s/iter; left time: 204480.8135s
	iters: 4400, epoch: 5 | loss: 3.5416951
	

	iters: 4100, epoch: 6 | loss: 5.8266029
	speed: 0.2527s/iter; left time: 202350.1918s
	iters: 4200, epoch: 6 | loss: 7.3996730
	speed: 0.2525s/iter; left time: 202101.9098s
	iters: 4300, epoch: 6 | loss: 6.7615118
	speed: 0.2525s/iter; left time: 202074.7998s
	iters: 4400, epoch: 6 | loss: 8.6384487
	speed: 0.2524s/iter; left time: 202018.7394s
	iters: 4500, epoch: 6 | loss: 8.0238981
	speed: 0.2526s/iter; left time: 202166.3894s
	iters: 4600, epoch: 6 | loss: 5.0439858
	speed: 0.2526s/iter; left time: 202094.5461s
	iters: 4700, epoch: 6 | loss: 7.0267525
	speed: 0.2524s/iter; left time: 201939.9728s
	iters: 4800, epoch: 6 | loss: 4.4853582
	speed: 0.2525s/iter; left time: 201956.8222s
	iters: 4900, epoch: 6 | loss: 4.3770342
	speed: 0.2526s/iter; left time: 202013.3930s
	iters: 5000, epoch: 6 | loss: 4.6079187
	speed: 0.2529s/iter; left time: 202221.1691s
	iters: 5100, epoch: 6 | loss: 7.2847986
	speed: 0.2524s/iter; left time: 201810.4797s
	iters: 5200, epoch: 6 | loss: 5.9221611
	s

	iters: 4900, epoch: 7 | loss: 1.9641713
	speed: 0.2524s/iter; left time: 199704.5748s
	iters: 5000, epoch: 7 | loss: 6.2425694
	speed: 0.2525s/iter; left time: 199815.9614s
	iters: 5100, epoch: 7 | loss: 6.9459915
	speed: 0.2524s/iter; left time: 199692.4116s
	iters: 5200, epoch: 7 | loss: 5.8520794
	speed: 0.2526s/iter; left time: 199823.7381s
	iters: 5300, epoch: 7 | loss: 4.2962751
	speed: 0.2522s/iter; left time: 199497.7372s
	iters: 5400, epoch: 7 | loss: 5.5787745
	speed: 0.2523s/iter; left time: 199545.2447s
	iters: 5500, epoch: 7 | loss: 3.7369604
	speed: 0.2523s/iter; left time: 199511.1130s
	iters: 5600, epoch: 7 | loss: 5.2527161
	speed: 0.2523s/iter; left time: 199456.5149s
	iters: 5700, epoch: 7 | loss: 3.6378188
	speed: 0.2527s/iter; left time: 199747.8861s
	iters: 5800, epoch: 7 | loss: 6.5847688
	speed: 0.2524s/iter; left time: 199510.5792s
	iters: 5900, epoch: 7 | loss: 5.3569021
	speed: 0.2523s/iter; left time: 199431.3932s
	iters: 6000, epoch: 7 | loss: 3.5454493
	s

	iters: 5700, epoch: 8 | loss: 5.4934359
	speed: 0.2530s/iter; left time: 197907.8916s
	iters: 5800, epoch: 8 | loss: 10.1249275
	speed: 0.2529s/iter; left time: 197771.4737s
	iters: 5900, epoch: 8 | loss: 2.3323219
	speed: 0.2531s/iter; left time: 197861.4173s
	iters: 6000, epoch: 8 | loss: 3.2004354
	speed: 0.2529s/iter; left time: 197732.2781s
	iters: 6100, epoch: 8 | loss: 7.2648697
	speed: 0.2529s/iter; left time: 197661.3903s
	iters: 6200, epoch: 8 | loss: 4.6370964
	speed: 0.2528s/iter; left time: 197565.3625s
	iters: 6300, epoch: 8 | loss: 12.2692251
	speed: 0.2530s/iter; left time: 197683.1479s
	iters: 6400, epoch: 8 | loss: 2.9590316
	speed: 0.2533s/iter; left time: 197891.3645s
	iters: 6500, epoch: 8 | loss: 5.2980971
	speed: 0.2529s/iter; left time: 197592.5711s
	iters: 6600, epoch: 8 | loss: 7.1536145
	speed: 0.2529s/iter; left time: 197557.7262s
	iters: 6700, epoch: 8 | loss: 6.5393691
	speed: 0.2528s/iter; left time: 197468.3110s
	iters: 6800, epoch: 8 | loss: 4.6485233


	iters: 6500, epoch: 9 | loss: 5.3326364
	speed: 0.2529s/iter; left time: 195469.9917s
	iters: 6600, epoch: 9 | loss: 6.0827284
	speed: 0.2534s/iter; left time: 195802.1976s
	iters: 6700, epoch: 9 | loss: 4.3326292
	speed: 0.2529s/iter; left time: 195427.6738s
	iters: 6800, epoch: 9 | loss: 2.8810706
	speed: 0.2530s/iter; left time: 195469.0554s
	iters: 6900, epoch: 9 | loss: 12.6349611
	speed: 0.2530s/iter; left time: 195401.3147s
	iters: 7000, epoch: 9 | loss: 20.5208111
	speed: 0.2531s/iter; left time: 195508.7154s
	iters: 7100, epoch: 9 | loss: 5.3361945
	speed: 0.2529s/iter; left time: 195297.7309s
	iters: 7200, epoch: 9 | loss: 6.4034581
	speed: 0.2528s/iter; left time: 195175.7417s
	iters: 7300, epoch: 9 | loss: 6.9882255
	speed: 0.2530s/iter; left time: 195339.5194s
	iters: 7400, epoch: 9 | loss: 4.9361033
	speed: 0.2528s/iter; left time: 195170.7429s
	iters: 7500, epoch: 9 | loss: 6.8008537
	speed: 0.2533s/iter; left time: 195538.5212s
	iters: 7600, epoch: 9 | loss: 3.6993239


	iters: 7200, epoch: 10 | loss: 4.8369322
	speed: 0.2532s/iter; left time: 193355.6347s
	iters: 7300, epoch: 10 | loss: 8.1106672
	speed: 0.2528s/iter; left time: 193056.1082s
	iters: 7400, epoch: 10 | loss: 8.1619720
	speed: 0.2529s/iter; left time: 193064.3187s
	iters: 7500, epoch: 10 | loss: 11.3564243
	speed: 0.2528s/iter; left time: 192976.5338s
	iters: 7600, epoch: 10 | loss: 7.7526703
	speed: 0.2529s/iter; left time: 193044.3292s
	iters: 7700, epoch: 10 | loss: 9.6951580
	speed: 0.2532s/iter; left time: 193202.1464s
	iters: 7800, epoch: 10 | loss: 7.8652220
	speed: 0.2528s/iter; left time: 192869.7102s
	iters: 7900, epoch: 10 | loss: 2.3175101
	speed: 0.2528s/iter; left time: 192900.7338s
	iters: 8000, epoch: 10 | loss: 8.9715557
	speed: 0.2528s/iter; left time: 192818.5985s
	iters: 8100, epoch: 10 | loss: 6.3289957
	speed: 0.2528s/iter; left time: 192808.5387s
	iters: 8200, epoch: 10 | loss: 2.3638635
	speed: 0.2530s/iter; left time: 192933.0149s
	iters: 8300, epoch: 10 | loss:

	iters: 7900, epoch: 11 | loss: 5.0012059
	speed: 0.2527s/iter; left time: 190693.1003s
	iters: 8000, epoch: 11 | loss: 3.6121922
	speed: 0.2526s/iter; left time: 190580.9908s
	iters: 8100, epoch: 11 | loss: 4.6712847
	speed: 0.2526s/iter; left time: 190567.3238s
	iters: 8200, epoch: 11 | loss: 6.6878433
	speed: 0.2526s/iter; left time: 190519.0110s
	iters: 8300, epoch: 11 | loss: 9.8591423
	speed: 0.2526s/iter; left time: 190479.1795s
	iters: 8400, epoch: 11 | loss: 4.0379014
	speed: 0.2530s/iter; left time: 190776.7896s
Validation loss did not decrease 5 / 20
Epoch: 11 cost time: 2230.775738954544
Epoch: 11, Steps: 8471 | Train Loss: 6.4602991 Val Loss: 0.1687631
Test loss for 15 mins: 0.0622406, for 30 mins: 0.0848783, for 45 mins: : 0.1061640, for 60 mins: 0.1257197
	iters: 100, epoch: 12 | loss: 9.8795414
	speed: 0.2513s/iter; left time: 189433.8364s
	iters: 200, epoch: 12 | loss: 5.1988740
	speed: 0.2527s/iter; left time: 190469.6119s
	iters: 300, epoch: 12 | loss: 3.1079314
	spe