In [None]:
import argparse
import os
import time
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import math
from sklearn.preprocessing import StandardScaler
from accelerate import Accelerator, DistributedDataParallelKwargs

import transformers
from transformers import BertConfig, BertModel, BertTokenizer
transformers.logging.set_verbosity_error()


#########################################
# Helper functions
#########################################
def time_features(dates, freq='h'):
    """
    Generate time features from dates, returning an array of shape
    (num_features, num_dates) containing month, day, weekday, and hour.
    """
    dates = pd.to_datetime(dates)
    month = dates.month.values.astype(float)
    day = dates.day.values.astype(float)
    weekday = dates.weekday.values.astype(float)
    hour = dates.hour.values.astype(float)
    return np.stack([month, day, weekday, hour], axis=0)


#########################################
# Dataset definition (loads both CSV and video features .pt)
#########################################
class Dataset_Custom(Dataset):
    def __init__(self, root_path, flag='train', size=None,
                 data_path='PKU_Skeleton_Renew_preprocessed_0002-L.csv',
                 scale=True, timeenc=0, freq='h', percent=100, seasonal_patterns=None):
        # If size is not specified, default to set sequence, label, and prediction lengths
        if size is None:
            self.seq_len = 24 * 4 * 4
            self.label_len = 24 * 4
            self.pred_len = 24 * 4
        else:
            self.seq_len, self.label_len, self.pred_len = size

        assert flag in ['train', 'test', 'val']
        type_map = {'train': 0, 'val': 1, 'test': 2}
        self.set_type = type_map[flag]
        self.scale = scale
        self.timeenc = timeenc
        self.freq = freq
        self.percent = percent
        self.root_path = root_path
        self.data_path = data_path
        self.__read_data__()

        # Iterate over each CSV feature column (enc_in) to form samples
        self.enc_in = self.data_x.shape[-1]  # CSV data dimension (e.g., 75)
        self.tot_len = len(self.data_x) - self.seq_len - self.pred_len + 1

    def __read_data__(self):
        # Read CSV file and parse dates
        df_raw = pd.read_csv(os.path.join(self.root_path, self.data_path))
        df_raw['date'] = pd.to_datetime(df_raw['date'])
        # CSV data: all columns except date
        df_csv = df_raw.drop('date', axis=1)

        # Load video features .pt file (must have same number of rows as CSV)
        pt_path = os.path.join(self.root_path, "video-2_features.pt")
        video_features = torch.load(pt_path, map_location=torch.device('cpu'))
        assert video_features.shape[0] == len(df_raw), "Number of rows in video features does not match CSV dates!"
        # Linear mapping: map each 2048-dim video feature row to 200-dim (learnable, not frozen)
        linear_mapping = nn.Linear(2048, 200)
        mapped_features = linear_mapping(video_features)  # [T, 200]
        mapped_features = mapped_features.detach().numpy()

        # Standardize CSV data (using training statistics) and video features separately
        self.scaler_csv = StandardScaler()
        self.scaler_pt = StandardScaler()
        num_train = int(len(df_raw) * 0.7)
        num_test = int(len(df_raw) * 0.2)
        num_vali = len(df_raw) - num_train - num_test
        border1s = [0, num_train - self.seq_len, len(df_raw) - num_test - self.seq_len]
        border2s = [num_train, num_train + num_vali, len(df_raw)]
        border1 = border1s[self.set_type]
        border2 = border2s[self.set_type]
        if self.set_type == 0:
            border2 = (border2 - self.seq_len) * self.percent // 100 + self.seq_len

        # CSV data standardization
        train_csv = df_csv.iloc[border1s[0]:border2s[0]].values
        self.scaler_csv.fit(train_csv)
        csv_data = self.scaler_csv.transform(df_csv.values)

        # Video features standardization
        train_pt = mapped_features[border1s[0]:border2s[0]]
        self.scaler_pt.fit(train_pt)
        pt_data = self.scaler_pt.transform(mapped_features)

        # Generate timestamp features
        df_stamp = df_raw[['date']].iloc[border1:border2]
        if self.timeenc == 0:
            df_stamp['month'] = df_stamp['date'].apply(lambda row: row.month)
            df_stamp['day'] = df_stamp['date'].apply(lambda row: row.day)
            df_stamp['weekday'] = df_stamp['date'].apply(lambda row: row.weekday())
            df_stamp['hour'] = df_stamp['date'].apply(lambda row: row.hour)
            data_stamp = df_stamp.drop(['date'], axis=1).values
        elif self.timeenc == 1:
            data_stamp = time_features(df_stamp['date'].values, freq=self.freq)
            data_stamp = data_stamp.transpose(1, 0)

        # Only use CSV part as labels (forecast output only uses CSV encoding)
        self.data_x = csv_data[border1:border2]  # CSV inputs, shape [T, num_csv_features]
        self.data_y = csv_data[border1:border2]
        # Keep all dimensions for video features
        self.data_pt = pt_data[border1:border2]  # video inputs, shape [T, 200]
        self.data_stamp = data_stamp

    def __getitem__(self, index):
        # Use CSV feature count to index each variable
        tot_len = self.tot_len
        feat_id = index // tot_len  # CSV feature index (0 to enc_in-1)
        s_begin = index % tot_len
        s_end = s_begin + self.seq_len
        r_begin = s_end - self.label_len
        r_end = r_begin + self.label_len + self.pred_len
        # CSV part: take one variable column
        seq_x_csv = self.data_x[s_begin:s_end, feat_id:feat_id+1]
        seq_y = self.data_y[r_begin:r_end, feat_id:feat_id+1]
        # Video features part: also take corresponding single variable
        seq_x_pt = self.data_pt[s_begin:s_end, feat_id:feat_id+1]
        seq_x_mark = self.data_stamp[s_begin:s_end]
        seq_y_mark = self.data_stamp[r_begin:r_end]
        return seq_x_csv, seq_x_pt, seq_y, seq_x_mark, seq_y_mark

    def __len__(self):
        return (len(self.data_x) - self.seq_len - self.pred_len + 1) * self.data_x.shape[-1]

    def inverse_transform(self, data):
        return self.scaler_csv.inverse_transform(data)

