In [1]:
import os
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
os.environ['NCCL_P2P_DISABLE'] = '1'
os.environ['NCCL_IB_DISABLE'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, Seq2SeqTrainer, EarlyStoppingCallback, get_scheduler
from transformers import GPT2Tokenizer, GPT2Model, TrainingArguments, Trainer
import torch
import warnings
import numpy as np
warnings.filterwarnings("ignore")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  

In [None]:
ds_train = np.load("./solar_1_48_24/train_dataset.npz", allow_pickle=True)
ds_test = np.load("./solar_1_48_24/test_dataset.npz", allow_pickle=True)

In [3]:
train_input = ds_train['train_input']
train_input_time = ds_train['train_input_time']
train_retrival = ds_train['train_retrival']
train_retrival_time = ds_train['train_retrival_time']
train_retrival_label = ds_train['train_retrival_label']
train_retrival_label_time = ds_train['train_retrival_label_time']
train_label = ds_train['train_label']
train_label_time = ds_train['train_label_time']

In [4]:
test_input = ds_test['test_input']
test_input_time = ds_test['test_input_time']
test_retrival = ds_test['test_retrival']
test_retrival_time = ds_test['test_retrival_time']
test_retrival_label = ds_test['test_retrival_label']
test_retrival_label_time = ds_test['test_retrival_label_time']
test_label = ds_test['test_label']
test_label_time = ds_test['test_label_time']

In [None]:
import torch  
import numpy as np  

train_input_tensor = torch.tensor(train_input, dtype=torch.float32)
train_retrival_tensor = torch.tensor(train_retrival, dtype=torch.float32)
train_retrival_label_tensor = torch.tensor(train_retrival_label, dtype=torch.float32)
train_label_tensor = torch.tensor(train_label, dtype=torch.float32)
test_input_tensor = torch.tensor(test_input, dtype=torch.float32)
test_retrival_tensor = torch.tensor(test_retrival, dtype=torch.float32)
test_retrival_label_tensor = torch.tensor(test_retrival_label, dtype=torch.float32)
test_label_tensor = torch.tensor(test_label, dtype=torch.float32)

print("train_input_tensor shape:", train_input_tensor.shape)
print("train_retrival_tensor shape:", train_retrival_tensor.shape)
print("train_retrival_label_tensor shape:", train_retrival_label_tensor.shape)
print("train_label_tensor shape:", train_label_tensor.shape)
print("test_input_tensor shape:", test_input_tensor.shape)
print("test_retrival_tensor shape:", test_retrival_tensor.shape)
print("test_retrival_label_tensor shape:", test_retrival_label_tensor.shape)
print("test_label_tensor shape:", test_label_tensor.shape)

train_input_tensor shape: torch.Size([5192, 48])
train_retrival_tensor shape: torch.Size([5192, 48])
train_retrival_label_tensor shape: torch.Size([5192, 24])
train_label_tensor shape: torch.Size([5192, 24])
test_input_tensor shape: torch.Size([3438, 48])
test_retrival_tensor shape: torch.Size([3438, 48])
test_retrival_label_tensor shape: torch.Size([3438, 24])
test_label_tensor shape: torch.Size([3438, 24])


In [6]:
from datetime import datetime
import torch.nn as nn
class TimeEncoder(nn.Module):
    def __init__(self, d_model=768):
        super().__init__()
        self.d_model = d_model
        self.time_feature_dim = 6
        
       
    def encode_timestamps(self, timestamps):
        batch_size, seq_len = timestamps.shape
        device = timestamps.device

        time_features = torch.zeros(batch_size, seq_len, self.time_feature_dim, device=device)
        timestamps_np = timestamps.cpu().numpy()

        for b in range(batch_size):
            for s in range(seq_len):
                timestamp = timestamps_np[b, s]
                dt = datetime.fromtimestamp(timestamp)

                hour = dt.hour
                day_of_year = dt.timetuple().tm_yday
                weekday = dt.weekday()

                hour_sin = np.sin(2 * np.pi * hour / 24)
                hour_cos = np.cos(2 * np.pi * hour / 24)
                day_of_year_sin = np.sin(2 * np.pi * day_of_year / 366)
                day_of_year_cos = np.cos(2 * np.pi * day_of_year / 366)
                weekday_sin = np.sin(2 * np.pi * weekday / 7)
                weekday_cos = np.cos(2 * np.pi * weekday / 7)

                time_features[b, s] = torch.tensor([
                    hour_sin, hour_cos, 
                    day_of_year_sin, day_of_year_cos,
                    weekday_sin, weekday_cos, 
                ], device=device, dtype=torch.float32)

        return time_features
    
    def forward(self, timestamps):

        time_features = self.encode_timestamps(timestamps)  # (batch, seq_len, 6)
        
        return time_features
    


In [7]:
import torch
import torch.nn as nn
import numpy as np

class DualStreamPatchModel(nn.Module):
    def __init__(self, patch_size=8, d_model=128, num_heads=8):
        super().__init__()
        self.patch_size = patch_size
        
        self.pv_stream = PVPatchStream(patch_size, d_model)
        self.time_stream = TimePatchStream(patch_size, d_model)
        
        self.cross_modal_attention = nn.MultiheadAttention(
            embed_dim=d_model, num_heads=num_heads, batch_first=True
        )
        
        self.fusion_layer = nn.Linear(d_model * 2, d_model)
        
    def forward(self, pv_data, time_features):
        pv_patches = self.pv_stream(pv_data)
        time_patches = self.time_stream(time_features)
        
        pv_attended, _ = self.cross_modal_attention(pv_patches, time_patches, time_patches)
        time_attended, _ = self.cross_modal_attention(time_patches, pv_patches, pv_patches)
        
        fused = torch.cat([pv_attended, time_attended], dim=-1)
        final_features = self.fusion_layer(fused)
        
        return final_features

class PVPatchStream(nn.Module):
    def __init__(self, patch_size, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.patch_embedding = nn.Linear(patch_size, d_model)
        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, d_model))
        
    def forward(self, pv_data):
        batch_size, seq_len = pv_data.shape
        num_patches = seq_len // self.patch_size
        
        patches = pv_data[:, :num_patches*self.patch_size].reshape(
            batch_size, num_patches, self.patch_size
        )
        embedded = self.patch_embedding(patches)
        embedded = embedded + self.pos_embedding[:, :num_patches, :]
        return embedded

