In [88]:
from transformers import AutoformerConfig, AutoformerModel
import os
import sys
sys.path.insert(1, '../src/')
from config import raw_data_path, univariate_data_path, processed_data_path, models_path
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import AutoformerConfig, AutoformerForPrediction
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split


In [89]:
data_file = os.path.join(univariate_data_path, 'merged_univariate.npy')
data = np.load(data_file, allow_pickle=True)
print(data)

[{'record_name': 'ice001_l_1of1', 'signal': array([[-1.7358303 ],
        [-0.30347557],
        [-0.40749874],
        ...,
        [-3.09738299],
        [-2.90981482],
        [-3.22768386]]), 'metadata': {'fs': 20, 'sig_len': 100000, 'n_sig': 16, 'base_date': None, 'base_time': None, 'units': ['mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV', 'mV'], 'comments': ['Info:', 'ID:ice001', 'Record type:labour', 'Record number:1/1', 'Age(years):31', 'BMI before pregnancy:23.3', 'BMI at recording:27.6', 'Gravidity:3', 'Parity:2', 'Previous caesarean:No', 'Placental position:Fundus', 'Gestational age at recording(w/d):39/3', 'Gestational age at delivery:39/3', 'Mode of delivery:Vaginal', 'Synthetic oxytocin use in labour:No', 'Epidural during labour:No', 'Comments for recording:', 'Electrodes placed 5-10 mins prior to beginning of recording.', 'Baby born 20 minutes after the end of the recording.']}}
 {'record_name': 'ice002_p_1of3', 'signal': array([

In [90]:
def create_windows(time_series, window_size, forecast_horizon, stride=1):
    """
    Create sliding windows for a single time series.

    Parameters:
    - time_series: A numpy array of shape [sequence_length, 1]
    - window_size: The size of each input window
    - forecast_horizon: The number of steps ahead to forecast
    - stride: The step size for sliding windows

    Returns:
    - windows: The input windows
    - forecasts: The target forecasts
    """
    windows = []
    forecasts = []

    # Ensure that time_series is a 1D array
    time_series = time_series.squeeze()

    # Slide the window across the time series
    for start in range(0, len(time_series) - window_size - forecast_horizon + 1, stride):
        end = start + window_size
        window = time_series[start:end]  # Extract the window
        forecast = time_series[end:end + forecast_horizon]  # Forecast horizon
        
        windows.append(window)
        forecasts.append(forecast)

    return torch.tensor(windows), torch.tensor(forecasts)

def create_all_windows(data, window_size, forecast_horizon, stride=1):
    """
    Create sliding windows for all instances in the dataset.

    Parameters:
    - data: List of time series instances, where each instance is a dictionary with keys:
      - 'record_name': The record name
      - 'signal': A numpy array of shape [sequence_length, 1]
    - window_size: The size of each input window
    - forecast_horizon: The number of steps ahead to forecast
    - stride: The step size for sliding windows

    Returns:
    - all_windows: List of input windows
    - all_forecasts: List of target forecasts
    """
    all_windows = []
    all_forecasts = []
    
    for instance in data:
        time_series = instance['signal']  # shape [sequence_length, 1]
        windows, forecasts = create_windows(time_series, window_size, forecast_horizon, stride)
        
        all_windows.append(windows)
        all_forecasts.append(forecasts)
    
    # Stack the results
    all_windows = torch.cat(all_windows, dim=0)
    all_forecasts = torch.cat(all_forecasts, dim=0)
    
    return all_windows, all_forecasts

def split_data(windows, forecasts, train_size=0.7, val_size=0.15, test_size=0.15):
    assert train_size + val_size + test_size == 1, "Sizes should add up to 1"
    
    # Random split based on indices
    total_size = len(windows)
    train_end = int(train_size * total_size)
    val_end = train_end + int(val_size * total_size)
    
    X_train, X_val, X_test = windows[:train_end], windows[train_end:val_end], windows[val_end:]
    y_train, y_val, y_test = forecasts[:train_end], forecasts[train_end:val_end], forecasts[val_end:]
    
    return X_train, X_val, X_test, y_train, y_val, y_test

class TimeSeriesDataset(Dataset):
    def __init__(self, windows, forecasts):
        self.windows = windows
        self.forecasts = forecasts

    def __len__(self):
        return len(self.windows)

    def __getitem__(self, idx):
        return self.windows[idx], self.forecasts[idx]

# Define parameters
window_size = 12000  
forecast_horizon = 1200  
stride = 12000  

# Create sliding windows for all instances
all_windows, all_forecasts = create_all_windows(data, window_size, forecast_horizon, stride)

# Split data into train, validation, and test sets
X_train, X_val, X_test, y_train, y_val, y_test = split_data(all_windows, all_forecasts)
print('train size: ', len(X_train))
print('test size: ', len(X_test))
print('val size: ', len(X_val))
print('input shape: ', X_train[0].shape)
print('output shape: ', y_train[0].shape)


# Create datasets and DataLoaders
train_loader = torch.utils.data.DataLoader(TimeSeriesDataset(X_train, y_train), batch_size=64, shuffle=True)
val_loader = torch.utils.data.DataLoader(TimeSeriesDataset(X_val, y_val), batch_size=64, shuffle=False)
test_loader = torch.utils.data.DataLoader(TimeSeriesDataset(X_test, y_test), batch_size=64, shuffle=False)


train size:  1069
test size:  230
val size:  229
input shape:  torch.Size([12000])
output shape:  torch.Size([1200])


In [91]:


# class DecompositionBlock(nn.Module):
#     def __init__(self, window_size):
#         super(DecompositionBlock, self).__init__()
#         self.window_size = window_size
        
#         # Convolution layer to compute the moving average (smoothing) for trend extraction
#         self.avg_pool = nn.AvgPool1d(kernel_size=window_size, stride=1, padding=window_size // 2)

#     def forward(self, x):
#         batch_size, time_steps = x.shape  # x should have shape [batch_size, time_steps]
        
#         # Apply the moving average to extract the trend
#         trend = self.avg_pool(x.unsqueeze(1)).squeeze(1)  # Add and remove the extra channel dimension
        
#         # Trim the trend to match the original signal length (if necessary)
#         trend = trend[:, :time_steps]
        
#         # Compute the seasonality as the difference between the original signal and the trend
#         seasonality = x - trend
        
#         return trend, seasonality


In [92]:
import torch
import torch.nn as nn

class DecompositionLayer(nn.Module):
    """
    Decomposes the time series into its trend and seasonal parts.
    """

    def __init__(self, kernel_size):
        super(DecompositionLayer, self).__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=1, padding=0)  # Moving average

    def forward(self, x):
        """Input shape: Batch x Time x EMBED_DIM"""
        # Padding on both ends of the time series
        num_of_pads = (self.kernel_size - 1) // 2
        front = x[:, 0:1, :].repeat(1, num_of_pads, 1)
        end = x[:, -1:, :].repeat(1, num_of_pads, 1)
        x_padded = torch.cat([front, x, end], dim=1)

        # Calculate the trend and seasonal part of the series
        x_trend = self.avg(x_padded.permute(0, 2, 1)).permute(0, 2, 1)  # Moving average for trend
        x_seasonal = x - x_trend  # Seasonality is the residual
        return x_seasonal, x_trend


In [93]:
# # Instantiate the decomposition layer
# kernel_size = 25  # Feel free to tweak
# decomp = DecompositionLayer(kernel_size=kernel_size)

# # Select a sample from the training set and reshape to [1, seq_len, 1]
# sample = X_train[0].unsqueeze(0).unsqueeze(-1)  # Shape: [1, 12000, 1]

# # Pass through decomposition
# seasonal, trend = decomp(sample)

# # Convert tensors to numpy arrays
# original_signal = sample.squeeze().detach().cpu().numpy()
# trend_signal = trend.squeeze().detach().cpu().numpy()
# seasonal_signal = seasonal.squeeze().detach().cpu().numpy()

# # Plot each component
# plt.figure(figsize=(14, 8))

# plt.subplot(3, 1, 1)
# plt.plot(original_signal, label='Original Signal', color='tab:blue')
# plt.title("Original Signal")
# plt.ylabel("Amplitude")
# plt.grid(True)

# plt.subplot(3, 1, 2)
# plt.plot(trend_signal, label='Trend (Moving Avg)', color='tab:green')
# plt.title("Trend Component")
# plt.ylabel("Amplitude")
# plt.grid(True)

# plt.subplot(3, 1, 3)
# plt.plot(seasonal_signal, label='Seasonal Component', color='tab:orange')
# plt.title("Seasonality Component")
# plt.xlabel("Time")
# plt.ylabel("Amplitude")
# plt.grid(True)

# plt.tight_layout()
# plt.show()



In [4]:
# def autocorrelation(query_states, key_states):
#     """
#     Compute frequency-domain autocorrelation as a replacement for QK^T in attention.
    
#     Args:
#         query_states: Tensor of shape [batch_size, seq_len, embed_dim]
#         key_states: Tensor of shape [batch_size, seq_len, embed_dim]
#         normalize: Whether to normalize the autocorrelation output (default: True)

#     Returns:
#         attn_weights: Tensor of shape [batch_size, seq_len, embed_dim]
#     """
#     # Perform FFT along time dimension
#     query_fft = torch.fft.rfft(query_states, dim=1)
#     key_fft = torch.fft.rfft(key_states, dim=1)

#     # Compute element-wise product in frequency domain with conjugate
#     freq_corr = query_fft * torch.conj(key_fft)

#     # Inverse FFT to go back to time domain
#     attn_weights = torch.fft.irfft(freq_corr, n=query_states.shape[1], dim=1)

#     # Optionally added as parameter passed to function
#     # if normalize:
#     #     attn_weights = attn_weights / (query_states.shape[1] ** 0.5)

#     return attn_weights


import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.fft

import torch
import torch.nn as nn
import torch
import torch.nn as nn

class AutoCorrelation(nn.Module):
    def __init__(self, k):
        super(AutoCorrelation, self).__init__()
        self.k = k

    def forward(self, queries, keys, values, attn_mask=None):
        """
        Forward pass for the AutoCorrelation layer, which performs a kind of autocorrelation
        operation using shifts over the time dimension (dimension 1).
        
        Args:
            queries: Tensor of shape (batch_size, seq_len, num_features)
            keys: Tensor of shape (batch_size, seq_len, num_features)
            values: Tensor of shape (batch_size, seq_len, num_features)
            attn_mask: Optional tensor for attention mask (not used in this code)
        
        Returns:
            Tensor of shape (batch_size, seq_len, num_features)
        """
        # Ensure that queries, keys, and values are all the same shape
        assert queries.shape == keys.shape == values.shape, "Queries, keys, and values must have the same shape."

        batch_size, seq_len, num_features = queries.shape

        # Initialize an output tensor (same shape as queries)
        output = torch.zeros_like(queries)

        for i in range(self.k):
            # Randomly select a shift value for each batch (an integer between 1 and seq_len)
            shifts = torch.randint(1, seq_len, (batch_size, 1)).to(queries.device)  # Each sample in batch has its own shift
            shifts = shifts.squeeze(1)  # Remove the second dimension (size 1)
            
            # Perform the rolling shift operation on queries, keys, and values
            shifted_queries = torch.stack([torch.roll(q, shift, dims=1) for q, shift in zip(queries, shifts)])
            shifted_keys = torch.stack([torch.roll(k, shift, dims=1) for k, shift in zip(keys, shifts)])
            shifted_values = torch.stack([torch.roll(v, shift, dims=1) for v, shift in zip(values, shifts)])

            # Compute correlation (here I will sum the shifted values as a placeholder)
            output += shifted_values

        # Optionally apply attention mask or other operations
        if attn_mask is not None:
            output = output * attn_mask

        return output



# Test example (you can ignore this in production code)
if __name__ == "__main__":
    queries = torch.randn(32, 100, 64)  # [batch_size, time_length, d_model]
    keys = torch.randn(32, 100, 64)  # Same shape as queries
    values = torch.randn(32, 100, 64)  # Same shape as queries
    
    auto_corr = AutoCorrelation(k=3)
    output = auto_corr(queries, keys, values)
    print(output.shape)  # Should print [32, 100, 64] (same shape as values)




RuntimeError: `shifts` required

In [95]:
# import torch
# import matplotlib.pyplot as plt

# # 1. Pick one EHG forecasting window
# sample_signal = X_train[0]  # shape: (window_size,)
# seq_len = sample_signal.shape[0]

# # 2. Convert to torch tensor and reshape to [batch_size, seq_len, embed_dim]
# signal_tensor = torch.tensor(sample_signal, dtype=torch.float32).unsqueeze(0).unsqueeze(-1)  # [1, seq_len, 1]


# # 4. Compute autocorrelation
# autocorr_output = autocorrelation(signal_tensor, signal_tensor).squeeze().detach()  # shape: (seq_len,)
# print('autocorr output:', autocorr_output.shape)
# # 5. Plot
# plt.figure(figsize=(12, 5))

# plt.subplot(2, 1, 1)
# plt.plot(sample_signal, label='EHG Signal')
# plt.title('EHG Forecasting Window')
# plt.grid(True)

# plt.subplot(2, 1, 2)
# plt.plot(autocorr_output, label='Autocorrelation')
# plt.title('Autocorrelation of EHG Signal')
# plt.grid(True)

# plt.tight_layout()
# plt.show()


In [96]:
import torch
import math

def time_delay_aggregation(attn_weights, value_states, autocorrelation_factor=2):
    """
    Aggregates value_states using autocorrelations via time-delay rolling.
    Replaces the dot-product attention step in Transformers.
    
    attn_weights: [batch_size, time_length]
    value_states: [batch_size, time_length, embed_dim]
    autocorrelation_factor: scaling factor for the time delay window size
    
    Returns:
        attn_output: Aggregated tensor of the same shape as input [batch_size, time_length, embed_dim]
    """
    bsz, time_length = attn_weights.shape  # Now only batch_size and time_length
    _, _, embed_dim = value_states.shape  # Get embed_dim from value_states
    
    # Step 1: Mean over embedding dim (attn_weights is already [batch_size, time_length], no need for mean)
    autocorr_scores = attn_weights  # Using the weights as is
    
    # Step 2: Get top-k time delays per batch element (using autocorrelation_factor)
    top_k = int(autocorrelation_factor * math.log(time_length))
    
    # Step 3: Apply time-delay aggregation (this logic will need to be adjusted based on your task)
    # For now, let's assume we simply weight the value states using autocorrelation scores
    weighted_values = value_states * autocorr_scores.unsqueeze(-1)  # Shape: [batch_size, time_length, embed_dim]
    
    # Step 4: Aggregate the weighted values over the time axis (e.g., sum or average)
    attn_output = weighted_values.sum(dim=1)  # Sum over time dimension (axis 1)
    
    return attn_output


In [97]:
import torch.nn as nn

class SeriesEmbedding(nn.Module):
    def __init__(self, d_model):
        super(SeriesEmbedding, self).__init__()
        self.linear = nn.Linear(1, d_model)

    def forward(self, x):
        # x: [B, T] → [B, T, 1]
        x = x.unsqueeze(-1)
        return self.linear(x)  # [B, T, d_model]


In [98]:
# class AutoformerEncoderLayer(nn.Module):
#     def __init__(self, d_model, kernel_size):
#         super(AutoformerEncoderLayer, self).__init__()
#         self.decomp = DecompositionLayer(kernel_size)
#         self.proj = nn.Linear(d_model, d_model)
#         self.norm = nn.LayerNorm(d_model)
#         self.activation = nn.ReLU()

#     def forward(self, x):
#         # x: [B, T, d_model]
#         seasonal, trend = self.decomp(x)
#         attn_weights = autocorrelation(seasonal, seasonal)
#         seasonal = time_delay_aggregation(attn_weights.mean(dim=-1), seasonal)
#         seasonal = self.proj(seasonal)
#         seasonal = self.norm(seasonal)
#         seasonal = self.activation(seasonal)
#         return seasonal + trend  # Residual connection

# class AutoformerDecoderLayer(nn.Module):
#     def __init__(self, d_model, kernel_size):
#         super(AutoformerDecoderLayer, self).__init__()
#         self.decomp = DecompositionLayer(kernel_size)
#         self.proj = nn.Linear(d_model, d_model)
#         self.norm = nn.LayerNorm(d_model)
#         self.activation = nn.ReLU()

#     def forward(self, x, enc_output):
#         # x: [B, T_out, d_model], enc_output: [B, T_in, d_model]
#         seasonal, trend = self.decomp(x)
#         attn_weights = autocorrelation(seasonal, enc_output)
#         seasonal = time_delay_aggregation(attn_weights.mean(dim=-1), seasonal)
#         seasonal = self.proj(seasonal)
#         seasonal = self.norm(seasonal)
#         seasonal = self.activation(seasonal)
#         return seasonal + trend


In [1]:
import torch.nn as nn

class EncoderLayer(nn.Module):
    """
    Autoformer Encoder Layer with integrated AutoCorrelation and SeriesDecomp (DecompositionLayer).
    """

    def __init__(self, d_model, kernel_size=25, dropout=0.1):
        super(EncoderLayer, self).__init__()

        self.auto_corr = AutoCorrelation()                    # Our custom autocorrelation module
        self.series_decomp = DecompositionLayer(kernel_size) # Our custom decomposition module

        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )

    def forward(self, x):
        # === Step 1: Auto-Correlation + Residual ===
        residual = x
        x = self.auto_corr(x, x)
        x = x + residual

        # === Step 2: Series Decomposition (1st time) ===
        seasonal_part, _ = self.series_decomp(x)

        # === Step 3: FeedForward + Residual ===
        residual = seasonal_part
        x = self.feedforward(seasonal_part)
        x = x + residual

        # === Step 4: Series Decomposition (2nd time) ===
        seasonal_part, _ = self.series_decomp(x)

        return seasonal_part


In [None]:
class Encoder(nn.Module):
    def __init__(self, layers, d_model, kernel_size):
        """
        Args:
            layers: number of encoder layers
            d_model: embedding dimension
            kernel_size: for decomposition
        """
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([
            EncoderLayer(d_model, kernel_size)
            for _ in range(layers)
        ])

    def forward(self, x):
        """
        Args:
            x: input time series [B, T, D]
        Returns:
            x: output from the final encoder layer [B, T, D]
        """
        for layer in self.layers:
            x = layer(x)
        return x


In [None]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, kernel_size, dropout=0.1):
        super(DecoderLayer, self).__init__()
        self.auto_corr = AutoCorrelation()
        self.series_decomp = DecompositionLayer(kernel_size)
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model),
        )
        # Trend projectors Wl,1, Wl,2, Wl,3
        self.proj_trend1 = nn.Linear(d_model, d_model)
        self.proj_trend2 = nn.Linear(d_model, d_model)
        self.proj_trend3 = nn.Linear(d_model, d_model)

    def forward(self, x, cross, trend):
        # x: seasonal component from previous layer (or embedding) [B, T, D]
        # cross: encoder output (X_N_en) [B, T_encoder, D]
        # trend: accumulated trend (X_det) [B, T, D]

        # 1. Self AutoCorrelation
        x1 = self.auto_corr(x, x) + x
        s1, t1 = self.series_decomp(x1)

        # 2. Cross AutoCorrelation (seasonal + encoder)
        x2 = self.auto_corr(s1, cross) + s1
        s2, t2 = self.series_decomp(x2)

        # 3. FeedForward + Decomp
        x3 = self.feedforward(s2) + s2
        s3, t3 = self.series_decomp(x3)

        # 4. Accumulate trend
        trend = trend + self.proj_trend1(t1) + self.proj_trend2(t2) + self.proj_trend3(t3)

        return s3, trend


