# Environment Variables

In [None]:
import argparse, time, os, random, faulthandler, torch
import torch.nn as nn
from types import SimpleNamespace
from torch import optim
import numpy as np
from torch.utils.data import DataLoader, Subset, random_split
from data_provider.data_loader_emb import Dataset_ETT_hour, Dataset_ETT_minute, Dataset_Custom
from models.T3Time import TriModal
from utils.metrics import MSE, MAE, metric
faulthandler.enable()
torch.cuda.empty_cache()
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:150"

# Helper Function

In [None]:
class trainer:
    def __init__(
        self,
        scaler,
        channel,
        num_nodes,
        seq_len,
        pred_len,
        dropout_n,
        d_llm,
        e_layer,
        d_layer,
        head,
        lrate,
        wdecay,
        device,
        epochs
    ):
        self.model = TriModal(
            device=device, channel=channel, num_nodes=num_nodes, seq_len=seq_len, pred_len=pred_len, 
            dropout_n=dropout_n, d_llm=d_llm, e_layer=e_layer, d_layer=d_layer, head=head
        )
        self.epochs = epochs
        self.optimizer = optim.AdamW(self.model.parameters(), lr=lrate, weight_decay=wdecay)
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(self.optimizer, T_max=min(epochs, 50), eta_min=1e-6)
        self.loss = MSE
        self.MAE = MAE
        self.clip = 5
        print("The number of trainable parameters: {}".format(self.model.count_trainable_params()))
        print("The number of parameters: {}".format(self.model.param_num()))


    def train(self, input, mark, embeddings, real):
        self.model.train()
        self.optimizer.zero_grad()
        predict = self.model(input, mark, embeddings)
        loss = self.loss(predict, real)
        loss.backward()
        if self.clip is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
        self.optimizer.step()
        mae = self.MAE(predict, real)
        return loss.item(), mae.item()
    
    def eval(self, input, mark, embeddings, real_val):
        self.model.eval()
        with torch.no_grad():
            predict = self.model(input,mark, embeddings)
        loss = self.loss(predict, real_val)
        mae = self.MAE(predict, real_val)
        return loss.item(), mae.item()

def load_data(args):
    data_map = {
        'ETTh1': Dataset_ETT_hour,
        'ETTh2': Dataset_ETT_hour,
        'ETTm1': Dataset_ETT_minute,
        'ETTm2': Dataset_ETT_minute
    }
    data_class = data_map.get(args.data_path, Dataset_Custom)
    train_set = data_class(flag='train', scale=True, size=[args.seq_len, 0, args.pred_len], data_path=args.data_path)
    val_set = data_class(flag='val', scale=True, size=[args.seq_len, 0, args.pred_len], data_path=args.data_path)
    test_set = data_class(flag='test', scale=True, size=[args.seq_len, 0, args.pred_len], data_path=args.data_path)

    scaler = train_set.scaler

    train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers)
    val_loader = DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers)
    test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers)

    return train_set, val_set, test_set, train_loader, val_loader, test_loader, scaler

