In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr import load_dataset, load_dataset_scheme, load_dataset_config, Dataset
from lib.ehr.interface import Patients, InterfaceConfig
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.DEBUG

In [4]:
tag = 'M4'
PATH = f'{os.environ.get("HOME")}/GP/ehr-data/mimic4-cohort'
sample = 100
cache =  f'cached_inteface/patients_{tag}_{sample or ""}'
dataset_config = load_dataset_config(tag, 
                                     sample=sample,
                                     path=PATH)

##### Possible Interface Scheme Configurations

In [5]:
import json
dataset_scheme = load_dataset_scheme(tag)
interface_schem_options = dataset_scheme.supported_target_scheme_options
print(json.dumps(interface_schem_options, sort_keys=True, indent=4))

DEBUG:root:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMIC4Eth32'>) scheme
DEBUG:root:Constructing gender (<class 'lib.ehr.coding_scheme.Gender'>) scheme
DEBUG:root:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme


{
    "dx": [
        "DxICD10",
        "DxCCS",
        "DxICD9",
        "DxFlatCCS"
    ],
    "ethnicity": [
        "MIMIC4Eth32",
        "MIMIC4Eth5"
    ],
    "gender": [
        "Gender"
    ],
    "outcome": [
        "dx_icd9_filter_v3_groups",
        "dx_flatccs_mlhc_groups",
        "dx_icd9_filter_v2_groups",
        "dx_flatccs_filter_v1",
        "dx_icd9_filter_v1"
    ]
}


In [6]:
interface_scheme = dataset_scheme.make_target_scheme_config(dx='DxICD9',
                                                            outcome='dx_icd9_filter_v3_groups',
                                                            ethnicity='MIMIC4Eth5')


# Demographic vector attributes
demographic_vector_conf = DemographicVectorConfig(
    age=False, 
    gender=False, 
    ethnicity=False
)
interface_config = InterfaceConfig(scheme=interface_scheme,
                                   dataset_scheme=dataset_scheme,
                                   demographic_vector=demographic_vector_conf,
                                   cache=cache)

In [7]:
from lib.ml import (ICENODE, ICENODEConfig, OutpatientEmbeddingConfig,  SplitConfig,
                    Trainer, TrainerConfig, TrainerReporting, OptimizerConfig, WarmupConfig, ReportingConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC, CodeLevelMetricConfig, MetricLevelsConfig,
                         LossMetricConfig,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric, CodeGroupTopAlarmAccuracyConfig)
from lib.ml import Experiment, ExperimentConfig, SplitConfig

import jax.random as jrandom

In [8]:
emb_dims = OutpatientEmbeddingConfig(dx=30, demo=0)
model_config = ICENODEConfig(mem=15, emb=emb_dims)
model_classname = ICENODE.__name__

In [9]:
trainer_config = TrainerConfig(optimizer=OptimizerConfig(opt='adam', lr=1e-3),
                          epochs=80,
                          batch_size=128,
                          dx_loss='balanced_focal_bce',
                          obs_loss='mse',
                          lead_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)




In [10]:
dx_loss = ["softmax_bce", "balanced_focal_softmax_bce", "balanced_focal_bce",
          "allpairs_exp_rank", "allpairs_hard_rank", "allpairs_sigmoid_rank"]
obs_loss =  ["mse", "mae", "rms"]
                
metrics_conf = [
    (CodeAUC, CodeLevelMetricConfig(aggregate_level=True, code_level=True)),
    (AdmissionAUC, MetricLevelsConfig(admission=False, aggregate=True, subject_aggregate=False)),
    (CodeGroupTopAlarmAccuracy, CodeGroupTopAlarmAccuracyConfig(n_partitions=5, top_k_list=[3, 5, 10, 15, 20])),
    (LossMetric, LossMetricConfig(dx_loss=dx_loss))
]
metrics_conf = [m.export_module_class(c) for m, c in metrics_conf]

In [11]:
reporting_conf = ReportingConfig(output_dir='icenode',
                                 console=True,
                                 model_stats=False,
                                 parameter_snapshots=True,
                                 config_json=True)

