# Classification Baselines
## Random Forest and Logistic Regression

Statistical Model Baselines for LSM Classification Tasks

In [None]:
# @title Imports

import io
import functools
import random
import time
from typing import Any, Callable, Dict, Iterator, Tuple, Optional, Type, Union
import time
from collections import Counter

from absl import logging
from clu import metric_writers
from clu import periodic_actions
from clu import platform

import flax
from flax import jax_utils
import flax.linen as nn
import jax
import jax.numpy as jnp
import jax.profiler

import pandas as pd
import ml_collections
import numpy as np
import optax
import tensorflow as tf
import tensorflow_datasets as tfds
from typing import Optional, Sequence, Union

from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import average_precision_score
import sklearn.metrics as skmetrics
import sklearn.preprocessing as skpreprocessing

import matplotlib
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
%matplotlib inline

from colabtools import adhoc_import
with adhoc_import.Google3():
  from scenic.dataset_lib import dataset_utils
  from scenic.google.xm import xm_utils
  from scenic.model_lib.base_models import base_model
  from scenic.model_lib.base_models import model_utils
  from scenic.model_lib.layers import nn_ops
  from scenic.model_lib.layers import nn_layers
  from scenic.projects.baselines import vit
  from scenic.train_lib import optax as scenic_optax
  from scenic.train_lib import pretrain_utils
  from scenic.train_lib import train_utils

  from scenic.projects.multimask.models import model_utils as mm_model_utils

  from google3.experimental.largesensormodels.scenic.datasets import dataset_constants
  from google3.experimental.largesensormodels.scenic.datasets import lsm_activity_subset_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_mood_subj_dependent_preprocessed_40sps_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_tiny_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_combined_pretrain_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_fewshot_mood_vs_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_fewshot_remapped_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_remapped_activity_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_biological_sex_dataset
  from google3.experimental.largesensormodels.scenic.datasets import lsm_binned_age_dataset

  from google3.experimental.largesensormodels.scenic.models import lsm_vit as lsm_vit_mae
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_constants
  from google3.experimental.largesensormodels.scenic.models.lsm_vit_utils import model_utils as lsm_model_utils
  from google3.experimental.largesensormodels.scenic.trainers import lsm_mae_trainer

  from google3.pyglib import gfile


Batch = Dict[str, jnp.ndarray]
MetricFn = Callable[
    [jnp.ndarray, jnp.ndarray, Dict[str, jnp.ndarray]],
    Dict[str, Tuple[float, int]],
]
LossFn = Callable[
    [jnp.ndarray, Batch, Optional[jnp.ndarray], jnp.ndarray], float
]
LrFns = Dict[str, Callable[[jnp.ndarray], jnp.ndarray]]
Patch = Union[Tuple[int, int], Tuple[int, int, int]]


In [None]:
# @title mAP Function

def compute_mean_avg_precision(
    targets: Sequence[int],
    logits: Sequence[int],
    n_classes: int,
    return_per_class_ap=False
):
  """Computes mean average precision for multi-label classification.

  Forked from: google3/third_party/py/scenic/projects/av_mae/evaluation_lib.py

  Args:
    targets: List of length num_examples - classes indexed.
    logits: List of length num_examples - classes indexed.
    n_classes: Int number of classes
    return_per_class_ap: If True, return results for each class in the summary.

  Returns:
    summary: Dictionary containing the mean average precision, and maybe the
      average precision per class.
  """
  targets = np.array(targets)
  logits = np.array(logits)
  if logits.shape[0] != targets.shape[0]:
    raise ValueError(
        'Predictions and targets have different leading shape\n'
        f'Preds: {logits.shape}\nTargets: {targets.shape}'
    )

  # Convert preds / targets to OHE for sklearn metric compatibility.
  if n_classes > 2:
    labels = skpreprocessing.label_binarize(
        targets, classes=np.arange(n_classes)
    )
  else:
    labels = tf.one_hot(targets, n_classes)

  # Get average precision across all classes.
  average_precisions = []
  summary = {}
  for i in range(n_classes):
    ave_precision = skmetrics.average_precision_score(
        labels[:, i], logits[:, i]
    )
    if return_per_class_ap:
      summary_key = f'class_{i}_AP'
      summary[summary_key] = ave_precision
    average_precisions.append(ave_precision)

  # Update and return metrics.
  summary['nanmean_AP'] = np.nanmean(average_precisions)
  summary['mAP'] = np.mean(average_precisions)
  return summary