class TimePatchStream(nn.Module):
    def __init__(self, patch_size, d_model):
        super().__init__()
        self.patch_size = patch_size
        self.patch_embedding = nn.Linear(patch_size * 6, d_model)
        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, d_model))
        self.time_encoder = TimeEncoder(d_model=d_model)
        
    def forward(self, time_features):
        time_features = self.time_encoder(time_features) 
        batch_size, seq_len, time_dim = time_features.shape
        num_patches = seq_len // self.patch_size
        
        patches = time_features[:, :num_patches*self.patch_size, :].reshape(
            batch_size, num_patches, self.patch_size * time_dim
        )
        embedded = self.patch_embedding(patches)
        embedded = embedded + self.pos_embedding[:, :num_patches, :]
        return embedded
    


In [None]:
class SimpleTimeSeriesBranch(nn.Module):
    def __init__(self, input_len=48, pred_len=24):
        super().__init__()
        self.shortcut = nn.Sequential(
            nn.Linear(input_len, pred_len * 2),
            nn.ReLU(),
            nn.Linear(pred_len * 2, pred_len)
        )
        
    def forward(self, hist_data):
        return self.shortcut(hist_data)
    
class ConstrainedReLU(nn.Module):
    def __init__(self, max_val=50):
        super().__init__()
        self.max_val = max_val
    
    def forward(self, x):
        x = torch.relu(x)
        return self.max_val * torch.tanh(x / self.max_val)

In [None]:
from transformers import GPT2Model, GPT2Tokenizer