In [12]:
expt_config = ExperimentConfig(dataset=dataset_config,
                              interface=interface_config,
                              split=SplitConfig(train=0.8, val=0.1, test=0.1, balanced='admissions'),
                              trainer=trainer_config,
                              metrics=metrics_conf,
                              reporting=reporting_conf,
                              model=model_config,
                              model_classname=model_classname,
                              n_evals=100,
                              continue_training=True,
                              warmup=None,
                              reg_hyperparams=None)

In [13]:
experiment = Experiment(expt_config)

In [14]:
expt_config

ExperimentConfig(
  dataset=DatasetConfig(
    path='/home/asem/GP/ehr-data/mimic4-cohort',
    scheme=DatasetSchemeConfig(
      dx={'10': 'DxICD10', '9': 'DxICD9'},
      ethnicity='MIMIC4Eth32',
      gender='Gender',
      outcome=None
    ),
    scheme_classname='MIMIC4DatasetScheme',
    colname={
      'adm':
      {
        'admittime':
        'admittime',
        'dischtime':
        'dischtime',
        'index':
        'hadm_id',
        'subject_id':
        'subject_id'
      },
      'dx':
      {'admission_id': 'hadm_id', 'code': 'icd_code', 'version': 'icd_version'},
      'static':
      {
        'anchor_age':
        'anchor_age',
        'anchor_year':
        'anchor_year',
        'ethnicity':
        'race',
        'gender':
        'gender',
        'index':
        'subject_id'
      }
    },
    files={
      'adm':
      'adm_df.csv.gz',
      'dx':
      'dx_df.csv.gz',
      'static':
      'static_df.csv.gz'
    },
    sample=100,
    meta_fpath='',
    

In [16]:
result = experiment.run()

INFO:root:Cache does not match config, ignoring cache.
INFO:root:Loading subjects from scratch.
DEBUG:root:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMIC4Eth32'>) scheme
DEBUG:root:Constructing gender (<class 'lib.ehr.coding_scheme.Gender'>) scheme
DEBUG:root:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:root:Loading dataframe files
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4-cohort/adm_df.csv.gz
Please ensure that each individual file can fit in memory and
use the keyword ``blocksize=None to remove this message``
Setting ``blocksize=None``
  warn(
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4-cohort/dx_df.csv.gz
Please ensure that each individual file can fit in memory and
use t

ERROR:root:Code '5990' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code '42731' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '5569' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code '1970' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code 'V5811' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '43311' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '80702' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '7802' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code '42731' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '78079' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '0479' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code '25013' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '6826' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code '3229' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010.

ERROR:root:Code '25013' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '34550' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

ERROR:root:Code '34982' is missing. Accepted keys: dict_keys(['001', '001.0', '001.1', '001.9', '002', '002.0', '002.1', '002.2', '002.3', '002.9', '003', '003.0', '003.1', '003.2', '003.20', '003.21', '003.22', '003.23', '003.24', '003.29', '003.8', '003.9', '004', '004.0', '004.1', '004.2', '004.3', '004.8', '004.9', '005', '005.0', '005.1', '005.2', '005.3', '005.4', '005.8', '005.81', '005.89', '005.9', '006', '006.0', '006.1', '006.2', '006.3', '006.4', '006.5', '006.6', '006.8', '006.9', '007', '007.0', '007.1', '007.2', '007.3', '007.4', '007.5', '007.8', '007.9', '008', '008.0', '008.00', '008.01', '008.02', '008.03', '008.04', '008.09', '008.1', '008.2', '008.3', '008.4', '008.41', '008.42', '008.43', '008.44', '008.45', '008.46', '008.47', '008.49', '008.5', '008.6', '008.61', '008.62', '008.63', '008.64', '008.65', '008.66', '008.67', '008.69', '008.8', '009', '009.0', '009.1', '009.2', '009.3', '010', '010.0', '010.00', '010.01', '010.02', '010.03', '010.04', '010.05', '010

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[C

Loading to device:   0%|          | 0/10 [00:00<?, ?subject/s]



  0%|          | 0/80 [00:00<?, ?Epoch/s]

  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[150]), ShapedArray(float32[30,150]), ShapedArray(float32[30]), ShapedArray(bool[17375])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[15]), ShapedArray(float32[30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling convert_element_type for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


DEBUG:jax._src.interpreters.pxla:Compiling _integrate for with global shapes and types [ShapedArray(float32[225,45]), ShapedArray(float32[225]), ShapedArray(float32[225,225]), ShapedArray(float32[225]), ShapedArray(float32[45,225]), ShapedArray(float32[45]), ShapedArray(float32[45]), ShapedArray(float32[], weak_type=True), ShapedArray(float16[0])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[45])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assig

DEBUG:jax._src.interpreters.pxla:Compiling balanced_focal_bce for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), S

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _where for with global shapes and types [ShapedArray(bool[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _where for with global shapes and types [ShapedArray(bool[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[209]), ShapedArray(float32[]), ShapedArray(float32[])]. Argume

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[209])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[209])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[209])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling reduce_sum for with global shapes and types [ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}),).
DE

DEBUG:jax._src.interpreters.pxla:Compiling _integrate for with global shapes and types [ShapedArray(float32[2,45]), ShapedArray(float32[2]), ShapedArray(float16[0]), ShapedArray(float32[225,45]), ShapedArray(float32[225]), ShapedArray(float32[225,225]), ShapedArray(float32[225]), ShapedArray(float32[45,225]), ShapedArray(float32[45]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(float32[45])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(f

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[45,225]), ShapedArray(float32[45,225])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[45]), ShapedArray(float32[45])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[150,17375])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[150,17375])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[150,17375])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [Shaped

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[150]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[150]), ShapedArray(float32[150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[150]), ShapedArray(float32[150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDShardin

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[30])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[30]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[30])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevic

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[150,30]), ShapedArray(float32[150,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[2081,150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[2081,150]), ShapedArray(float32[2081,150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_a

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[2081]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[2081]), ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[225]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[225]), ShapedArray(float32[225])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[45,225]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[45,225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[45,225]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_c

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[30,60]), ShapedArray(float32[30,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[30,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[30,60]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lam

DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[45,30]), ShapedArray(float32[45,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[45,30]), ShapedArray(float32[45,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[45,15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_ass

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[15]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[15]), ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_opt

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[2081,150])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[2081,150])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Co

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[30,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[45,30])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[45,30])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Comp

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[150]), ShapedArray(float32[30,150]), ShapedArray(float32[30]), ShapedArray(bool[17375])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling _integrate for with global shapes and types [ShapedArray(float32[225,45]), ShapedArray(float32[225]), ShapedArray(float32[225,225]), ShapedArray(float32[225]), ShapedArray(float32[45,225]), ShapedArray(float32[45]), ShapedArray(float32[45]), ShapedArray(float32[], weak_type=True), ShapedArray(float16[0])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _decode for with global shapes and types [ShapedArray(float32[150,30]), ShapedArray(float32[150]), ShapedArray(float32[2081,150]), ShapedArray(float32[2081]), ShapedArray(float32[30])]. Argument mapping: (GSP

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling allpairs_exp_rank for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling allpairs_hard_rank for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling allpairs_sigmoid_rank for with global shape

Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[8])]. Argument mapping: (GSPMDShardin

  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


Embedding:   0%|          | 0/10 [00:00<?, ?subject/s]

  0%|          | 0.00/3016.61 [00:00<?, ?longitudinal-days/s]

  row.append(agg_f(field_vals))
  return onp.nansum(A * weights, axis=axis) / (
  row.append(agg_f(field_vals))
  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/1 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/75 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/75 [00:00<?, ?subject/s]

  0%|          | 0.00/58917.20 [00:00<?, ?longitudinal-days/s]

  group_alarm_acc[k] = group_tp.sum() / group_true.sum()


KeyboardInterrupt: 

###### 