In [11]:
import torch
import sys
import os
# Ensure the root directory is in python path so we can import 'model' and 'layers'
sys.path.append(os.path.abspath('.'))
from model.Testformer import Model

try:
    from torchsummary import summary
except ImportError:
    print("torchsummary not found. Please install it using `pip install torchsummary`")

try:
    from torchinfo import summary as summary_info
except ImportError:
    print("torchinfo not found (optional).")

In [12]:
class Config:
    """
    Configuration class mocking the arguments normally passed via command line.
    Values are taken from the DEBUG run arguments.
    """
    def __init__(self):
        self.model_id = 'ETTh1_96_96'
        self.model = 'Testformer'
        self.data = 'ETTh1'
        self.features = 'M'
        self.seq_len = 96
        self.label_len = 48
        self.pred_len = 96
        self.enc_in = 7
        self.dec_in = 7
        self.c_out = 7
        self.d_model = 256
        self.n_heads = 8
        self.e_layers = 2
        self.d_layers = 1
        self.d_ff = 256
        self.moving_avg = 25
        self.factor = 1
        self.distil = True
        self.dropout = 0.1
        self.embed = 'timeF'
        self.activation = 'gelu'
        self.output_attention = False
        self.do_predict = False
        self.freq = 'h'
        self.class_strategy = 'projection'
        self.use_gpu = torch.cuda.is_available()
        self.use_channel_period_flex = True  # Assuming default True if not specified, or False based on code

config = Config()
print("Config loaded.")

Config loaded.


In [13]:
# Instantiate the model
model = Model(config)

if torch.cuda.is_available():
    model = model.cuda()
    print("Model moved to GPU.")
else:
    print("Model on CPU.")

print(model)

Model moved to GPU.
Model(
  (decomp): Decomp(
    (ema): EMA()
  )
  (normlizer): Normalize()
  (trend_net): TrendFlow(
    (layer1): Sequential(
      (0): Linear(in_features=96, out_features=1024, bias=True)
      (1): GELU(approximate='none')
      (2): Dropout(p=0.1, inplace=False)
      (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (layer2): Sequential(
      (0): Linear(in_features=1024, out_features=1024, bias=True)
      (1): GELU(approximate='none')
      (2): Dropout(p=0.1, inplace=False)
      (3): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
    )
    (head): Linear(in_features=1024, out_features=256, bias=True)
  )
  (season_net): SeasonFlow(
    (enc_embedding): DataEmbedding_inverted(
      (value_embedding): Linear(in_features=96, out_features=256, bias=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (data_embedding): DataEmbedding(
      (value_embedding): TokenEmbedding(
        (tokenConv): Conv1d(7, 256, kernel_size=

In [14]:
# Visualize using torchsummary
# Input shape expected by torchsummary is (Channels, Length) or similar, but it adds Batch dim automatically.
# Testformer expects input x_enc: [Batch, Seq_Len, Enc_In]
# So we pass (seq_len, enc_in)

print("\n--- Model Summary (torchsummary) ---")
try:
    # Note: torchsummary's summary(model, input_size) creates a tensor of shape (Batch, *input_size)
    # So passing (96, 7) creates [Batch, 96, 7], which matches x_enc shape.
    summary(model, (config.seq_len, config.enc_in))
except Exception as e:
    print(f"torchsummary failed: {e}")


--- Model Summary (torchsummary) ---
torchsummary failed: 'NoneType' object has no attribute 'size'


In [15]:
# Visualize using torchinfo (if available) - Provides more detailed layer info
try:
    batch_size = 32
    print("\n--- Model Summary (torchinfo) ---")
    # summary_info handles args more flexibly. passing input_size explicitly.
    summary_info(model, input_size=(batch_size, config.seq_len, config.enc_in), 
                 col_names=["input_size", "output_size", "num_params", "kernel_size", "mult_adds"], 
                 verbose=1)
except NameError:
    pass
except Exception as e:
    print(f"torchinfo execution failed: {e}")


--- Model Summary (torchinfo) ---
torchinfo execution failed: Failed to run torchinfo. See above stack traces for more details. Executed layers up to: [Normalize: 1, Mahalanobis_mask: 1, Decomp: 1, EMA: 2, TrendFlow: 1, Sequential: 2, Linear: 3, GELU: 3, Dropout: 3, LayerNorm: 3, Sequential: 2, Linear: 3, GELU: 3, Dropout: 3, LayerNorm: 3, Linear: 2, DataEmbedding_inverted: 2, Linear: 3, Dropout: 3, Linear: 6, Linear: 6, Linear: 6, Dropout: 7]
