In [11]:
# !uv pip install braindecode

[2mUsing Python 3.12.3 environment at: /opt/venv[0m
[2mAudited [1m1 package[0m [2min 28ms[0m[0m


In [1]:
# Core libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW

# Transformers and PEFT
from transformers import AutoTokenizer, AutoModel
from peft import LoraConfig, get_peft_model

# Data processing and visualization
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
# from sklearn.metrics.pairwise import cosine_similarity
# from sklearn.linear_model import LinearRegression
# from sklearn.metrics import mean_squared_error

# Utilities
import gc
# from tqdm.auto import tqdm

# Set style for prettier plots
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 100

print("✓ All libraries imported successfully!")

✓ All libraries imported successfully!


In [2]:
# Device 

In [3]:
# Configure compute device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✓ Using device: {device}")
if torch.cuda.is_available():
    print(f"  GPU: {torch.cuda.get_device_name(0)}")
    print(f"  Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

✓ Using device: cuda
  GPU: NVIDIA A100-SXM4-40GB
  Memory: 42.29 GB


In [4]:
# load data

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Try importing braindecode - if not available, provide installation instructions
try:
    from braindecode.datasets import BCICompetitionIVDataset4
    BRAINDECODE_AVAILABLE = True
except ImportError:
    BRAINDECODE_AVAILABLE = False
    print("Warning: braindecode not installed. Install with: uv add braindecode (or pip install braindecode)")

# Try importing timesfm - if not available, provide installation instructions
try:
    import timesfm
    TIMESFM_AVAILABLE = True
except ImportError:
    TIMESFM_AVAILABLE = False
    print("Warning: timesfm not installed. Install with: uv add \"git+https://github.com/google-research/timesfm.git\"")
    print("Or with pip: pip install git+https://github.com/google-research/timesfm.git")
    print("Note: TimesFM does not support ARM architectures (Apple Silicon)")

# Set style for better plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)


def load_bci_dataset(subject_ids=None):
    """
    Load the BCI Competition IV Dataset 4.
    
    Parameters:
    -----------
    subject_ids : list of int or int or None
        Subject(s) to load. If None, loads all available subjects (1-3).
    
    Returns:
    --------
    dataset : BaseConcatDataset
        Loaded dataset containing ECoG recordings
    """
    print("=" * 80)
    print("Loading BCI Competition IV Dataset 4")
    print("=" * 80)
    
    if not BRAINDECODE_AVAILABLE:
        raise ImportError(
            "braindecode is required. Install with: uv add braindecode moabb\n"
            "Note: moabb is also required for dataset downloads."
        )
    
    # Download dataset if not already available
    print("\nDownloading dataset (if not already available)...")
    try:
        BCICompetitionIVDataset4.download()
    except ModuleNotFoundError as e:
        if 'moabb' in str(e):
            raise ImportError(
                "moabb is required for dataset downloads. Install with: uv add moabb\n"
                "Or: pip install moabb"
            ) from e
        raise
    
    # Load dataset
    print(f"\nLoading dataset for subjects: {subject_ids if subject_ids else 'all'}")
    dataset = BCICompetitionIVDataset4(subject_ids=subject_ids)
    
    print(f"Dataset loaded successfully!")
    print(f"Number of recordings: {len(dataset.datasets)}")
    
    return dataset


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Try importing braindecode - if not available, provide installation instructions
try:
    from braindecode.datasets import BCICompetitionIVDataset4
    BRAINDECODE_AVAILABLE = True
except ImportError:
    BRAINDECODE_AVAILABLE = False
    print("Warning: braindecode not installed. Install with: uv add braindecode (or pip install braindecode)")

# Try importing timesfm - if not available, provide installation instructions
try:
    import timesfm
    TIMESFM_AVAILABLE = True
except ImportError:
    TIMESFM_AVAILABLE = False
    print("Warning: timesfm not installed. Install with: uv add \"git+https://github.com/google-research/timesfm.git\"")
    print("Or with pip: pip install git+https://github.com/google-research/timesfm.git")
    print("Note: TimesFM does not support ARM architectures (Apple Silicon)")

# Set style for better plots
sns.set_style("whitegrid")
plt.rcParams['figure.figsize'] = (15, 10)


def load_bci_dataset(subject_ids=None):
    """
    Load the BCI Competition IV Dataset 4.
    
    Parameters:
    -----------
    subject_ids : list of int or int or None
        Subject(s) to load. If None, loads all available subjects (1-3).
    
    Returns:
    --------
    dataset : BaseConcatDataset
        Loaded dataset containing ECoG recordings
    """
    print("=" * 80)
    print("Loading BCI Competition IV Dataset 4")
    print("=" * 80)
    
    if not BRAINDECODE_AVAILABLE:
        raise ImportError(
            "braindecode is required. Install with: uv add braindecode moabb\n"
            "Note: moabb is also required for dataset downloads."
        )
    
    # Download dataset if not already available
    print("\nDownloading dataset (if not already available)...")
    try:
        BCICompetitionIVDataset4.download()
    except ModuleNotFoundError as e:
        if 'moabb' in str(e):
            raise ImportError(
                "moabb is required for dataset downloads. Install with: uv add moabb\n"
                "Or: pip install moabb"
            ) from e
        raise
    
    # Load dataset
    print(f"\nLoading dataset for subjects: {subject_ids if subject_ids else 'all'}")
    dataset = BCICompetitionIVDataset4(subject_ids=subject_ids)
    
    print(f"Dataset loaded successfully!")
    print(f"Number of recordings: {len(dataset.datasets)}")
    
    return dataset


In [28]:
data=load_bci_dataset(subject_ids=3)

Loading BCI Competition IV Dataset 4


ImportError: braindecode is required. Install with: uv add braindecode moabb
Note: moabb is also required for dataset downloads.

In [None]:
#Load model and zero shot evaluation 

In [None]:
import torch
import numpy as np
import timesfm

torch.set_float32_matmul_precision("high")

model = timesfm.TimesFM_2p5_200M_torch.from_pretrained("google/timesfm-2.5-200m-pytorch")

model.compile(
    timesfm.ForecastConfig(
        max_context=1024,
        max_horizon=256,
        normalize_inputs=True,
        use_continuous_quantile_head=True,
        force_flip_invariance=True,
        infer_is_positive=True,
        fix_quantile_crossing=True,
    )
)
point_forecast, quantile_forecast = model.forecast(
    horizon=12,
    inputs=[
        np.linspace(0, 1, 100),
        np.sin(np.linspace(0, 20, 67)),
    ],  # Two dummy inputs
)
point_forecast.shape  # (2, 12)
quantile_forecast.shape  # (2, 12, 10): mean, then 10th to 90th quantiles.