In [None]:
# @title Sample Model Config

r"""A config to train a Tiny ViT MAE on LSM 5M dataset.

Forked from google3/third_party/py/scenic/projects/multimask/configs/mae_cifar10_tiny.py

To run on XManager:
gxm third_party/py/scenic/google/xm/launch_xm.py -- \
--binary //experimental/largesensormodels/scenic:main \
--config=experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py \
--platform=vlp_4x4 \
--exp_name=lsm_mae_tier2_TinyShallow_10_5_res \
--workdir=/cns/dz-d/home/xliucs/lsm/xm/\{xid\} \
--xm_resource_alloc=group:mobile-dynamic/h2o-ai-gqm-quota \
--priority=200

To run locally:
./third_party/py/scenic/google/runlocal.sh \
--uptc="" \
--binary=//experimental/largesensormodels/scenic:main \
--config=$(pwd)/experimental/largesensormodels/scenic/configs/mae_lsm_tiny.py:runlocal
"""


# To set constants.
DATASET_NAME = 'lsm_300min_pretraining_165K_n10'
CACHE_DATASET = False
TRAIN_DATA_SIZE = 1000000  # 100k train samples
BATCH_SIZE = 8
NUMBER_OF_EPOCH = 100
REPEAT_DATA = False

# Model variant / patch H (time steps) / patch W (features)
VARIANT = 'TiShallow/10/5'
LRS = [1e-3]
TOKEN_MASK_PROB = 'constant_0.8'
LOSS_ONLY_MASKED_TOKENS = True
USE_DATETIME_FEATURES = True
USE_TRAIN_AUGMENTATIONS = True
TRAIN_AUGMENTATIONS = ['stretch', 'flip', 'noise']
OHE_LABELS = True

# Derived constants.
TRAIN_DATA_SIZE = min(
    TRAIN_DATA_SIZE,
    dataset_constants.lsm_dataset_constants[DATASET_NAME]['num_train_examples']
)

STEPS_PER_EPOCH = max(1, int(TRAIN_DATA_SIZE / BATCH_SIZE))
NUM_TRAIN_STEPS = int(NUMBER_OF_EPOCH * STEPS_PER_EPOCH)

LOG_EVAL_SUMMARY_STEPS = STEPS_PER_EPOCH
LOG_CHECKPOINT_STEPS = LOG_EVAL_SUMMARY_STEPS * 5
MAX_NUM_CHECKPOINTS = int(NUM_TRAIN_STEPS / LOG_CHECKPOINT_STEPS)