class Dataset_ETT_hour(Dataset_Custom):
    pass

def data_provider(args, flag):
    Data = Dataset_ETT_hour
    timeenc = 0 if args.embed != 'timeF' else 1
    percent = args.percent
    if flag == 'test':
        shuffle_flag = False
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    else:
        shuffle_flag = True
        drop_last = True
        batch_size = args.batch_size
        freq = args.freq
    data_set = Data(
        root_path=args.root_path,
        data_path=args.data_path,
        flag=flag,
        size=[args.seq_len, args.label_len, args.pred_len],
        scale=args.scale,
        timeenc=timeenc,
        freq=freq,
        percent=percent,
        seasonal_patterns=args.seasonal_patterns
    )
    data_loader = DataLoader(
        data_set,
        batch_size=batch_size,
        shuffle=shuffle_flag,
        num_workers=args.num_workers,
        drop_last=drop_last)
    return data_set, data_loader


#########################################
# Model components
#########################################
# FlattenHead: flatten and linearly map to the prediction window
class FlattenHead(nn.Module):
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x

# PositionalEmbedding: positional encoding
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() *
                    -(math.log(10000.0) / d_model)).exp()
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]

# TokenEmbedding: 1D convolutional embedding
class TokenEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= '1.5.0' else 2
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                   kernel_size=3, padding=padding,
                                   padding_mode='circular', bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_in',
                                        nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x

# FixedEmbedding: fixed sinusoidal embedding
class FixedEmbedding(nn.Module):
    def __init__(self, c_in, d_model):
        super(FixedEmbedding, self).__init__()
        w = torch.zeros(c_in, d_model).float()
        w.require_grad = False
        position = torch.arange(0, c_in).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() *
                    -(math.log(10000.0) / d_model)).exp()
        w[:, 0::2] = torch.sin(position * div_term)
        w[:, 1::2] = torch.cos(position * div_term)
        self.emb = nn.Embedding(c_in, d_model)
        self.emb.weight = nn.Parameter(w, requires_grad=False)

    def forward(self, x):
        return self.emb(x).detach()

# TemporalEmbedding: time feature embedding
class TemporalEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='fixed', freq='h'):
        super(TemporalEmbedding, self).__init__()
        minute_size = 4
        hour_size = 24
        weekday_size = 7
        day_size = 32
        month_size = 13
        Embed = FixedEmbedding if embed_type == 'fixed' else nn.Embedding
        if freq == 't':
            self.minute_embed = Embed(minute_size, d_model)
        self.hour_embed = Embed(hour_size, d_model)
        self.weekday_embed = Embed(weekday_size, d_model)
        self.day_embed = Embed(day_size, d_model)
        self.month_embed = Embed(month_size, d_model)

    def forward(self, x):
        x = x.long()
        minute_x = self.minute_embed(x[:, :, 4]) if hasattr(self, 'minute_embed') else 0.
        hour_x = self.hour_embed(x[:, :, 3])
        weekday_x = self.weekday_embed(x[:, :, 2])
        day_x = self.day_embed(x[:, :, 1])
        month_x = self.month_embed(x[:, :, 0])
        return hour_x + weekday_x + day_x + month_x + minute_x

