In [3]:
import argparse
import torch
from torch import nn, optim
from torch.optim import lr_scheduler
from tqdm import tqdm
import time
import random
import numpy as np
import os
import TimeLLM
from TimeSeries import data_provider
from tools import EarlyStopping, adjust_learning_rate, vali, load_content
import sys
import warnings
warnings.filterwarnings("ignore")

In [4]:
# Set random seeds for reproducibility
fix_seed = 2021
random.seed(fix_seed)
torch.manual_seed(fix_seed)
np.random.seed(fix_seed)

In [5]:
args = argparse.Namespace(
    root_path='./',  # Assuming the file is in the current directory
    data_path='NVDA.csv',  # Your CSV file
    seq_len=384,
    label_len=96,
    pred_len=96,
    features='M',  # 'M' for multivariate, 'S' for univariate
    target=None,  # Replace with the actual name of the target column in your CSV
    embed='timeF',  # Use 'timeF' for time feature encoding
    scale=True,
    percent=100,
    num_workers=0,
    batch_size=2,
    freq='d',
    model_id='test',
    model_comment='none',
    model='TimeLLM',
    seed=2021,
    checkpoints='./checkpoints/',
    d_model=512,
    n_heads=8,
    e_layers=2,
    d_layers=1,
    d_ff=4096,
    dropout=0.1,
    activation='gelu',
    eval_batch_size=1,
    patience=10,
    learning_rate=0.0001,
    train_epochs=1,
    loss='MSE',
    lradj='type1',
    use_amp=False,
    llm_layers=12,
    task_name = 'long_term_forecast',
    llm_model = 'LLAMA',
    llm_dim = 4096,
    enc_in = 7,
    patch_len = 16 ,
    stride = 8,
    prompt_domain = False,
    content = None
        
)

In [6]:
model = TimeLLM.Model(args).float()

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [7]:
train_data, train_loader = data_provider(args, 'train')
vali_data, vali_loader = data_provider(args, 'val')
test_data, test_loader = data_provider(args, 'test')

data_x shape: (870, 6)
data_y shape: (870, 6)
data_x shape: (509, 6)
data_y shape: (509, 6)
data_x shape: (632, 6)
data_y shape: (632, 6)


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

Model(
  (llm_model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-11): 12 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): LlamaRMSNorm()
  )
  (d

In [9]:
torch.cuda.empty_cache()

In [10]:
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
criterion = nn.MSELoss()
scheduler = lr_scheduler.OneCycleLR(optimizer, max_lr=args.learning_rate, epochs=args.train_epochs, steps_per_epoch=len(train_loader))

In [11]:
if args.use_amp:
    scaler = torch.cuda.amp.GradScaler()

In [12]:
early_stopping = EarlyStopping(patience=args.patience)

In [None]:
for epoch in range(args.train_epochs):
    model.train()
    train_loss = []
    epoch_start_time = time.time()

    for i, (batch_x, batch_y, batch_x_mark, batch_y_mark) in enumerate(tqdm(train_loader)):
        optimizer.zero_grad()
        batch_x, batch_y = batch_x.float().to(device), batch_y.float().to(device)
        batch_x_mark, batch_y_mark = batch_x_mark.float().to(device), batch_y_mark.float().to(device)

        # Decoder input
        dec_inp = torch.zeros_like(batch_y[:, -args.pred_len:, :]).float()
        dec_inp = torch.cat([batch_y[:, :args.label_len, :], dec_inp], dim=1).to(device)

        if args.use_amp:
            with torch.cuda.amp.autocast():
                outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
                loss = criterion(outputs, batch_y[:, -args.pred_len:, :])
        else:
            outputs = model(batch_x, batch_x_mark, dec_inp, batch_y_mark)
            loss = criterion(outputs, batch_y[:, -args.pred_len:, :])

        if args.use_amp:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()

        train_loss.append(loss.item())

    scheduler.step()

    # Validation and Early Stopping
    train_loss = np.mean(train_loss)
    vali_loss, _ = vali(args, model, vali_data, vali_loader, criterion)
    early_stopping(vali_loss, model, path=args.checkpoints)

    print(f"Epoch: {epoch+1}, Train Loss: {train_loss:.6f}, Vali Loss: {vali_loss:.6f}, Time: {time.time()-epoch_start_time:.2f}s")

    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

# Testing phase
model.load_state_dict(torch.load(args.checkpoints))
test_loss, _ = vali(args, model, test_data, test_loader, criterion)
print(f"Test Loss: {test_loss:.6f}")

 34%|████████████████████████████████████████████████████████▎                                                                                                           | 403/1173 [4:52:31<9:30:00, 44.42s/it]