# FewShotPainAdaptation on Google Colab (T4)
This notebook installs dependencies, verifies GPU, configures reproducibility, and runs LOSO few-shot training.

Before running, set Colab runtime to: **GPU (T4)**.

In [3]:
!nvidia-smi
!pip -q install -U pip
!pip -q install tensorflow==2.18.1 cloudpickle matplotlib numpy pandas scikit-learn scipy seaborn pydantic

Thu Feb 26 09:53:39 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   50C    P8             10W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [6]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
REPO_URL = "https://github.com/hhihn/FewShotPainAdaptation.git"
PROJECT_DIR = "/content/FewShotPainAdaptation"
import os
if not os.path.isdir(PROJECT_DIR):
    !git clone $REPO_URL $PROJECT_DIR
    %cd $PROJECT_DIR
else:
    %cd $PROJECT_DIR
    !git pull


/content/FewShotPainAdaptation
Already up to date.


In [5]:
import tensorflow as tf
gpus = tf.config.list_physical_devices('GPU')
print('Visible GPUs:', gpus)
if not gpus:
    raise RuntimeError('No GPU detected. In Colab, set Runtime -> Change runtime type -> GPU.') # Optional, helps speed on T4.
from tensorflow.keras import mixed_precision
mixed_precision.set_global_policy('mixed_float16')
print('Mixed precision policy:', mixed_precision.global_policy())

Visible GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Mixed precision policy: <DTypePolicy "mixed_float16">


In [9]:
# Point this to a folder containing: X_pre.npy, y_heater.npy, subjects.npy\n# Example if your data is in repo: /content/FewShotPainAdaptation/data\n# Example if in Drive: /content/drive/MyDrive/FewShotPainAdaptation/data\n
DATA_DIR = "/content/drive/MyDrive/PainData"
import os
required = ['X_pre.npy', 'y_heater.npy', 'subjects.npy']
missing = [f for f in required if not os.path.exists(os.path.join(DATA_DIR, f))]
if missing:
    raise FileNotFoundError(f'Missing files in DATA_DIR={DATA_DIR}: {missing}')
print('Using data directory:', DATA_DIR)

Using data directory: /content/drive/MyDrive/PainData


In [11]:
import json
import logging
import numpy as np
from data_loaders.pain_ds_config import PainDatasetConfig
from learner.few_shot_pain_learner import FewShotPainLearner
from utils.logger import setup_logger
logger = setup_logger('FewShotPainLearner', level=logging.INFO)
# Reproducible config
config = PainDatasetConfig(seed=42, deterministic_ops=True, k_shot=3, q_query=3,)
# Keep these small first; increase after sanity check.
NUM_EPOCHS = 3
EPISODES_PER_EPOCH = 10
VAL_EPISODES = 5
FUSION_METHODS = ['attention']
all_results = {}
for fusion_method in FUSION_METHODS:
    logger.info(f'Training with fusion method: {fusion_method}')
    learner = FewShotPainLearner(config=config, data_dir=DATA_DIR, learning_rate=1e-3, fusion_method=fusion_method, seed=config.seed,        deterministic_ops=config.deterministic_ops)
    cv_results = learner.train(num_epochs=NUM_EPOCHS, episodes_per_epoch=EPISODES_PER_EPOCH, val_episodes=VAL_EPISODES,)
    all_results[fusion_method] = cv_results
    summary = {
        fm: {
            'avg_test_acc': float(np.mean(res['test_accuracies'])),
            'std_test_acc': float(np.std(res['test_accuracies'])),
            'avg_test_loss': float(np.mean(res['test_losses'])),
        }    for fm, res in all_results.items()}
    print(json.dumps(summary, indent=2))
    with open('colab_run_results.json', 'w') as f:
        json.dump({'summary': summary, 'full': all_results}, f, indent=2)
    print('Saved results to colab_run_results.json')

2026-02-26 10:02:16 │ INFO     │ FewShotPainLearner:17	 │ Training with fusion method: concat


INFO:FewShotPainLearner:Training with fusion method: concat


2026-02-26 10:02:16 │ DEBUG    │ PainMetaDataset:57	 │ Data directory: /content/drive/MyDrive/PainData


DEBUG:PainMetaDataset:Data directory: /content/drive/MyDrive/PainData


2026-02-26 10:02:16 │ INFO     │ PainMetaDataset:73	 │ Loading data from /content/drive/MyDrive/PainData...


INFO:PainMetaDataset:Loading data from /content/drive/MyDrive/PainData...


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:79	 │ X.shape: (2495, 2500, 3, 1)


INFO:PainMetaDataset:X.shape: (2495, 2500, 3, 1)


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:81	 │ y_onehot.shape: (2495, 6)


INFO:PainMetaDataset:y_onehot.shape: (2495, 6)


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:83	 │ subjects.shape: [ 0  0  0 ... 51 51 51]


INFO:PainMetaDataset:subjects.shape: [ 0  0  0 ... 51 51 51]


2026-02-26 10:02:17 │ DEBUG    │ PainMetaDataset:94	 │ Unique subjects: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51]