# TimeFeatureEmbedding: embedding for timeF type features
class TimeFeatureEmbedding(nn.Module):
    def __init__(self, d_model, embed_type='timeF', freq='h'):
        super(TimeFeatureEmbedding, self).__init__()
        freq_map = {'h': 4, 't': 5, 's': 6, 'm': 1, 'a': 1, 'w': 2, 'd': 3, 'b': 3}
        d_inp = freq_map[freq]
        self.embed = nn.Linear(d_inp, d_model, bias=False)

    def forward(self, x):
        return self.embed(x)

# DataEmbedding: combine value, positional, and time embeddings
class DataEmbedding(nn.Module):
    def __init__(self, c_in, d_model, embed_type='fixed', freq='h', dropout=0.1):
        super(DataEmbedding, self).__init__()
        self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
        self.position_embedding = PositionalEmbedding(d_model=d_model)
        self.temporal_embedding = (TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
                                   if embed_type != 'timeF'
                                   else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq))
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, x_mark):
        if x_mark is None:
            x = self.value_embedding(x) + self.position_embedding(x).to(x.device)
        else:
            x = (self.value_embedding(x) +
                 self.temporal_embedding(x_mark) +
                 self.position_embedding(x))
        return self.dropout(x)

# ReplicationPad1d: replicate padding
class ReplicationPad1d(nn.Module):
    def __init__(self, padding) -> None:
        super(ReplicationPad1d, self).__init__()
        self.padding = padding

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        replicate_padding = input[:, :, -1].unsqueeze(-1).repeat(1, 1, self.padding[-1])
        output = torch.cat([input, replicate_padding], dim=-1)
        return output

# PatchEmbedding: divide input into patches and embed
class PatchEmbedding(nn.Module):
    def __init__(self, d_model, patch_len, stride, dropout):
        super(PatchEmbedding, self).__init__()
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = ReplicationPad1d((0, stride))
        # TokenEmbedding input channels equal patch length
        self.value_embedding = TokenEmbedding(patch_len, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, C, T] where C is channel count (e.g., CSV: 1, video: 200)
        n_vars = x.shape[1]
        x = self.padding_patch_layer(x)
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        # reshape [B, C, L, patch_len] to [B*C, L, patch_len]
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        x = self.value_embedding(x)
        return self.dropout(x), n_vars

# Normalize: normalization and denormalization module
class Normalize(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=False,
                 subtract_last=False, non_norm=False):
        super(Normalize, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        self.subtract_last = subtract_last
        self.non_norm = non_norm
        if self.affine:
            self._init_params()

    def forward(self, x, mode: str):
        if mode == 'norm':
            self._get_statistics(x)
            x = self._normalize(x)
        elif mode == 'denorm':
            x = self._denormalize(x)
        else:
            raise NotImplementedError
        return x

    def _init_params(self):
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim - 1))
        if self.subtract_last:
            self.last = x[:, -1, :].unsqueeze(1)
        else:
            self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce,
                           keepdim=True, unbiased=False) + self.eps).detach() 

    def _normalize(self, x):
        if self.non_norm:
            return x
        if self.subtract_last:
            x = x - self.last
        else:
            x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.non_norm:
            return x
        if self.affine:
            x = (x - self.affine_bias) / (self.affine_weight + self.eps * self.eps)
        x = x * self.stdev
        if self.subtract_last:
            x = x + self.last
        else:
            x = x + self.mean
        return x