In [None]:
class Decoder(nn.Module):
    def __init__(self, layers, d_model, kernel_size, output_dim):
        """
        Args:
            layers: number of decoder layers
            d_model: embedding dimension
            kernel_size: for decomposition
            output_dim: number of output features per timestep
        """
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([
            DecoderLayer(d_model, kernel_size)
            for _ in range(layers)
        ])
        self.projection = nn.Linear(d_model, output_dim)  # WS: final projection for seasonal component

    def forward(self, seasonal_init, trend_init, cross):
        """
        Args:
            seasonal_init: X_des (seasonal decoder input) [B, T, D]
            trend_init: X_det (trend decoder input) [B, T, D]
            cross: encoder output [B, T_enc, D]
        Returns:
            Final prediction: [B, T, output_dim]
        """
        seasonal, trend = seasonal_init, trend_init

        for layer in self.layers:
            seasonal, trend = layer(seasonal, cross, trend)

        # Final output: projected seasonal + trend
        output = self.projection(seasonal) + trend
        return output


In [104]:
import torch
import torch.nn as nn

# class Autoformer(nn.Module):
#     def __init__(self, input_dim, kernel_size=25):  # Add kernel_size to the constructor
#         super(Autoformer, self).__init__()
#         self.kernel_size = kernel_size  # Define the kernel_size attribute
        