class LLMPVForecaster(nn.Module):
    def __init__(self, patch_len=8, pred_len=24, freeze_llm=True, residual_hidden_dim=256):
        super().__init__()
        self.patch_len = patch_len
        self.pred_len = pred_len
        
        self.tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
        
        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token
            
        self.llm = GPT2Model.from_pretrained('gpt2')
        
        special_tokens = ['<hist>', '<retrieval>', '<predict>', '<sep>', '<pad>']
        self.tokenizer.add_special_tokens({'additional_special_tokens': special_tokens})
        self.llm.resize_token_embeddings(len(self.tokenizer))
        
        if freeze_llm:
            for param in self.llm.parameters():
                param.requires_grad = False

        d_model = self.llm.config.hidden_size
        
        self.dual_stream_patch_embedding = DualStreamPatchModel(patch_len, d_model)

        self.pos_embedding = nn.Parameter(torch.randn(1, 1000, d_model))
        
        self.llm_regression_head = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model // 2, pred_len)
        )
        
        self.ts_branch = SimpleTimeSeriesBranch(input_len=24, pred_len=pred_len)
        
        self.alpha = nn.Parameter(torch.tensor(0.8), requires_grad=True)
        self.beta = nn.Parameter(torch.tensor(0.2), requires_grad=True)
        self.final_sigmoid = nn.Sigmoid()
        self.final_relu = nn.ReLU()
        self.final_activation = ConstrainedReLU(max_val=50)

    def create_interleaved_input(self, hist_data, hist_timestamps, retrieval_data, retrieval_timestamps, retrieval_next, retrieval_next_timestamps):

        batch_size = hist_data.shape[0]
        device = hist_data.device
        

        text_prompts = [
            "<hist> Historical PV power data sequence:",
            "<retrieval> Similar historical pattern found:",
            "<predict> Expected continuation pattern:",
            "Based on the above information, predict the next 24 hours PV output:"
        ]

        hist_patches = self.dual_stream_patch_embedding(pv_data=hist_data, time_features=hist_timestamps)  # (batch, n_patches, d_model)32, 6, 768
        retrieval_patches = self.dual_stream_patch_embedding(pv_data=retrieval_data, time_features=retrieval_timestamps) #32, 6, 768
        retrieval_next_patches = self.dual_stream_patch_embedding(pv_data=retrieval_next, time_features=retrieval_next_timestamps)  #32, 3, 768
        
        interleaved_embeddings = []
        
        text1_tokens = self.tokenizer.encode(text_prompts[0], return_tensors='pt').to(device) 
        text1_tokens = text1_tokens.repeat(batch_size, 1) 
        text1_embeddings = self.llm.wte(text1_tokens) 
        interleaved_embeddings.append(text1_embeddings)
        interleaved_embeddings.append(hist_patches)
        
        text2_tokens = self.tokenizer.encode(text_prompts[1], return_tensors='pt').to(device)
        text2_tokens = text2_tokens.repeat(batch_size, 1)
        text2_embeddings = self.llm.wte(text2_tokens)
        interleaved_embeddings.append(text2_embeddings)
        interleaved_embeddings.append(retrieval_patches)
        
        text3_tokens = self.tokenizer.encode(text_prompts[2], return_tensors='pt').to(device)
        text3_tokens = text3_tokens.repeat(batch_size, 1)
        text3_embeddings = self.llm.wte(text3_tokens)
        interleaved_embeddings.append(text3_embeddings)
        interleaved_embeddings.append(retrieval_next_patches)
        
        text4_tokens = self.tokenizer.encode(text_prompts[3], return_tensors='pt').to(device)
        text4_tokens = text4_tokens.repeat(batch_size, 1)
        text4_embeddings = self.llm.wte(text4_tokens)
        interleaved_embeddings.append(text4_embeddings)
        
        all_embeddings = torch.cat(interleaved_embeddings, dim=1) 
        return all_embeddings
    
    def forward(self, hist_data, hist_timestamps, retrieval_data, retrieval_timestamps, retrieval_next, retrieval_next_timestamps):

        batch_size = hist_data.shape[0]
        

        all_embeddings = self.create_interleaved_input(hist_data, hist_timestamps, retrieval_data, retrieval_timestamps, retrieval_next, retrieval_next_timestamps)
        
        seq_len = all_embeddings.shape[1]
        pos_emb = self.pos_embedding[:, :seq_len, :].expand(batch_size, -1, -1)
        all_embeddings = all_embeddings + pos_emb

        for block in self.llm.h:  
            outputs = block(all_embeddings)  
            all_embeddings = outputs[0]

        hidden_states = all_embeddings

        final_representation = hidden_states[:, -1, :]
        llm_predictions = self.llm_regression_head(final_representation) 
        
        ts_predictions = self.ts_branch(retrieval_next) 
        

        final_predictions = self.alpha * llm_predictions + self.beta * ts_predictions

        final_predictions = self.final_activation(final_predictions)
        
        return final_predictions