def get_config(runlocal=''):
  """Returns the ViT experiment configuration."""

  runlocal = bool(runlocal)

  # Experiment.
  config = ml_collections.ConfigDict()
  config.experiment_name = f'electrodes-mae-{DATASET_NAME}-{TRAIN_DATA_SIZE}'
  config.dataset_name = f'lsm_prod/{DATASET_NAME}'

  # Dataset.
  config.data_dtype_str = 'float32'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.dataset = f'lsm_prod/{DATASET_NAME}'
  # config.dataset_configs.num_classes = NUM_CLASSES
  config.dataset_configs.train_split = 'train'  # train data split
  config.dataset_configs.train_num_samples = TRAIN_DATA_SIZE  # train sample
  # eval data split - note: this split is used for validation and test.
  config.dataset_configs.eval_split = 'test'
  config.dataset_configs.cache_dataset = CACHE_DATASET
  config.dataset_configs.prefetch_to_device = 2
  # Shuffle_buffer_size is per host, so small-ish is ok.
  config.dataset_configs.shuffle_buffer_size = 250_000
  config.dataset_configs.repeat_data = REPEAT_DATA
  config.dataset_configs.ohe_labels = OHE_LABELS

  # Model.
  if len(VARIANT.split('/')) == 3:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[2]  # patch height
  elif len(VARIANT.split('/')) == 2:
    version = VARIANT.split('/')[0]  # model variant
    patch_h = VARIANT.split('/')[1]  # patch width
    patch_w = VARIANT.split('/')[1]  # patch height
  else:
    raise ValueError(f'Invalid model variant: {VARIANT}')

  version = 'Deb' if runlocal else version
  config.model_name = 'lsm_vit_mae'
  config.model = ml_collections.ConfigDict()
  # encoder
  config.model.hidden_size = model_constants.HIDDEN_SIZES[version]
  config.model.patches = ml_collections.ConfigDict()
  config.model.patches.size = tuple([int(patch_h), int(patch_w)])
  config.model.num_heads = model_constants.NUM_HEADS[version]
  config.model.mlp_dim = model_constants.MLP_DIMS[version]
  config.model.num_layers = model_constants.NUM_LAYERS[version]
  config.model.dropout_rate = 0.0
  config.model.classifier = 'none'  # Has to be "none" for the autoencoder
  config.model.representation_size = None
  config.model.positional_embedding = 'sinusoidal_2d'
  config.model.positional_embedding_decoder = 'sinusoidal_2d'
  # decoder
  config.model.decoder_config = ml_collections.ConfigDict()
  config.model.decoder_config.hidden_size = (
      model_constants.DECODER_HIDDEN_SIZES[version]
  )
  config.model.decoder_config.mlp_dim = model_constants.DECODER_MLP_DIMS[
      version
  ]
  config.model.decoder_config.num_layers = model_constants.DECODER_NUM_LAYERS[
      version
  ]
  config.model.decoder_config.num_heads = model_constants.DECODER_NUM_HEADS[
      version
  ]
  config.model.decoder_config.dropout_rate = 0.0
  config.model.decoder_config.attention_dropout_rate = 0.0

  config.masked_feature_loss = ml_collections.ConfigDict()
  config.masked_feature_loss.targets_type = 'rgb'
  config.masked_feature_loss.token_mask_probability = TOKEN_MASK_PROB
  config.masked_feature_loss.loss_only_masked_tokens = LOSS_ONLY_MASKED_TOKENS
  config.masked_feature_loss.loss_type = 'squared'  # 'squared' or 'absolute'

  # Datetime features.
  config.use_datetime_features = USE_DATETIME_FEATURES

  # Training.
  config.trainer_name = 'lsm_mae_trainer'
  config.batch_size = BATCH_SIZE
  config.num_training_steps = NUM_TRAIN_STEPS
  config.log_eval_steps = LOG_EVAL_SUMMARY_STEPS
  config.log_summary_steps = LOG_EVAL_SUMMARY_STEPS
  config.rng_seed = 42
  config.use_train_augmentations = USE_TRAIN_AUGMENTATIONS
  config.train_augmentations = TRAIN_AUGMENTATIONS
  sched = ml_collections.ConfigDict()
  sched.re = '(.*)'
  sched.lr_configs = ml_collections.ConfigDict()
  sched.lr_configs.learning_rate_schedule = 'compound'
  sched.lr_configs.factors = 'constant * cosine_decay * linear_warmup'
  sched.lr_configs.total_steps = NUM_TRAIN_STEPS
  sched.lr_configs.steps_per_cycle = sched.lr_configs.total_steps
  sched.lr_configs.warmup_steps = STEPS_PER_EPOCH * NUMBER_OF_EPOCH * 0.05
  sched.lr_configs.base_learning_rate = LRS[0]
  config.schedule = ml_collections.ConfigDict({'all': sched})

  # *Single* optimizer.
  optim = ml_collections.ConfigDict()
  optim.optax_name = 'scale_by_adam'
  # optim.optax = dict(mu_dtype='bfloat16')
  optim.optax_configs = ml_collections.ConfigDict({  # Optimizer settings.
      'b1': 0.9,
      'b2': 0.95,
  })
  config.optax = dict(mu_dtype='bfloat16')
  optim.max_grad_norm = 1.0

  optim.weight_decay = 1e-4
  optim.weight_decay_decouple = True
  config.optimizer = optim

  # Logging.
  config.write_summary = True
  config.xprof = True  # Profile using xprof.
  config.checkpoint = True  # Do checkpointing.
  config.checkpoint_steps = LOG_CHECKPOINT_STEPS
  config.debug_train = False  # Debug mode during training.
  config.debug_eval = False  # Debug mode during eval.
  config.max_checkpoints_to_keep = MAX_NUM_CHECKPOINTS
  # BEGIN GOOGLE-INTERNAL
  if runlocal:
    # Current implementation fails with UPTC.
    config.count_flops = False
  # END GOOGLE-INTERNAL

  return config


