In [4]:
import torch
import torch.nn as nn
import numpy as np
import time
from prettytable import PrettyTable

# Copy of the TST model definition
class TimeSeriesTransformer(nn.Module):
    def __init__(self, input_dim, num_classes, seq_len, d_model=64,
                 n_heads=4, depth=4, d_ff=256, dropout=0.1, patch_size=25):
        super().__init__()

        self.patch_size = patch_size
        self.n_patches = seq_len // patch_size
        self.d_model = d_model

        # Patch embedding
        self.patch_embedding = nn.Linear(input_dim * patch_size, d_model)

        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, d_model))

        # Position embedding
        self.pos_embedding = nn.Parameter(torch.randn(1, self.n_patches + 1, d_model))

        # Dropout
        self.dropout = nn.Dropout(dropout)

        # Transformer encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=d_ff,
            dropout=dropout,
            activation='relu',
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=depth)

        # Output layers
        self.norm = nn.LayerNorm(d_model)
        self.fc_out = nn.Linear(d_model, num_classes)

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.patch_embedding.weight)
        nn.init.xavier_uniform_(self.fc_out.weight)
        nn.init.normal_(self.cls_token, std=0.02)
        nn.init.normal_(self.pos_embedding, std=0.02)
        nn.init.zeros_(self.fc_out.bias)

    def forward(self, x):
        B, L, D = x.shape

        # Create patches
        x = x[:, :self.n_patches * self.patch_size, :]
        x = x.reshape(B, self.n_patches, self.patch_size * D)

        # Patch embedding
        x = self.patch_embedding(x)

        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat([cls_tokens, x], dim=1)

        # Add position embedding
        x = x + self.pos_embedding
        x = self.dropout(x)

        # Transformer
        x = self.transformer(x)

        # Get CLS output
        cls_output = x[:, 0]
        cls_output = self.norm(cls_output)

        # Classification
        logits = self.fc_out(cls_output)

        return logits


def count_parameters(model):
    """Calculate the number of model parameters."""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params


def measure_latency(model, input_shape, device, warmup=50, num_runs=200):
    """Measure model inference latency."""
    model.eval()
    model = model.to(device)

    # Create test input
    test_input = torch.randn(1, *input_shape).to(device)

    # Warm-up
    print(f"Warming up ({warmup} runs)…")
    for _ in range(warmup):
        with torch.no_grad():
            _ = model(test_input)

    # Synchronize CUDA
    if device.type == 'cuda':
        torch.cuda.synchronize()

    # Measure latency
    print(f"Measuring latency ({num_runs} runs)…")
    times = []

    for _ in range(num_runs):
        if device.type == 'cuda':
            torch.cuda.synchronize()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            with torch.no_grad():
                _ = model(test_input)
            end.record()

            torch.cuda.synchronize()
            times.append(start.elapsed_time(end))  # milliseconds
        else:
            start = time.perf_counter()
            with torch.no_grad():
                _ = model(test_input)
            end = time.perf_counter()
            times.append((end - start) * 1000)  # convert to milliseconds

    times = np.array(times)
    return {
        'mean': np.mean(times),
        'std': np.std(times),
        'median': np.median(times),
        'min': np.min(times),
        'max': np.max(times),
        'p95': np.percentile(times, 95),
        'p99': np.percentile(times, 99)
    }


def calculate_flops(model, input_shape):
    """Compute model FLOPs."""
    try:
        from thop import profile
        dummy_input = torch.randn(1, *input_shape)
        flops, params = profile(model.cpu(), inputs=(dummy_input,), verbose=False)
        return flops / 1e6  # convert to MFLOPs
    except ImportError:
        print("Please install the 'thop' library to compute FLOPs: pip install thop")
        return None