In [None]:
def visualize_input_structure(model, hist_data, retrieval_data, retrieval_next):

    batch_size = hist_data.shape[0]
    device = hist_data.device
        
    all_embeddings = model.create_interleaved_input(hist_data, retrieval_data, retrieval_next)
    
    text_prompts = [
        "<hist> Historical PV power data sequence:",
        "<retrieval> Similar historical pattern found:",
        "<predict> Expected continuation pattern:",
        "Based on the above information, predict the next 24 hours PV output:"
    ]
    
    current_pos = 0
    
    for i, prompt in enumerate(text_prompts):
        tokens = model.tokenizer.encode(prompt)
        text_len = len(tokens)
        
        print(f"position {current_pos}-{current_pos + text_len - 1}: text prompt{i+1} - '{prompt}'")
        current_pos += text_len
        
        if i < 3: 
            if i < 2:  
                data_patches = 48 // model.patch_len 
            else: 
                data_patches = 24 // model.patch_len 
            
            print(f"position {current_pos}-{current_pos + data_patches - 1}: data patch{i+1} ({data_patches} patches)")
            current_pos += data_patches
    
    print(f"\ntotal length: {all_embeddings.shape[1]}")
    print(f"dimension of embeddings: {all_embeddings.shape[2]}")
    
    return all_embeddings