# BEGIN GOOGLE-INTERNAL
def get_hyper(hyper):
  """Defines the hyper-parameters sweeps for doing grid search."""
  return hyper.product([
      hyper.sweep('config.schedule.all.lr_configs.base_learning_rate', LRS),
  ])


In [None]:
# @title Set random seeds

random.seed(42)  # Replace 42 with any integer seed
np.random.seed(42)  # Replace 42 with any integer seed
tf.random.set_seed(42)  # Replace 42 with any integer seed


# Activity Recognition

In [None]:
# @title Load Train and Test Datasets

# Constants
DATASET_NAME = 'lsm_300min_600_activities_remapped_8class'
TRAIN_DATA_SIZE = None

# Load Train Dataset
BATCH_SIZE = dataset_constants.lsm_dataset_constants[
    'lsm_300min_600_activities_remapped_8class'
]['num_train_examples']

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_remapped_activity_dataset.get_dataset(config, data_rng)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.train_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_train = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_train = jnp.reshape(data_batch['label'], [p*b, -1])

# Load Test Dataset
BATCH_SIZE = dataset_constants.lsm_dataset_constants[
    'lsm_300min_600_activities_remapped_8class'
]['num_test_examples']

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_remapped_activity_dataset.get_dataset(config, data_rng)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.valid_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_test = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_test = jnp.reshape(data_batch['label'], [p*b, -1])

# Print
print('Data keys: ', data_keys)
print('X_train shape: ', X_train.shape)
print('y_train shape: ', y_train.shape)
print('X_test shape: ', X_test.shape)
print('y_test shape: ', y_test.shape)

In [None]:
# @title Random Forest
n_estimators = 500
max_samples = 0.75
max_depth = 10
clf = RandomForestClassifier(
    n_estimators=n_estimators,
    max_depth=max_depth, random_state=42,
    max_samples=max_samples, n_jobs=5,
)

y_train_rf = jnp.argmax(y_train, axis=1)
y_test_rf = jnp.argmax(y_test, axis=1)

# Fit Model
clf.fit(X_train, y_train_rf)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_rf)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('Acc:', acc)
print('mAP:', mean_average_precision)

In [None]:
# @title Logistic Regression

clf = LogisticRegression(random_state=42, n_jobs=1, multi_class='ovr', max_iter=1000, solver='saga')

y_train_lr = jnp.argmax(y_train, axis=1)
y_test_lr = jnp.argmax(y_test, axis=1)

# Fit Model
start_t = time.time()
clf.fit(X_train, y_train_lr)
end_t = time.time()

print('\nFit time', end_t - start_t)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_lr)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('\nAcc:', acc)
print('mAP:', mean_average_precision)

# Exercise Detection

In [None]:
# @title Load Train and Test Datasets

# Constants
DATASET_NAME = 'lsm_300min_mood_vs_activity'
TRAIN_DATA_SIZE = None

# Load Train Dataset
act_samples = dataset_constants.lsm_dataset_constants[
    'lsm_300min_600_activities_balanced']['num_train_examples']
mood_samples = dataset_constants.lsm_dataset_constants[
    'lsm_300min_2000_mood_balanced']['num_train_examples']
BATCH_SIZE = act_samples + mood_samples

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_mood_vs_activity_dataset.get_dataset(config, data_rng)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.train_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_train = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_train = jnp.reshape(data_batch['label'], [p*b, -1])

# Load Test Dataset
act_samples = dataset_constants.lsm_dataset_constants[
    'lsm_300min_600_activities_balanced']['num_test_examples']
mood_samples = dataset_constants.lsm_dataset_constants[
    'lsm_300min_2000_mood_balanced']['num_test_examples']