#         # Example layers
#         self.avg = nn.AvgPool1d(kernel_size=self.kernel_size, stride=1, padding=self.kernel_size // 2)
#         # Add other layers here as needed

#     def forward(self, x):
#         """Input shape: Batch x Time x EMBED_DIM"""
#         print(x.shape)  # This will show the shape of your input tensor

#         num_of_pads = (self.kernel_size - 1) // 2
#         front = x[:, 0:1, :].repeat(1, num_of_pads, 1)
#         end = x[:, -1:, :].repeat(1, num_of_pads, 1)
#         x_padded = torch.cat([front, x, end], dim=1)

#         # Calculate the trend part of the series
#         x_trend = self.avg(x_padded.permute(0, 2, 1)).permute(0, 2, 1)  # Moving average for trend

#         # Ensure x_trend is the same size as x by removing extra padding
#         trend = x_trend[:, num_of_pads:-num_of_pads, :]  # Remove padding from the trend tensor

#         # Calculate the seasonal part of the series
#         x_seasonal = x - trend  # Seasonality is the residual

#         # Ensure seasonal and trend have the same size
#         seasonal = x_seasonal[:, num_of_pads:-num_of_pads, :]  # Remove padding from seasonal tensor

#         return seasonal, trend

class Autoformer(nn.Module):
    def __init__(self, input_dim, kernel_size, d_model=64, n_layers=3):
        super(Autoformer, self).__init__()
        
        self.input_dim = input_dim
        self.kernel_size = kernel_size
        self.d_model = d_model
        self.n_layers = n_layers

        # Initialize encoder and decoder layers
        self.encoder_layers = nn.ModuleList([AutoformerEncoderLayer(d_model, kernel_size) for _ in range(n_layers)])
        self.decoder_layers = nn.ModuleList([AutoformerDecoderLayer(d_model, kernel_size) for _ in range(n_layers)])

        # Final linear layer to transform the output to the desired shape
        self.fc = nn.Linear(d_model, input_dim)

        # Decomposition layer
        self.decomp = DecompositionLayer(kernel_size)

    def forward(self, x):
        # Decompose the input into seasonal and trend components
        seasonal, trend = self.decomp(x)

        # Process the seasonal component through encoder layers
        for encoder in self.encoder_layers:
            seasonal = encoder(seasonal)

        # You can modify the decoder logic as needed, but let's assume you're passing the seasonal 
        # through the decoder layers with the trend as the additional input
        for decoder in self.decoder_layers:
            seasonal = decoder(seasonal, trend)

        # Combine the processed seasonal with the trend component (or modify as needed)
        output = seasonal + trend  # Both should have the same shape now

        # Apply the final fully connected layer
        output = self.fc(output)
        
        return output