def train_model(model, train_loader, val_loader, num_epochs=100, lr=5e-5, 
                patience=10, min_delta=1e-6, save_path='./pth/solar_1_48_24_gpt2_resnet.pth'):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    
    # 只训练非冻结的参数
    trainable_params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.AdamW(trainable_params, lr=lr, weight_decay=1e-5)
    criterion = nn.MSELoss()
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    
    # 早停相关变量
    best_val_loss = float('inf')
    epochs_without_improvement = 0
    early_stop = False
    
    print(f"training begin, patience: {patience} epochs")
    print(f"min_delta: {min_delta}")
    print("-" * 60)
    

    for epoch in range(num_epochs):
        if early_stop:
            print(f"early stopping in {epoch} epochs")
            break
            
        model.train()
        train_loss = 0.0
        

        for batch_idx, (hist,hist_timestamps, retrieval,retrieval_timestamps, retrieval_next, retrieval_next_timestamps, target, target_timestamps) in enumerate(train_loader):
            hist = hist.to(device).float()
            hist_timestamps = hist_timestamps.to(device).float()
            retrieval = retrieval.to(device).float()
            retrieval_timestamps = retrieval_timestamps.to(device).float()
            retrieval_next = retrieval_next.to(device).float()
            retrieval_next_timestamps = retrieval_next_timestamps.to(device).float()
            target = target.to(device).float()
            target_timestamps = target_timestamps.to(device).float()
            
            optimizer.zero_grad()
            
            predictions = model(hist, 
                                hist_timestamps,
                                retrieval, 
                                retrieval_timestamps,
                                retrieval_next,
                                retrieval_next_timestamps)
            
            loss = criterion(predictions, target)
            
            loss.backward()
            torch.nn.utils.clip_grad_norm_(trainable_params, max_norm=1.0)
            optimizer.step()
            
            train_loss += loss.item()
            
            if batch_idx % 100 == 0:
                print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx}], Loss: {loss.item():.6f}')
        
        model.eval()
        val_loss = 0.0
        
        with torch.no_grad():
            for hist,hist_timestamps, retrieval,retrieval_timestamps, retrieval_next, retrieval_next_timestamps, target, target_timestamps in val_loader:
                hist = hist.to(device).float()
                hist_timestamps = hist_timestamps.to(device).float()
                retrieval = retrieval.to(device).float()
                retrieval_timestamps = retrieval_timestamps.to(device).float()
                retrieval_next = retrieval_next.to(device).float()
                retrieval_next_timestamps = retrieval_next_timestamps.to(device).float()
                target = target.to(device).float()
                target_timestamps = target_timestamps.to(device).float()

                predictions = model(hist,
                                    hist_timestamps,
                                    retrieval, 
                                    retrieval_timestamps,
                                    retrieval_next,
                                    retrieval_next_timestamps)
                
                loss = criterion(predictions, target)
                val_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        print(f'Epoch [{epoch+1}/{num_epochs}]')
        print(f'Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}')
        
        improvement = best_val_loss - avg_val_loss
        
        if improvement > min_delta:
            best_val_loss = avg_val_loss
            epochs_without_improvement = 0
            torch.save(model.state_dict(), save_path)
            print(f'best model saved, best_val_loss: {best_val_loss:.6f}')
        else:
            epochs_without_improvement += 1
            print(f'Loss does not improve ({epochs_without_improvement}/{patience})')
            
            if epochs_without_improvement >= patience:
                print(f'Our of patience at {patience}, early stopping...')
                early_stop = True
        
        scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f'learning rate: {current_lr:.2e}')
        print('-' * 60)
    
    if early_stop:
        print(f"🛑 Training stopped early at epoch {epoch+1}")
    else:
        print(f"✅ Training completed for {num_epochs} epochs")
    
    print(f"🏆 Best validation loss: {best_val_loss:.6f}")
   
    print("📁 Best model saved as: " + save_path)
    
    return best_val_loss

In [11]:
model = LLMPVForecaster(patch_len=8, pred_len=24, freeze_llm=True)
print(model)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")

KeyboardInterrupt: 

In [None]:
import random

def set_seed(seed):  

    random.seed(seed)  
    np.random.seed(seed)  
    torch.manual_seed(seed)  
    if torch.cuda.is_available():  
        torch.cuda.manual_seed(seed)  
        torch.cuda.manual_seed_all(seed)  
        torch.backends.cudnn.deterministic = True  
        torch.backends.cudnn.benchmark = False  
    os.environ['PYTHONHASHSEED'] = str(seed)

seed = 2025  
set_seed(seed)  

def worker_init_fn(worker_id):  
    worker_seed = torch.initial_seed() % 2**32  
    np.random.seed(worker_seed)  
    random.seed(worker_seed)  

In [None]:
from torch.utils.data import DataLoader, TensorDataset