# ReprogrammingLayer: aligns patch embeddings with pretrained model embeddings
class ReprogrammingLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_keys=None, d_llm=None,
                 attention_dropout=0.1):
        super(ReprogrammingLayer, self).__init__()
        d_keys = d_keys or (d_model // n_heads)
        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, target_embedding, source_embedding, value_embedding):
        B, L, _ = target_embedding.shape
        S, _ = source_embedding.shape
        H = self.n_heads
        target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
        source_embedding = self.key_projection(source_embedding).view(S, H, -1)
        value_embedding = self.value_projection(value_embedding).view(S, H, -1)
        out = self.reprogramming(target_embedding, source_embedding, value_embedding)
        out = out.reshape(B, L, -1)
        return self.out_projection(out)

    def reprogramming(self, target_embedding, source_embedding, value_embedding):
        B, L, H, E = target_embedding.shape
        scale = 1. / math.sqrt(E)
        scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
        return reprogramming_embedding


#########################################
# TimeLLM model using fixed BERT
#########################################
class Model(nn.Module):
    def __init__(self, configs, patch_len=16, stride=8):
        super(Model, self).__init__()
        self.task_name = configs.task_name
        self.pred_len = configs.pred_len
        self.seq_len = configs.seq_len
        self.d_ff = configs.d_ff
        self.top_k = 5
        self.d_llm = configs.llm_dim  # match BERT hidden size, default 768
        self.patch_len = configs.patch_len
        self.stride = configs.stride

        # Force use of BERT
        configs.llm_model = 'BERT'
        self.bert_config = BertConfig.from_pretrained("bert-base-uncased")
        self.bert_config.num_hidden_layers = configs.llm_layers  
        self.bert_config.output_attentions = True
        self.bert_config.output_hidden_states = True

        try:
            self.llm_model = BertModel.from_pretrained(
                "bert-base-uncased", config=self.bert_config)
        except EnvironmentError:
            print("Local model files not found, attempting download...")
            self.llm_model = BertModel.from_pretrained(
                "bert-base-uncased", config=self.bert_config)

        try:
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
        except EnvironmentError:
            print("Local tokenizer not found, attempting download...")
            self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")

        if self.tokenizer.eos_token:
            self.tokenizer.pad_token = self.tokenizer.eos_token
        else:
            pad_token = '[PAD]'
            self.tokenizer.add_special_tokens({'pad_token': pad_token})
            self.tokenizer.pad_token = pad_token

        # Freeze BERT parameters
        for param in self.llm_model.parameters():
            param.requires_grad = False

        if configs.prompt_domain:
            self.description = configs.content
        else:
            self.description = (
                'The 3D spatial positions of the 25 keypoints of the human body '
                'serve as essential benchmarks for human pose trajectory prediction.')

        self.dropout = nn.Dropout(configs.dropout)
        # Create PatchEmbedding for CSV and video features
        self.csv_patch_embedding = PatchEmbedding(
            configs.d_model, self.patch_len, self.stride, configs.dropout)
        self.pt_patch_embedding = PatchEmbedding(
            configs.d_model, self.patch_len, self.stride, configs.dropout)
        self.word_embeddings = self.llm_model.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 1000
        self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)
        self.reprogramming_layer = ReprogrammingLayer(
            configs.d_model, configs.n_heads, self.d_ff, self.d_llm)
        # Compute number of CSV patches
        self.patch_nums = int((configs.seq_len - self.patch_len) / self.stride + 2)
        self.head_nf = self.d_ff * self.patch_nums
        if self.task_name in ['long_term_forecast', 'short_term_forecast']:
            self.output_projection = FlattenHead(
                configs.enc_in, self.head_nf, self.pred_len,
                head_dropout=configs.dropout)
        else:
            raise NotImplementedError

        self.normalize_layers = Normalize(configs.enc_in, affine=False)

    def forward(self, x_csv, x_pt, x_mark_enc, x_dec, x_mark_dec, mask=None):
        if self.task_name in ['long_term_forecast', 'short_term_forecast']:
            dec_out = self.forecast(x_csv, x_pt, x_mark_enc, x_dec, x_mark_dec)
            return dec_out[:, -self.pred_len:, :]
        return None

    def forecast(self, x_csv, x_pt, x_mark_enc, x_dec, x_mark_dec):
        # x_csv: [B, T, csv_dim]; x_pt: [B, T, pt_dim]
        # Normalize both CSV and video inputs
        x_csv = self.normalize_layers(x_csv, 'norm')
        x_pt = self.normalize_layers(x_pt, 'norm')
        B, T, _ = x_csv.size()

        # Compute CSV stats for prompt
        csv_for_stats = x_csv.permute(0, 2, 1).contiguous().reshape(
            B * x_csv.shape[-1], T, 1)
        min_values = torch.min(csv_for_stats, dim=1)[0]
        max_values = torch.max(csv_for_stats, dim=1)[0]
        medians = torch.median(csv_for_stats, dim=1).values
        lags = self.calcute_lags(csv_for_stats)
        trends = x_csv.permute(0, 2, 1).contiguous().reshape(
            B * x_csv.shape[-1], T, 1).diff(dim=1).sum(dim=1)
        prompt = []
        for b in range(csv_for_stats.shape[0]):
            min_str = str(min_values[b].tolist()[0])
            max_str = str(max_values[b].tolist()[0])
            median_str = str(medians[b].tolist()[0])
            lags_str = str(lags[b].tolist())
            prompt_.join = (
                f"<|start_prompt|>Dataset description: {self.description}. "
                f"Task: Forecast the next {self.pred_len} steps given the previous "
                f"{self.seq_len} steps; Input statistics: min value {min_str}, "
                f"max value {max_str}, median value {median_str}, "
                f"trend is {'upward' if trends[b] > 0 else 'downward'}, "
                f"top 5 lags: {lags_str}<|<end_prompt>|>"
            )
            prompt.append(prompt_)
        # Tokenize prompt
        prompt_tokens = self.tokenizer(
            prompt, return_tensors="pt", padding=True,
            truncation=True, max_length=2048).input_ids
        prompt_embeddings = self.llm_model.get_input_embeddings()(
            prompt_tokens.to(x_csv.device))

        # CSV patch embedding
        csv_input = x_csv.permute(0, 2, 1).contiguous()
        csv_emb, n_vars_csv = self.csv_patch_embedding(csv_input)

        # Video patch embedding
        pt_input = x_pt.permute(0, 2, 1).contiguous()
        pt_emb, n_vars_pt = self.pt_patch_embedding(pt_input)

        # Reprogram CSV and video embeddings
        source_embeddings = (
            self.mapping_layer(self.word_embeddings.permute(1, 0))
            .permute(1, 0))
        csv_emb = self.reprogramming_layer(
            csv_emb, source_embeddings, source_embeddings)
        pt_emb = self.reprogramming_layer(
            pt_emb, source_embeddings, source_embeddings)

        # Concatenate prompt, video, and CSV embeddings
        combined = torch.cat([prompt_embeddings, pt_emb, csv_emb], dim=1)
        total_prompt = prompt_embeddings.shape[1]
        total_pt = pt_emb.shape[1]
        # Pass through frozen LLM
        llm_out = self.llm_model(inputs_embeds=combined).last_hidden_state
        # Keep only CSV part
        csv_out = llm_out[:, total_prompt + total_pt:, :]
        csv_out = csv_out[:, :, :self.d_ff]
        csv_out = csv_out.view(-1, n_vars_csv, csv_out.size(-2), csv_out.size(-1))
        csv_out = csv_out.permute(0, 1, 3, 2).contiguous()
        csv_out = self.output_projection(
            csv_out[:, :, :, -self.patch_nums:])
        csv_out = csv_out.permute(0, 2, 1).contiguous()
        csv_out = self.normalize_layers(csv_out, 'denorm')
        return csv_out

    def calcute_lags(self, x_enc):
        q_fft = torch.fft.rfft(
            x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(
            x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)
        mean_value = torch.mean(corr, dim=1)
        _, lags = torch.topk(mean_value, self.top_k, dim=-1)
        return lags


#########################################
# Utility functions and simple implementations
#########################################
def del_files(path):
    """Delete all files in the specified directory."""
    for file in os.listdir(path):
        file_path = os.path.join(path, file)
        if os.path.isfile(file_path):
            os.remove(file_path)

class EarlyStopping:
    def __init__(self, accelerator, patience=10):
        self.accelerator = accelerator
        self.patience = patience
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, loss, model, path):
        if self.best_loss is None:
            self.best_loss = loss
            self.save_checkpoint(model, path)
        elif loss > self.best_loss:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = loss
            self.save_checkpoint(model, path)
            self.counter = 0

    def save_checkpoint(self, model, path):
        torch.save(model.state_dict(), os.path.join(path, 'checkpoint.pt'))