def seed_it(seed):
    random.seed(seed)
    os.environ["PYTHONSEED"] = str(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
    torch.manual_seed(seed)

def seed_everything(seed):
    os.environ["PYTHONHASHSEED"] = str(seed)
    os.environ["OMP_NUM_THREADS"]  = "1"
    os.environ["MKL_NUM_THREADS"]  = "1"
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark   = False


# Few shot forecasting

## Training

In [None]:
args = SimpleNamespace(
    device='cuda', data_path='ETTm1',
    channel=128, batch_size=64, dropout_n=0.5,
    e_layer=1, d_layer=3,
    num_nodes=7, seq_len=512, pred_len=96,
    learning_rate=1e-3, d_llm=768,
    head=8, weight_decay=1e-3,
    num_workers=10, model_name='gpt2',
    epochs=40, seed=2024,
    es_patience=10, save='./logs/custom-save-path-'
)

seed_everything(args.seed)
g1 = torch.Generator().manual_seed(args.seed)
percent = 0.1                                                       # Percent of data for few shot forecasting

# ------------------------------------------Few Shot DataLoading----------------------
t1 = time.time()
train_set, val_set, test_set, train_loader, val_loader, test_loader, scaler = load_data(args)


num_train       = len(train_set)
indices         = list(range(num_train))
few_shot_size   = int(num_train * percent)                          # Size of the few shot. e.g 10%
rest_size       = num_train - few_shot_size
last_indices    = list(range(num_train - few_shot_size, num_train))

first_subset    = Subset(train_set, indices[:few_shot_size])        # taking the first 10% 
last_subset     = Subset(train_set, last_indices)                   # taking the last 10% 
random_set, _   = random_split(train_set, [few_shot_size, rest_size], generator=g1) 

"""
train_loader    = DataLoader(first_subset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers)   # First subset
train_loader    = DataLoader(last_subset, batch_size=args.batch_size, shuffle=False, drop_last=True, num_workers=args.num_workers)    # Last subset
train_loader    = DataLoader(random_set, batch_size=args.batch_size, shuffle=True, drop_last=True, num_workers=args.num_workers)      # Random subset
"""

g2 = torch.Generator().manual_seed(args.seed)
train_loader = DataLoader(random_set, batch_size = args.batch_size, shuffle = True, drop_last = True, num_workers = args.num_workers,
    generator = g2, worker_init_fn = lambda worker_id: np.random.seed(args.seed + worker_id))


print(f"Data ready for few shot learning")
# ------------------------------------------- Training  ------------------------------

print()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

loss = 9999999
test_log = 999999
epochs_since_best_mse = 0

path = os.path.join(args.save, args.data_path, 
                    f"{args.pred_len}_{args.channel}_{args.e_layer}_{args.d_layer}_{args.learning_rate}_{args.dropout_n}_{args.seed}/")
if not os.path.exists(path):
    os.makedirs(path)
his_train_loss = []
his_train_mae = []
his_loss = []
his_loss_mae = []
val_time = []
train_time = []
print(args)

engine = trainer(
    scaler=scaler,
    channel=args.channel,
    num_nodes=args.num_nodes,
    seq_len=args.seq_len,
    pred_len=args.pred_len,
    dropout_n=args.dropout_n,
    d_llm=args.d_llm,
    e_layer=args.e_layer,
    d_layer=args.d_layer,
    head=args.head,
    lrate=args.learning_rate,
    wdecay=args.weight_decay,
    device=device,
    epochs=args.epochs
)

print("Start training...", flush=True)

for i in range(1, args.epochs + 1):
    t1 = time.time()
    train_loss = []
    train_mae = []
    
    for iter, (x,y,x_mark,y_mark, embeddings) in enumerate(train_loader):
        trainx = torch.Tensor(x).to(device) # [B, L, N]
        trainy = torch.Tensor(y).to(device)
        trainx_mark = torch.Tensor(x_mark).to(device) 
        train_embedding = torch.Tensor(embeddings).to(device)
        metrics = engine.train(trainx, trainx_mark, train_embedding, trainy)
        train_loss.append(metrics[0])
        train_mae.append(metrics[1])

    t2 = time.time()
    log = "Epoch: {:03d}, Training Time: {:.4f} secs"
    print(log.format(i, (t2 - t1)))
    train_time.append(t2 - t1)

    # validation
    val_loss = []
    val_mae = []
    s1 = time.time()

    for iter, (x,y,x_mark,y_mark, embeddings) in enumerate(val_loader):
        valx = torch.Tensor(x).to(device)
        valy = torch.Tensor(y).to(device)
        valx_mark = torch.Tensor(x_mark).to(device)
        val_embedding = torch.Tensor(embeddings).to(device)
        metrics = engine.eval(valx, valx_mark, val_embedding, valy)
        val_loss.append(metrics[0])
        val_mae.append(metrics[1])

    s2 = time.time()
    log = "Epoch: {:03d}, Validation Time: {:.4f} secs"
    print(log.format(i, (s2 - s1)))
    val_time.append(s2 - s1)

    mtrain_loss = np.mean(train_loss)
    mtrain_mae = np.mean(train_mae)
    mvalid_loss = np.mean(val_loss)
    mvalid_mae = np.mean(val_mae)

    his_train_loss.append(mtrain_loss)
    his_train_mae.append(mtrain_mae)
    his_loss.append(mvalid_loss)
    his_loss_mae.append(mvalid_mae)

    print("-----------------------")

    log = "Epoch: {:03d}, Train Loss: {:.4f}, Train MAE: {:.4f} "
    print(
        log.format(i, mtrain_loss, mtrain_mae),
        flush=True,
    )
    log = "Epoch: {:03d}, Valid Loss: {:.4f}, Valid MAE: {:.4f}"
    print(
        log.format(i, mvalid_loss, mvalid_mae),
        flush=True,
    )

    if mvalid_loss < loss:
        print("###Update tasks appear###")
        if i <= 10:
            
            loss = mvalid_loss
            torch.save(engine.model.state_dict(), path + "best_model.pth")
            bestid = i
            epochs_since_best_mse = 0
            print("Updating! Valid Loss:{:.4f}".format(mvalid_loss), end=", ")
            print("epoch: ", i)
        else:
            test_outputs = []
            test_y = []

            for iter, (x,y,x_mark,y_mark, embeddings) in enumerate(test_loader):
                testx = torch.Tensor(x).to(device)
                testy = torch.Tensor(y).to(device)
                testx_mark = torch.Tensor(x_mark).to(device)
                test_embedding = torch.Tensor(embeddings).to(device)
                with torch.no_grad():
                    preds = engine.model(testx, testx_mark, test_embedding)
                test_outputs.append(preds)
                test_y.append(testy)
            
            test_pre = torch.cat(test_outputs, dim=0)
            test_real = torch.cat(test_y, dim=0)

            amse = []
            amae = []
            
            for j in range(args.pred_len):
                pred = test_pre[:, j,].to(device)
                real = test_real[:, j, ].to(device)
                metrics = metric(pred, real)
                log = "Evaluate best model on test data for horizon {:d}, Test MSE: {:.4f}, Test MAE: {:.4f}"
                amse.append(metrics[0])
                amae.append(metrics[1])

            log = "On average horizons, Test MSE: {:.4f}, Test MAE: {:.4f}"
            print(
                log.format(
                    np.mean(amse), np.mean(amae)
                )
            )

            if np.mean(amse) < test_log:
                test_log = np.mean(amse)
                loss = mvalid_loss
                torch.save(engine.model.state_dict(), path + "best_model.pth")
                epochs_since_best_mse = 0
                print("Test low! Updating! Test Loss: {:.4f}".format(np.mean(amse)), end=", ")
                print("Test low! Updating! Valid Loss: {:.4f}".format(mvalid_loss), end=", ")

                bestid = i
                print("epoch: ", i)
            else:
                epochs_since_best_mse += 1
                print("No update")

    else:
        epochs_since_best_mse += 1
        print("No update")

    engine.scheduler.step()

    if epochs_since_best_mse >= args.es_patience and i >= args.epochs//2: # early stop
        break

# Output consumption
print("Average Training Time: {:.4f} secs/epoch".format(np.mean(train_time)))
print("Average Validation Time: {:.4f} secs".format(np.mean(val_time)))

# Test
print("Training ends")
print("The epoch of the best result：", bestid)
print("The valid loss of the best model", str(round(his_loss[bestid - 1], 4)))

engine.model.load_state_dict(torch.load(path + "best_model.pth"))

test_outputs = []
test_y = []

for iter, (x,y,x_mark,y_mark, embeddings) in enumerate(test_loader):
    testx = torch.Tensor(x).to(device)
    testy = torch.Tensor(y).to(device)
    testx_mark = torch.Tensor(x_mark).to(device)
    test_embedding = torch.Tensor(embeddings).to(device)
    with torch.no_grad():
        preds = engine.model(testx, testx_mark, test_embedding)
    test_outputs.append(preds)
    test_y.append(testy)

test_pre = torch.cat(test_outputs, dim=0)
test_real = torch.cat(test_y, dim=0)

amse = []
amae = []

for j in range(args.pred_len):
    pred = test_pre[:, j,].to(device)
    real = test_real[:, j, ].to(device)
    metrics = metric(pred, real)
    log = "Evaluate best model on test data for horizon {:d}, Test MSE: {:.4f}, Test MAE: {:.4f}"
    amse.append(metrics[0])
    amae.append(metrics[1])

log = "On average horizons, Test MSE: {:.4f}, Test MAE: {:.4f}"
print(log.format(np.mean(amse), np.mean(amae)))