def create_data_loader(train_inputs, train_inputs_time, train_retrivals, train_retrivals_time, 
                      train_retrivals_next, train_retrivals_next_time, train_targets, train_targets_time,
                      batch_size=32, if_shuffle=True, worker_init_fn=None):

    
    train_inputs_tensor = torch.tensor(train_inputs, dtype=torch.float32)
    train_retrivals_tensor = torch.tensor(train_retrivals, dtype=torch.float32)
    train_retrivals_next_tensor = torch.tensor(train_retrivals_next, dtype=torch.float32)
    train_targets_tensor = torch.tensor(train_targets, dtype=torch.float32)
    
    train_inputs_time_tensor = prepare_timestamp_data(train_inputs_time)
    train_retrivals_time_tensor = prepare_timestamp_data(train_retrivals_time)
    train_retrivals_next_time_tensor = prepare_timestamp_data(train_retrivals_next_time)
    train_targets_time_tensor = prepare_timestamp_data(train_targets_time)
    
    shapes = [
        train_inputs_tensor.shape[0],
        train_inputs_time_tensor.shape[0],
        train_retrivals_tensor.shape[0],
        train_retrivals_time_tensor.shape[0],
        train_retrivals_next_tensor.shape[0],
        train_retrivals_next_time_tensor.shape[0],
        train_targets_tensor.shape[0],
        train_targets_time_tensor.shape[0]
    ]
    
    if not all(s == shapes[0] for s in shapes):
        raise ValueError(f"Tensor sizes do not match! First dimension sizes: {shapes}")
    
    dataset = TensorDataset(
        train_inputs_tensor,
        train_inputs_time_tensor,
        train_retrivals_tensor,
        train_retrivals_time_tensor,
        train_retrivals_next_tensor,
        train_retrivals_next_time_tensor,
        train_targets_tensor,
        train_targets_time_tensor
    )
    
    return DataLoader(dataset, batch_size=batch_size, shuffle=if_shuffle, worker_init_fn=worker_init_fn)

def prepare_timestamp_data(time_data):

    if isinstance(time_data, torch.Tensor):
        return time_data.float()
    
    if isinstance(time_data, (list, np.ndarray)):
        time_data = np.array(time_data)
        
        if time_data.ndim == 2:
            batch_size, seq_len = time_data.shape
            timestamps = np.zeros((batch_size, seq_len), dtype=np.float32)
            
            for i in range(batch_size):
                for j in range(seq_len):
                    time_str = time_data[i, j]
                    if isinstance(time_str, str):
                        dt = datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S')
                        timestamps[i, j] = dt.timestamp()
                    else:
                        timestamps[i, j] = float(time_str)
            
            return torch.tensor(timestamps, dtype=torch.float32)
        
        elif time_data.ndim == 1:
            timestamps = []
            for time_str in time_data:
                if isinstance(time_str, str):
                    dt = datetime.strptime(time_str, '%Y-%m-%d %H:%M:%S')
                    timestamps.append(dt.timestamp())
                else:
                    timestamps.append(float(time_str))
            return torch.tensor(timestamps, dtype=torch.float32)
    
    return torch.tensor(time_data, dtype=torch.float32)

In [None]:
train_input_time = prepare_timestamp_data(train_input_time)
train_retrival_time = prepare_timestamp_data(train_retrival_time)
train_retrival_label_time = prepare_timestamp_data(train_retrival_label_time)
train_label_time = prepare_timestamp_data(train_label_time)

test_input_time = prepare_timestamp_data(test_input_time)
test_retrival_time = prepare_timestamp_data(test_retrival_time)
test_retrival_label_time = prepare_timestamp_data(test_retrival_label_time)
test_label_time = prepare_timestamp_data(test_label_time)

train_loader = create_data_loader(train_input, 
                                  train_input_time,
                                  train_retrival, 
                                  train_retrival_time,
                                  train_retrival_label, 
                                  train_retrival_label_time,
                                  train_label, 
                                  train_label_time,
                                  batch_size=32,
                                  if_shuffle=True,
                                  worker_init_fn=worker_init_fn)

val_loader = create_data_loader(test_input, 
                                test_input_time,
                                test_retrival, 
                                test_retrival_time,
                                test_retrival_label, 
                                test_retrival_label_time,
                                test_label, 
                                test_label_time,
                                batch_size=32,
                                if_shuffle=False,
                                worker_init_fn=worker_init_fn) 


save_path = './pth/github_1.pth'

train_model(model, 
            train_loader, 
            val_loader,
            lr=1e-4, 
            num_epochs=1000,
            save_path=save_path,)


开始训练，早停容忍度: 10 epochs
最小改善阈值: 1e-06
------------------------------------------------------------
hidden_states shape: torch.Size([32, 48, 768])


TypeError: int() argument must be a string, a bytes-like object or a number, not 'list'