DEBUG:PainMetaDataset:Unique subjects: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
 48 49 50 51]


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:96	 │ Number of subjects: 52


INFO:PainMetaDataset:Number of subjects: 52


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:97	 │   Data shape: (2495, 2500, 3)


INFO:PainMetaDataset:  Data shape: (2495, 2500, 3)


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:98	 │   Labels shape: (2495,)


INFO:PainMetaDataset:  Labels shape: (2495,)


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:99	 │   Number of subjects: 52


INFO:PainMetaDataset:  Number of subjects: 52


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:100	 │   Samples per subject: ~47


INFO:PainMetaDataset:  Samples per subject: ~47


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:101	 │   Classes: [0 1 2 3 4 5]


INFO:PainMetaDataset:  Classes: [0 1 2 3 4 5]


2026-02-26 10:02:17 │ INFO     │ PainMetaDataset:135	 │   Minimum samples per (subject, class): 7


INFO:PainMetaDataset:  Minimum samples per (subject, class): 7


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:108	 │ Initialized TCN with 3 blocks


INFO:TemporalConvolutionalNetwork:Initialized TCN with 3 blocks


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:109	 │ Filters: [32, 64, 128]


INFO:TemporalConvolutionalNetwork:Filters: [32, 64, 128]


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:110	 │ Dilation rates: [1, 2, 4]


INFO:TemporalConvolutionalNetwork:Dilation rates: [1, 2, 4]


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:81	 │ Built CNN encoder with EDA


INFO:MultimodalPrototypicalNetwork:Built CNN encoder with EDA


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:82	 │ None


INFO:MultimodalPrototypicalNetwork:None


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:108	 │ Initialized TCN with 3 blocks


INFO:TemporalConvolutionalNetwork:Initialized TCN with 3 blocks


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:109	 │ Filters: [32, 64, 128]


INFO:TemporalConvolutionalNetwork:Filters: [32, 64, 128]


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:110	 │ Dilation rates: [1, 2, 4]


INFO:TemporalConvolutionalNetwork:Dilation rates: [1, 2, 4]


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:81	 │ Built CNN encoder with ECG


INFO:MultimodalPrototypicalNetwork:Built CNN encoder with ECG


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:82	 │ None


INFO:MultimodalPrototypicalNetwork:None


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:108	 │ Initialized TCN with 3 blocks


INFO:TemporalConvolutionalNetwork:Initialized TCN with 3 blocks


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:109	 │ Filters: [32, 64, 128]


INFO:TemporalConvolutionalNetwork:Filters: [32, 64, 128]


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:110	 │ Dilation rates: [1, 2, 4]


INFO:TemporalConvolutionalNetwork:Dilation rates: [1, 2, 4]


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:81	 │ Built CNN encoder with EMG


INFO:MultimodalPrototypicalNetwork:Built CNN encoder with EMG


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:82	 │ None


INFO:MultimodalPrototypicalNetwork:None


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:68	 │ Initialized MultimodalPrototypicalNetwork with 3 modalities


INFO:MultimodalPrototypicalNetwork:Initialized MultimodalPrototypicalNetwork with 3 modalities


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:71	 │ Fusion method: concat, Final embedding dim: 192


INFO:MultimodalPrototypicalNetwork:Fusion method: concat, Final embedding dim: 192


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:90	 │ Run config: {"data_dir": "/content/drive/MyDrive/PainData", "deterministic_ops": true, "fusion_method": "concat", "k_shot": 3, "learning_rate": 0.001, "modality_names": ["EDA", "ECG", "EMG"], "n_way": 6, "q_query": 3, "seed": 42, "sensor_idx": [1, 4, 5], "sequence_length": 2500}


INFO:few_shot_pain_learner:Run config: {"data_dir": "/content/drive/MyDrive/PainData", "deterministic_ops": true, "fusion_method": "concat", "k_shot": 3, "learning_rate": 0.001, "modality_names": ["EDA", "ECG", "EMG"], "n_way": 6, "q_query": 3, "seed": 42, "sensor_idx": [1, 4, 5], "sequence_length": 2500}


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:92	 │ Initialized FewShotPainLearner with 52 subjects


INFO:few_shot_pain_learner:Initialized FewShotPainLearner with 52 subjects


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:95	 │ Data shape: (sequence_length=2500, num_sensors=3)


INFO:few_shot_pain_learner:Data shape: (sequence_length=2500, num_sensors=3)


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:98	 │ Modalities: ('EDA', 'ECG', 'EMG')


INFO:few_shot_pain_learner:Modalities: ('EDA', 'ECG', 'EMG')


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:99	 │ Fusion method: concat


INFO:few_shot_pain_learner:Fusion method: concat


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:151	 │ 


INFO:few_shot_pain_learner:


2026-02-26 10:02:17 │ INFO     │ few_shot_pain_learner:152	 │ Fold 1/52: Test subject = 0


INFO:few_shot_pain_learner:Fold 1/52: Test subject = 0






2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:108	 │ Initialized TCN with 3 blocks