def test_tst_performance():
    """Benchmark TST model performance."""
    print("=" * 80)
    print("TST Model Performance Benchmark")
    print("=" * 80)

    # Device setup
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Test device: {device}")

    # Evaluation configurations (identical to the original code)
    configs = [
        {"window_size": 200, "patch_size": 20, "d_ff": 256},
        {"window_size": 500, "patch_size": 25, "d_ff": 256},
        {"window_size": 800, "patch_size": 40, "d_ff": 320}
    ]

    # Fixed parameters
    input_dim = 52   # Feature dimension for the PAMAP2 dataset
    num_classes = 12  # Number of activity classes in PAMAP2
    batch_sizes = [1, 16, 32, 64]  # Test multiple batch sizes

    results = []

    for config in configs:
        window_size = config['window_size']
        patch_size = config['patch_size']
        d_ff = config['d_ff']

        print(f"\n{'=' * 60}")
        print(f"Test configuration: Window={window_size} ({window_size / 100:.1f}s), "
              f"Patch={patch_size}, d_ff={d_ff}")
        print('=' * 60)

        # Instantiate model
        model = TimeSeriesTransformer(
            input_dim=input_dim,
            num_classes=num_classes,
            seq_len=window_size,
            d_model=64,
            n_heads=4,
            depth=4,
            d_ff=d_ff,
            dropout=0.1,
            patch_size=patch_size
        )

        # Parameter count
        total_params, trainable_params = count_parameters(model)
        print("\nParameter count:")
        print(f"  - Total parameters: {total_params:,}")
        print(f"  - Trainable parameters: {trainable_params:,}")
        print(f"  - Parameter size: {total_params * 4 / 1024 / 1024:.2f} MB (FP32)")

        # Theoretical computational cost
        n_patches = window_size // patch_size
        print("\nModel structure:")
        print(f"  - Number of patches: {n_patches}")
        print(f"  - Sequence length (incl. CLS): {n_patches + 1}")

        # FLOPs
        flops = calculate_flops(model, (window_size, input_dim))
        if flops:
            print(f"  - FLOPs: {flops:.2f} M")

        # Inference latency tests
        print("\nInference latency test:")
        for batch_size in batch_sizes:
            input_shape = (window_size, input_dim)

            # Adjust the first dimension to batch_size
            test_input = torch.randn(batch_size, *input_shape).to(device)
            model = model.to(device)

            # Measure batch inference latency
            with torch.no_grad():
                if device.type == 'cuda':
                    torch.cuda.synchronize()

                start_time = time.perf_counter()
                for _ in range(100):
                    _ = model(test_input)

                if device.type == 'cuda':
                    torch.cuda.synchronize()

                total_time = (time.perf_counter() - start_time) * 1000 / 100
                per_sample_time = total_time / batch_size

            # Detailed latency statistics (only for batch_size = 1)
            if batch_size == 1:
                latency_stats = measure_latency(model, input_shape, device)

                results.append({
                    'window_size': window_size,
                    'patch_size': patch_size,
                    'd_ff': d_ff,
                    'total_params': total_params,
                    'flops_m': flops,
                    'latency_mean': latency_stats['mean'],
                    'latency_std': latency_stats['std'],
                    'latency_p95': latency_stats['p95']
                })

                print(f"\n  Batch Size = {batch_size}:")
                print(f"    - Mean latency: {latency_stats['mean']:.3f} ± {latency_stats['std']:.3f} ms")
                print(f"    - Median: {latency_stats['median']:.3f} ms")
                print(f"    - P95: {latency_stats['p95']:.3f} ms")
                print(f"    - P99: {latency_stats['p99']:.3f} ms")
                print(f"    - Range: [{latency_stats['min']:.3f}, {latency_stats['max']:.3f}] ms")
            else:
                print(f"  Batch Size = {batch_size}: {per_sample_time:.3f} ms/sample")

    # Summary table
    print(f"\n{'=' * 80}")
    print("Performance Summary")
    print('=' * 80)

    table = PrettyTable()
    table.field_names = ["Window(s)", "Patches", "Params", "FLOPs(M)", "Latency(ms)", "Throughput(Hz)"]

    for r in results:
        table.add_row([
            f"{r['window_size'] / 100:.1f}",
            f"{r['window_size'] // r['patch_size']}",
            f"{r['total_params']:,}",
            f"{r['flops_m']:.2f}" if r['flops_m'] else "N/A",
            f"{r['latency_mean']:.2f}±{r['latency_std']:.2f}",
            f"{1000 / r['latency_mean']:.1f}"
        ])

    print(table)

    # Memory footprint estimate
    print("\nMemory footprint estimate (FP32):")
    for r in results:
        model_size = r['total_params'] * 4 / 1024 / 1024  # MB
        print(f"  - Window {r['window_size'] / 100:.1f}s: {model_size:.2f} MB")


if __name__ == "__main__":
    test_tst_performance()

TST Model Performance Benchmark
Test device: cpu

Test configuration: Window=200 (2.0s), Patch=20, d_ff=256

Parameter count:
  - Total parameters: 268,236
  - Trainable parameters: 268,236
  - Parameter size: 1.02 MB (FP32)

Model structure:
  - Number of patches: 10
  - Sequence length (incl. CLS): 11
Please install the 'thop' library to compute FLOPs: pip install thop

Inference latency test:
Warming up (50 runs)…
Measuring latency (200 runs)…

  Batch Size = 1:
    - Mean latency: 1.462 ± 0.642 ms
    - Median: 1.283 ms
    - P95: 2.551 ms
    - P99: 3.668 ms
    - Range: [1.242, 7.072] ms
  Batch Size = 16: 0.286 ms/sample
  Batch Size = 32: 0.322 ms/sample
  Batch Size = 64: 0.364 ms/sample

Test configuration: Window=500 (5.0s), Patch=25, d_ff=256

Parameter count:
  - Total parameters: 285,516
  - Trainable parameters: 285,516
  - Parameter size: 1.09 MB (FP32)

Model structure:
  - Number of patches: 20
  - Sequence length (incl. CLS): 21
Please install the 'thop' library to c