BATCH_SIZE = act_samples + mood_samples

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = lsm_mood_vs_activity_dataset.get_dataset(config, data_rng)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.valid_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_test = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_test = jnp.reshape(data_batch['label'], [p*b, -1])

# Print
print('Data keys: ', data_keys)
print('X_train shape: ', X_train.shape)
print('y_train shape: ', y_train.shape)
print('X_test shape: ', X_test.shape)
print('y_test shape: ', y_test.shape)

In [None]:
# @title Random Forest
n_estimators = 500
max_samples = 0.75
max_depth = 10
clf = RandomForestClassifier(
    n_estimators=n_estimators,
    max_depth=max_depth, random_state=42,
    max_samples=max_samples, n_jobs=5,
)

y_train_rf = jnp.argmax(y_train, axis=1)
y_test_rf = jnp.argmax(y_test, axis=1)

# Fit Model
clf.fit(X_train, y_train_rf)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_rf)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('Acc:', acc)
print('mAP:', mean_average_precision)


In [None]:
# @title Logistic Regression

clf = LogisticRegression(random_state=42, n_jobs=1, multi_class='ovr', max_iter=1000, solver='saga')

y_train_lr = jnp.argmax(y_train, axis=1)
y_test_lr = jnp.argmax(y_test, axis=1)

# Fit Model
start_t = time.time()
clf.fit(X_train, y_train_lr)
end_t = time.time()

print('\nFit time', end_t - start_t)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_lr)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('\nAcc:', acc)
print('mAP:', mean_average_precision)

# Subj. Dependent Mood

In [None]:
# @title Load Train and Test Datasets

# Constants
DATASET_NAME = 'lsm_300min_2000_mood_subject_dependent_preprocessed_40sps'
TRAIN_DATA_SIZE = None

# Load Train Dataset
BATCH_SIZE = 3553

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = (
    lsm_mood_subj_dependent_preprocessed_40sps_dataset.get_preprocessed_dataset(
        config, data_rng,
    )
)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.train_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_train = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_train = jnp.reshape(data_batch['label'], [p*b, -1])

# Load Test Dataset
BATCH_SIZE = 1154

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = (
    lsm_mood_subj_dependent_preprocessed_40sps_dataset.get_preprocessed_dataset(
        config, data_rng,
    )
)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.valid_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_test = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_test = jnp.reshape(data_batch['label'], [p*b, -1])

# Print
print('Data keys: ', data_keys)
print('X_train shape: ', X_train.shape)
print('y_train shape: ', y_train.shape)
print('X_test shape: ', X_test.shape)
print('y_test shape: ', y_test.shape)

In [None]:
# @title Random Forest
n_estimators = 500
max_samples = 0.75
max_depth = 10
clf = RandomForestClassifier(
    n_estimators=n_estimators,
    max_depth=max_depth, random_state=42,
    max_samples=max_samples, n_jobs=5,
)

y_train_rf = jnp.argmax(y_train, axis=1)
y_test_rf = jnp.argmax(y_test, axis=1)

# Fit Model
clf.fit(X_train, y_train_rf)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_rf)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('Acc:', acc)
print('mAP:', mean_average_precision)


In [None]:
# @title Logistic Regression
max_iter = 1000

clf = LogisticRegression(random_state=42, n_jobs=1, multi_class='ovr', max_iter=max_iter, solver='saga')

y_train_lr = jnp.argmax(y_train, axis=1)
y_test_lr = jnp.argmax(y_test, axis=1)

# Fit Model
start_t = time.time()
clf.fit(X_train, y_train_lr)
end_t = time.time()

print('\nFit time', end_t - start_t)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_lr)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('\nAcc:', acc)
print('mAP:', mean_average_precision)

# Biological Sex

In [None]:
# @title Load Train and Test Datasets

# Constants
DATASET_NAME = 'lsm_300min_600_biological_sex'
TRAIN_DATA_SIZE = None

# Load Train Dataset
BATCH_SIZE = 14203

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = (
    lsm_biological_sex_dataset.get_dataset(
        config, data_rng,
    )
)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.train_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_train = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_train = jnp.reshape(data_batch['label'], [p*b, -1])