def adjust_learning_rate(accelerator, optimizer, scheduler, epoch, args, printout=True):
    if printout:
        accelerator.print(f"Current learning rate: {optimizer.param_groups[0]['lr']:.10f}")
    scheduler.step()

def vali(args, accelerator, model, data, data_loader, criterion, mae_metric):
    model.eval()
    total_loss = 0
    total_mae = 0
    count = 0
    with torch.no_grad():
        for batch in data_loader:
            batch_csv, batch_pt, batch_y, batch_x_mark, batch_y_mark = batch
            batch_csv = batch_csv.float().to(accelerator.device)
            batch_pt = batch_pt.float().to(accelerator.device)
            batch_y = batch_y.float().to(accelerator.device)
            batch_x_mark = batch_x_mark.float().to(accelerator.device)
            batch_y_mark = batch_y_mark.float().to(accelerator.device)
            dec_inp = torch.zeros_like(batch_y[:, -args.pred_len:, :]).float().to(accelerator.device)
            dec_inp = torch.cat([batch_y[:, :args.label_len, :], dec_inp], dim=1)
            outputs = model(batch_csv, batch_pt, batch_x_mark, dec_inp, batch_y_mark)
            outputs = outputs[:, -args.pred_len:, 0:]
            batch_y = batch_y[:, -args.pred_len:, 0:]
            loss = criterion(outputs, batch_y)
            mae = mae_metric(outputs, batch_y)
            total_loss += loss.item()
            total_mae += mae.item()
            count += 1
    return total_loss / count, total_mae / count