INFO:TemporalConvolutionalNetwork:Initialized TCN with 3 blocks


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:109	 │ Filters: [32, 64, 128]


INFO:TemporalConvolutionalNetwork:Filters: [32, 64, 128]


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:110	 │ Dilation rates: [1, 2, 4]


INFO:TemporalConvolutionalNetwork:Dilation rates: [1, 2, 4]


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:81	 │ Built CNN encoder with EDA


INFO:MultimodalPrototypicalNetwork:Built CNN encoder with EDA


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:82	 │ None


INFO:MultimodalPrototypicalNetwork:None


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:108	 │ Initialized TCN with 3 blocks


INFO:TemporalConvolutionalNetwork:Initialized TCN with 3 blocks


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:109	 │ Filters: [32, 64, 128]


INFO:TemporalConvolutionalNetwork:Filters: [32, 64, 128]


2026-02-26 10:02:17 │ INFO     │ TemporalConvolutionalNetwork:110	 │ Dilation rates: [1, 2, 4]


INFO:TemporalConvolutionalNetwork:Dilation rates: [1, 2, 4]


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:81	 │ Built CNN encoder with ECG


INFO:MultimodalPrototypicalNetwork:Built CNN encoder with ECG


2026-02-26 10:02:17 │ INFO     │ MultimodalPrototypicalNetwork:82	 │ None


INFO:MultimodalPrototypicalNetwork:None


2026-02-26 10:02:18 │ INFO     │ TemporalConvolutionalNetwork:108	 │ Initialized TCN with 3 blocks


INFO:TemporalConvolutionalNetwork:Initialized TCN with 3 blocks


2026-02-26 10:02:18 │ INFO     │ TemporalConvolutionalNetwork:109	 │ Filters: [32, 64, 128]


INFO:TemporalConvolutionalNetwork:Filters: [32, 64, 128]


2026-02-26 10:02:18 │ INFO     │ TemporalConvolutionalNetwork:110	 │ Dilation rates: [1, 2, 4]


INFO:TemporalConvolutionalNetwork:Dilation rates: [1, 2, 4]


2026-02-26 10:02:18 │ INFO     │ MultimodalPrototypicalNetwork:81	 │ Built CNN encoder with EMG


INFO:MultimodalPrototypicalNetwork:Built CNN encoder with EMG


2026-02-26 10:02:18 │ INFO     │ MultimodalPrototypicalNetwork:82	 │ None


INFO:MultimodalPrototypicalNetwork:None


2026-02-26 10:02:18 │ INFO     │ MultimodalPrototypicalNetwork:68	 │ Initialized MultimodalPrototypicalNetwork with 3 modalities


INFO:MultimodalPrototypicalNetwork:Initialized MultimodalPrototypicalNetwork with 3 modalities


2026-02-26 10:02:18 │ INFO     │ MultimodalPrototypicalNetwork:71	 │ Fusion method: concat, Final embedding dim: 192


INFO:MultimodalPrototypicalNetwork:Fusion method: concat, Final embedding dim: 192


2026-02-26 10:02:18 │ DEBUG    │ data_loaders.loso_cross_validator:68	 │ Train subjects: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46), np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51)]


DEBUG:data_loaders.loso_cross_validator:Train subjects: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46), np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51)]


2026-02-26 10:02:18 │ DEBUG    │ data_loaders.loso_cross_validator:71	 │ n_val: 5


DEBUG:data_loaders.loso_cross_validator:n_val: 5


2026-02-26 10:02:18 │ DEBUG    │ data_loaders.loso_cross_validator:73	 │ val_subjects: [np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51)]


DEBUG:data_loaders.loso_cross_validator:val_subjects: [np.int64(47), np.int64(48), np.int64(49), np.int64(50), np.int64(51)]


2026-02-26 10:02:18 │ DEBUG    │ data_loaders.loso_cross_validator:75	 │ train_subjects_final: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46)]


DEBUG:data_loaders.loso_cross_validator:train_subjects_final: [np.int64(1), np.int64(2), np.int64(3), np.int64(4), np.int64(5), np.int64(6), np.int64(7), np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(12), np.int64(13), np.int64(14), np.int64(15), np.int64(16), np.int64(17), np.int64(18), np.int64(19), np.int64(20), np.int64(21), np.int64(22), np.int64(23), np.int64(24), np.int64(25), np.int64(26), np.int64(27), np.int64(28), np.int64(29), np.int64(30), np.int64(31), np.int64(32), np.int64(33), np.int64(34), np.int64(35), np.int64(36), np.int64(37), np.int64(38), np.int64(39), np.int64(40), np.int64(41), np.int64(42), np.int64(43), np.int64(44), np.int64(45), np.int64(46)]


ResourceExhaustedError: Exception encountered when calling Softmax.call().

[1m{{function_node __wrapped__Softmax_device_/job:localhost/replica:0/task:0/device:GPU:0}} OOM when allocating tensor with shape[18,4,2500,2500] and type half on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Softmax] name: [0m

Arguments received by Softmax.call():
  • inputs=tf.Tensor(shape=(18, 4, 2500, 2500), dtype=float16)
  • mask=None