# Load Test Dataset
BATCH_SIZE = 3250

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = (
    lsm_biological_sex_dataset.get_dataset(
        config, data_rng,
    )
)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.valid_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_test = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_test = jnp.reshape(data_batch['label'], [p*b, -1])

# Print
print('Data keys: ', data_keys)
print('X_train shape: ', X_train.shape)
print('y_train shape: ', y_train.shape)
print('X_test shape: ', X_test.shape)
print('y_test shape: ', y_test.shape)

In [None]:
# @title Random Forest
n_estimators = 500
max_samples = 0.75
max_depth = 10
clf = RandomForestClassifier(
    n_estimators=n_estimators,
    max_depth=max_depth, random_state=42,
    max_samples=max_samples, n_jobs=5,
)

y_train_rf = jnp.argmax(y_train, axis=1)
y_test_rf = jnp.argmax(y_test, axis=1)

# Fit Model
clf.fit(X_train, y_train_rf)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_rf)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('Acc:', acc)
print('mAP:', mean_average_precision)


In [None]:
# @title Logistic Regression

clf = LogisticRegression(random_state=42, n_jobs=1, multi_class='ovr', max_iter=1000, solver='saga')

y_train_lr = jnp.argmax(y_train, axis=1)
y_test_lr = jnp.argmax(y_test, axis=1)

# Fit Model
start_t = time.time()
clf.fit(X_train, y_train_lr)
end_t = time.time()

print('\nFit time', end_t - start_t)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_lr)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('\nAcc:', acc)
print('mAP:', mean_average_precision)

# Binned Age

In [None]:
# @title Load Train and Test Datasets

# Constants
DATASET_NAME = 'lsm_300min_600_binnned_age'
TRAIN_DATA_SIZE = None

# Load Train Dataset
BATCH_SIZE = 14372

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = (
    lsm_binned_age_dataset.get_dataset(
        config, data_rng,
    )
)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.train_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_train = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_train = jnp.reshape(data_batch['label'], [p*b, -1])

# Load Test Dataset
BATCH_SIZE = 3262

config = get_config(runlocal=False)  # must be false to get full dataset
rng = jax.random.PRNGKey(config.rng_seed)
data_rng, rng = jax.random.split(rng)
dataset = (
    lsm_binned_age_dataset.get_dataset(
        config, data_rng,
    )
)

# Get Batch and Format to Desired Shape
data_batch = next(dataset.valid_iter)
data_keys = data_batch.keys()
p, b, t, f, c = data_batch['input_signal'].shape
X_test = jnp.reshape(data_batch['input_signal'], [p*b, -1])
y_test = jnp.reshape(data_batch['label'], [p*b, -1])

# Print
print('Data keys: ', data_keys)
print('X_train shape: ', X_train.shape)
print('y_train shape: ', y_train.shape)
print('X_test shape: ', X_test.shape)
print('y_test shape: ', y_test.shape)

In [None]:
# @title Random Forest
n_estimators = 500
max_samples = 0.75
max_depth = 10
clf = RandomForestClassifier(
    n_estimators=n_estimators,
    max_depth=max_depth, random_state=42,
    max_samples=max_samples, n_jobs=5,
)

y_train_rf = jnp.argmax(y_train, axis=1)
y_test_rf = jnp.argmax(y_test, axis=1)

# Fit Model
clf.fit(X_train, y_train_rf)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_rf)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('Acc:', acc)
print('mAP:', mean_average_precision)


In [None]:
# @title Logistic Regression

clf = LogisticRegression(random_state=42, n_jobs=1, multi_class='ovr', max_iter=1000, solver='saga')

y_train_lr = jnp.argmax(y_train, axis=1)
y_test_lr = jnp.argmax(y_test, axis=1)

# Fit Model
start_t = time.time()
clf.fit(X_train, y_train_lr)
end_t = time.time()

print('\nFit time', end_t - start_t)

# Eval
# Accuracy
acc = clf.score(X_test, y_test_lr)

# Mean Average Precision
y_probs = clf.predict_proba(X_test)
mean_average_precision = average_precision_score(y_test, y_probs, average='macro')

print('\nAcc:', acc)
print('mAP:', mean_average_precision)