def load_content(args):
    """Load prompt description content; returns example string."""
    return "Example prompt content describing the domain."


#########################################
# Main training loop
#########################################
if __name__ == '__main__':
    os.environ['CURL_CA_BUNDLE'] = ''
    os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"

    parser = argparse.ArgumentParser(description='Time-LLM Time Series Forecasting')
    fix_seed = 2021
    random.seed(fix_seed)
    torch.manual_seed(fix_seed)
    np.random.seed(fix_seed)

    # Required arguments
    parser.add_argument('--task_name', type=str, default='long_term_forecast', help='Task name')
    parser.add_argument('--is_training', type=int, default=1, help='Training flag, 1 means training')
    parser.add_argument('--model_id', type=str, default='ETTh1_512_96', help='Model ID')
    parser.add_argument('--model_comment', type=str, default='TimeLLM-ETTh1', help='Model comment')
    parser.add_argument('--model', type=str, default='TimeLLM', help='Model name; only TimeLLM supported')

    # Data arguments (only ETTh1 dataset supported)
    parser.add_argument('--data', type=str, default='PKU_Skeleton_Renew_preprocessed_0002-L.csv', help='Dataset name (only ETTh1)')
    parser.add_argument('--root_path', type=str, default='/', help='Root directory for data')
    parser.add_argument('--data_path', type=str, default='PKU_Skeleton_Renew_preprocessed_0002-L.csv', help='Data file name')
    parser.add_argument('--freq', type=str, default='h', help='Time feature encoding frequency')
    parser.add_argument('--checkpoints', type=str, default='./checkpoints/', help='Directory to save checkpoints')
    parser.add_argument('--scale', type=bool, default=True, help='Whether to standardize data')

    # Forecasting task arguments
    parser.add_argument('--seq_len', type=int, default=512, help='Input sequence length')
    parser.add_argument('--label_len', type=int, default=48, help='Label length')
    parser.add_argument('--pred_len', type=int, default=30, help='Prediction length')
    parser.add_argument('--seasonal_patterns', type=str, default='Monthly', help='Seasonal patterns')

    # Model architecture arguments
    parser.add_argument('--enc_in', type=int, default=75, help='Encoder input size')
    parser.add_argument('--dec_in', type=int, default=75, help='Decoder input size')
    parser.add_argument('--c_out', type=int, default=75, help='Output size')
    parser.add_argument('--d_model', type=int, default=32, help='Model dimension')
    parser.add_argument('--n_heads', type=int, default=8, help='Number of attention heads')
    parser.add_argument('--e_layers', type=int, default=2, help='Number of encoder layers')
    parser.add_argument('--d_layers', type=int, default=1, help='Number of decoder layers')
    parser.add_argument('--d_ff', type=int, default=128, help='Feedforward dimension')
    parser.add_argument('--moving_avg', type=int, default=25, help='Moving average window size')
    parser.add_argument('--factor', type=int, default=3, help='Attention factor')
    parser.add_argument('--dropout', type=float, default=0.1, help='Dropout ratio')
    parser.add_argument('--embed', type=str, default='timeF', help='Time feature encoding type')
    parser.add_argument('--activation', type=str, default='gelu', help='Activation function')
    parser.add_argument('--output_attention', action='store_true', help='Whether to output encoder attention')
    parser.add_argument('--patch_len', type=int, default=16, help='Patch length')
    parser.add_argument('--stride', type=int, default=8, help='Patch stride')
    parser.add_argument('--prompt_domain', type=int, default=0, help='Whether to use custom prompt domain')
    parser.add_argument('--llm_model', type=str, default='BERT', help='LLM model, fixed to BERT')
    parser.add_argument('--llm_dim', type=int, default=768, help='LLM hidden dimension')
    parser.add_argument('--llm_layers', type=int, default=12, help='LLM number of layers')

    # Optimization arguments
    parser.add_argument('--num_workers', type=int, default=10, help='Data loader workers')
    parser.add_argument('--itr', type=int, default=1, help='Number of experiments')
    parser.add_argument('--train_epochs', type=int, default=20, help='Number of training epochs')
    parser.add_argument('--align_epochs', type=int, default=10, help='Number of alignment epochs')
    parser.add_argument('--batch_size', type=int, default=24, help='Training batch size')
    parser.add_argument('--eval_batch_size', type=int, default=8, help='Evaluation batch size')
    parser.add_argument('--patience', type=int, default=10, help='Early stopping patience')
    parser.add_argument('--learning_rate', type=float, default=0.01, help='Learning rate')
    parser.add_argument('--des', type=str, default='Exp', help='Experiment description')
    parser.add_argument('--loss', type=str, default='MSE', help='Loss function')
    parser.add_argument('--lradj', type=str, default='type1', help='Learning rate adjustment method')
    parser.add_argument('--pct_start', type=float, default=0.2, help='OneCycleLR pct_start')
    parser.add_argument('--use_amp', action='store_true', default=False, help='Use mixed precision training')
    parser.add_argument('--percent', type=int, default=100, help='Percentage of data usage')

    # Use parse_known_args() to ignore extra Colab args (e.g., -f)
    args, unknown = parser.parse_known_args()

    # Initialize Accelerator without DeepSpeed
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    accelerator = Accelerator(kwargs_handlers=[ddp_kwargs])

    # On the local main process, create logs and override accelerator.print
    log_f = None
    if accelerator.is_local_main_process:
        os.makedirs(args.checkpoints, exist_ok=True)
        log_file = os.path.join(args.checkpoints, "training_log.txt")
        log_f = open(log_file, "w", encoding="utf-8")
        original_print = accelerator.print
        def log_print(*args, **kwargs):
            msg = " ".join(str(arg) for arg in args)
            original_print(*args, **kwargs)
            log_f.write(msg + "\n")
            log_f.flush()
        accelerator.print = log_print

    for ii in range(args.itr):
        setting = (
            f"{args.task_name}_{args.model_id}_{args.model}"
            f"_sl{args.seq_len}_ll{args.label_len}_pl{args.pred_len}"
            f"_dm{args.d_model}_nh{args.n_heads}_el{args.e_layers}"
            f"_dl{args.d_layers}_df{args.d_ff}_fc{args.factor}"
            f"_eb{args.embed}_{args.des}_{ii}"
        )

        train_data, train_loader = data_provider(args, 'train')
        vali_data, vali_loader = data_provider(args, 'val')
        test_data, test_loader = data_provider(args, 'test')

        if args.model == 'TimeLLM':
            model = Model(args).float()
        else:
            raise NotImplementedError("Only TimeLLM model is supported")

        path = os.path.join(args.checkpoints, setting + '-' + args.model_comment)
        args.content = load_content(args)
        if not os.path.exists(path) and accelerator.is_local_main_process:
            os.makedirs(path)

        time_now = time.time()
        train_steps = len(train_loader)
        early_stopping = EarlyStopping(accelerator=accelerator, patience=args.patience)

        trained_parameters = [p for p in model.parameters() if p.requires_grad]
        model_optim = optim.Adam(trained_parameters, lr=args.learning_rate)

        if args.lradj == 'COS':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                model_optim, T_max=20, eta_min=1e-8)
        else:
            scheduler = lr_scheduler.OneCycleLR(
                optimizer=model_optim,
                steps_per_epoch=train_steps,
                pct_start=args.pct_start,
                epochs=args.train_epochs,
                max_lr=args.learning_rate)

        criterion = nn.MSELoss()
        mae_metric = nn.L1Loss()

        # Prepare with Accelerator (unpack 5 parts: csv, pt, y, x_mark, y_mark)
        train_loader, vali_loader, test_loader, model, model_optim, scheduler = accelerator.prepare(
            train_loader, vali_loader, test_loader, model, model_optim, scheduler)

        if args.use_amp:
            scaler = torch.cuda.amp.GradScaler()

        for epoch in range(args.train_epochs):
            iter_count = 0
            train_loss = []
            model.train()
            epoch_time = time.time()
            for i, (batch_csv, batch_pt, batch_y, batch_x_mark, batch_y_mark) in tqdm(enumerate(train_loader)):
                iter_count += 1
                model_optim.zero_grad()
                batch_csv = batch_csv.float().to(accelerator.device)
                batch_pt = batch_pt.float().to(accelerator.device)
                batch_y = batch_y.float().to(accelerator.device)
                batch_x_mark = batch_x_mark.float().to(accelerator.device)
                batch_y_mark = batch_y_mark.float().to(accelerator.device)
                dec_inp = torch.zeros_like(batch_y[:, -args.pred_len:, :]).float().to(accelerator.device)
                dec_inp = torch.cat([batch_y[:, :args.label_len, :], dec_inp], dim=1)
                if args.use_amp:
                    with torch.cuda.amp.autocast():
                        outputs = model(
                            batch_csv, batch_pt, batch_x_mark,
                            dec_inp, batch_y_mark)
                        outputs = outputs[:, -args.pred_len:, 0:]
                        batch_y = batch_y[:, -args.pred_len:, 0:]
                        loss = criterion(outputs, batch_y)
                        train_loss.append(loss.item())
                else:
                    outputs = model(
                        batch_csv, batch_pt, batch_x_mark,
                        dec_inp, batch_y_mark)
                    outputs = outputs[:, -args.pred_len:, 0:]
                    batch_y = batch_y[:, -\
args.pred_len:, 0:]
                    loss = criterion(outputs, batch_y)
                    train_loss.append(loss.item())

                if (i + 1) % 100 == 0:
                    accelerator.print(
                        f"\titers: {i+1}, epoch: {epoch+1} | "
                        f"loss: {loss.item():.7f}")
                    speed = (time.time() - time_now) / iter_count
                    left_time = speed * (
                        (args.train_epochs - epoch) * train_steps - i)
                    accelerator.print(
                        f'\tspeed: {speed:.4f}s/iter; remaining time: {left_time:.4f}s')
                    iter_count = 0
                    time_now = time.time()

                if args.use_amp:
                    scaler.scale(loss).backward()
                    scaler.step(model_optim)
                    scaler.update()
                else:
                    accelerator.backward(loss)
                    model_optim.step()

                if args.lradj == 'TST':
                    adjust_learning_rate(
                        accelerator, model_optim, scheduler,
                        epoch+1, args, printout=False)
                    scheduler.step()

            accelerator.print(
                f"Epoch: {epoch+1} time: {time.time()-epoch_time:.2f}s")
            train_loss_avg = np.mean(train_loss)
            vali_loss, vali_mae_loss = vali(
                args, accelerator, model, vali_data,
                vali_loader, criterion, mae_metric)
            test_loss, test_mae_loss = vali(
                args, accelerator, model, test_data,
                test_loader, criterion, mae_metric)
            accelerator.print(
                f"Epoch: {epoch+1} | Train Loss: {train_loss_avg:.7f} | "
                f"Val Loss: {vali_loss:.7f} | Test Loss: {test_loss:.7f} | "
                f"MAE Loss: {vali_mae_loss:.7f}")
            early_stopping(vali_loss, model, path)
            if early_stopping.early_stop:
                accelerator.print("Early stopping triggered")
                break

            if args.lradj != 'TST':
                if args.lradj == 'COS':
                    scheduler.step()
                    accelerator.print(
                        f"Current learning rate: "
                        f"{model_optim.param_groups[0]['lr']:.10f}")
                else:
                    if epoch == 0:
                        args.learning_rate = model_optim.param_groups[0]['lr']
                        accelerator.print(
                            f"Current learning rate: "
                            f"{model_optim.param_groups[0]['lr']:.10f}")
                    adjust_learning_rate(
                        accelerator, model_optim, scheduler,
                        epoch+1, args, printout=True)
            else:
                accelerator.print(
                    f"Updated learning rate: {scheduler.get_last_lr()[0]}")

        accelerator.wait_for_everyone()
        accelerator.print(
            'Training complete, all files retained.')

    if log_f is not None:
        log_f.close()