In [105]:
criterion = nn.MSELoss()
print(list(model.parameters()))

model = Autoformer(input_dim=64, kernel_size=25).to(device)

# Check if the model has parameters
print("Model Parameters:", list(model.parameters()))  # This should print the model parameters

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
# Assuming 'device' is already defined as:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move the model to the correct device
model = Autoformer(input_dim=64, kernel_size=25).to(device)

# Training loop
for batch_x, batch_y in train_loader:
    # Move data to the same device as the model
    batch_x, batch_y = batch_x.to(device).float(), batch_y.to(device).float()

    # Forward pass
    preds = model(batch_x)
    
    # Compute loss
    loss = criterion(preds, batch_y)
    
    # Backward pass and optimization
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()



[Parameter containing:
tensor([[ 0.0868,  0.0213, -0.0896,  ..., -0.0554,  0.0406, -0.0390],
        [ 0.0192, -0.0398, -0.0659,  ..., -0.0083,  0.0142, -0.0855],
        [ 0.0869, -0.0775, -0.1037,  ..., -0.0465, -0.0637,  0.0893],
        ...,
        [-0.1094,  0.0942,  0.0096,  ..., -0.1063,  0.0665,  0.0787],
        [-0.0731,  0.0343,  0.0951,  ..., -0.0111,  0.0590,  0.0850],
        [ 0.0981,  0.0497,  0.0061,  ...,  0.0418, -0.0491,  0.0660]],
       device='cuda:0', requires_grad=True), Parameter containing:
tensor([-0.0029,  0.0576,  0.0928,  0.0440,  0.0950, -0.0924,  0.0455,  0.0524,
        -0.0065,  0.0770, -0.0506,  0.0709, -0.1112,  0.0852,  0.0180,  0.0126,
         0.1152,  0.0557,  0.0309, -0.0612,  0.0796, -0.0143,  0.0452,  0.0329,
         0.0317,  0.0471,  0.0175,  0.1045, -0.0781,  0.0455, -0.0973,  0.0450,
         0.0340, -0.0483, -0.0679, -0.0220, -0.1059, -0.1185, -0.0588, -0.0251,
         0.0361,  0.1133,  0.0602, -0.0113,  0.0609,  0.0264,  0.0334, -0.10

IndexError: too many indices for tensor of dimension 2