In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import CPRDDemographicVectorConfig, DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.DEBUG

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")
cache_to_disk = None #'cached_inteface/m4inpatients_8000'
use_cached =  False # 'cached_inteface/m4inpatients_8000' #cache_to_disk        # False # 
cprd_interface_scheme = dict(dx='DxLTC212FlatCodes',
                             outcome='dx_cprd_ltc212',
#                              outcome='dx_cprd_ltc9809',
                            ethnicity='CPRDEthnicity5')

m3_interface_scheme = dict(dx='DxICD9',
#                            dx='DxCCS',
                           outcome='dx_flatccs_filter_v1',
#                            outcome='dx_icd9_filter_v3_groups',
                           ethnicity='MIMIC4Eth5'
#                            ethnicity='MIMIC3Eth37'
                          )

In [None]:
if use_cached:
    cprd_patients = Patients.load(use_cached)
    splits = cprd_patients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        cprd_dataset = load_dataset('M4', sample=1000)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = cprd_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

        # Demographic vector attributes
        demographic_vector_conf = DemographicVectorConfig(age=True,
                                                          gender=True,
                                                          ethnicity=True)
        # Load interface
        cprd_patients = Patients(cprd_dataset, demographic_vector_conf,
                                **m3_interface_scheme).load_subjects(num_workers=12)

        # Cache to disk
#         cprd_patients.save(cache_to_disk, overwrite=True)

DEBUG:root:Loading dataframe files
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4-cohort/adm_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4-cohort/dx_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4-cohort/static_df.csv


In [13]:
len(cprd_patients.subjects)

999

In [14]:
from lib.ml import (ICENODE, ICENODEDimensions, OutpatientEmbeddingDimensions, 
                    Trainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

In [15]:
emb_dims = OutpatientEmbeddingDimensions(dx=30, demo=5)
dims = ICENODEDimensions(mem=15, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = ICENODE(dims=dims, 
              scheme=cprd_patients.dataset.scheme,
              demographic_vector_config=cprd_patients.demographic_vector_config,
              key=key)

In [16]:
splits = cprd_patients.random_splits([0.9, 0.95], 
                                    balanced='admissions')

trainer = Trainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=80,
                    batch_size=128,
                    dx_loss='balanced_focal_softmax_bce')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=8,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(cprd_patients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_exp_rank', 'allpairs_hard_rank', 
                                   'allpairs_sigmoid_rank'))

metrics = [CodeAUC(cprd_patients), 
           AdmissionAUC(cprd_patients), 
           CodeGroupTopAlarmAccuracy(cprd_patients, n_partitions=5, 
                                     top_k_list=[3, 5, 10, 15, 20],
                                     train_split=splits[0]), 
           loss_metric]


reporting = TrainerReporting(output_dir='dx_icenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

In [17]:
res = trainer(m, cprd_patients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup,
              continue_training=False)

INFO:root:Warming up...


Loading to device: 0subject [00:00, ?subject/s]

  0%|          | 0/1 [00:00<?, ?Epoch/s]

DEBUG:jax._src.interpreters.pxla:Compiling _shuffle for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(int32[89])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/47 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/1257.91 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling _integrate for with global shapes and types [ShapedArray(float32[225,50]), ShapedArray(float32[225]), ShapedArray(float32[225,225]), ShapedArray(float32[225]), ShapedArray(float32[45,225]), ShapedArray(float32[45]), ShapedArray(float32[45]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _decode for with global shapes and types [ShapedArray(float32[150,30]), ShapedArray(float32[150]), ShapedArray(float32[2081,150]), ShapedArray(float32[2081]), ShapedArray(float32[30])]. Argument mapping: (GSP

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[150,17375])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[150,17375])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._sr

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[30,150]), ShapedArray(float32[30,150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[30,150])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[30,150]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling 

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[30]), ShapedArray(float32[30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[25,7])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[25,7]), ShapedArray(float32[25,7])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({repli

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[25]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[25]), ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_opt

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[5]), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[5]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 n

DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[150,30]), ShapedArray(float32[150,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[150,30]), ShapedArray(float32[150,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[2081,150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 dev

DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[2081]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[2081]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Comp

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[225]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[225]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_op

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[45,225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[45,225]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[45,225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignmen

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[30,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[30,60]), ShapedArray(float32[30,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[30,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compili

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[45,30]), ShapedArray(float32[45,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[45,30]), ShapedArray(float32[45,30])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[45,15])]. Argument mapping: (GSPMDSharding({replicated}), GS

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[15]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[15]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_optio

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[25,7])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[225,225])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/1641.27 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[8])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with glo

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/4293.07 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[11])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[11])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[11])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax

Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/2516.93 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_parti

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/6168.25 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[6]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mappin

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/1008.39 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/604.97 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[7])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[7]), ShapedArra

Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/10702.48 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[3])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[35])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replica

Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/3356.18 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/10537.49 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/427.76 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[4])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[4]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.x

Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/2016.02 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[4])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[20])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[20]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBU

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/4970.67 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/3202.06 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[12])]. Argument mapp

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG

Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/5691.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/5763.47 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/367.93 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[9])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/1994.31 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[2])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[18])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[18]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBU

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/5550.54 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/2242.85 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/4661.96 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/1920.88 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/5113.61 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/3031.25 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shap

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/6042.51 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/5 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

  0%|          | 0.00/8040.38 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[11])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[11]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/1411.56 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/5 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

  0%|          | 0.00/3720.51 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/883.04 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/4410.88 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/2806.42 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/246.91 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/2 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/2 [00:00<?, ?subject/s]

  0%|          | 0.00/1796.82 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/1 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/1 [00:00<?, ?subject/s]

  0%|          | 0.00/179.24 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/1208.05 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[22])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[22]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBU

Loading to device:   0%|          | 0/4 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/4 [00:00<?, ?subject/s]

  0%|          | 0.00/2255.50 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
INFO:root:[DONE] Warming up.
INFO:root:HPs: {'opt_config': {'opt': 'adam', 'lr': 0.001, 'decay_rate': None, 'reverse_schedule': False}, 'reg_hyperparams': None, 'epochs': 80, 'batch_size': 128}


Loading to device:   0%|          | 0/58 [00:00<?, ?subject/s]

  0%|          | 0/80 [00:00<?, ?Epoch/s]

DEBUG:jax._src.interpreters.pxla:Compiling _shuffle for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(int32[890])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/23 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/43 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/43 [00:00<?, ?subject/s]

  0%|          | 0.00/46950.89 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[3])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[115])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for

Loading to device:   0%|          | 0/43 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/43 [00:00<?, ?subject/s]

  0%|          | 0.00/41277.51 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 d

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[127])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:ja

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/25 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/25 [00:00<?, ?subject/s]

  0%|          | 0.00/26175.97 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[128])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean fo

Loading to device:   0%|          | 0/27 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/27 [00:00<?, ?subject/s]

  0%|          | 0.00/20631.87 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.in

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[142])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[142])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:j

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/34 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/34 [00:00<?, ?subject/s]

  0%|          | 0.00/36210.48 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[4])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[148])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[148])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/54 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/54 [00:00<?, ?subject/s]

  0%|          | 0.00/43068.03 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[3])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[131])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/33 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/33 [00:00<?, ?subject/s]

  0%|          | 0.00/26614.58 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[133])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/39 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/39 [00:00<?, ?subject/s]

  0%|          | 0.00/60962.31 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[7])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[119])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for

Loading to device:   0%|          | 0/42 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/42 [00:00<?, ?subject/s]

  0%|          | 0.00/32825.45 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[144])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]

Loading to device:   0%|          | 0/38 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/38 [00:00<?, ?subject/s]

  0%|          | 0.00/40194.61 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[4])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[132])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/42 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/42 [00:00<?, ?subject/s]

  0%|          | 0.00/34905.49 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/37 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/37 [00:00<?, ?subject/s]

  0%|          | 0.00/40189.32 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/48 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/48 [00:00<?, ?subject/s]

  0%|          | 0.00/47542.12 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/33 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/33 [00:00<?, ?subject/s]

  0%|          | 0.00/43524.11 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/42 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/42 [00:00<?, ?subject/s]

  0%|          | 0.00/44763.72 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[9])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[137])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/39 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/39 [00:00<?, ?subject/s]

  0%|          | 0.00/39402.32 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[7])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[135])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/39 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/39 [00:00<?, ?subject/s]

  0%|          | 0.00/38302.82 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[134])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/47 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/47 [00:00<?, ?subject/s]

  0%|          | 0.00/50919.22 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/20 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/20 [00:00<?, ?subject/s]

  0%|          | 0.00/17691.64 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[101])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[101]), 

Embedding:   0%|          | 0/20 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[150,17375]), ShapedArray(float32[150]), ShapedArray(float32[30,150]), ShapedArray(float32[30]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float16[7]), ShapedArray(bool[17375])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0.00/17691.64 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling _integrate for with global shapes and types [ShapedArray(float32[225,50]), ShapedArray(float32[225]), ShapedArray(float32[225,225]), ShapedArray(float32[225]), ShapedArray(float32[45,225]), ShapedArray(float32[45]), ShapedArray(float32[45]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _decode for with global shapes and types [ShapedArray(float32[150,30]), ShapedArray(float32[150]), ShapedArray(float32[2081,150]), ShapedArray(float32[2081]), ShapedArray(float32[30])]. Argument mapping: (GSP

DEBUG:jax._src.interpreters.pxla:Compiling allpairs_exp_rank for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling allpairs_hard_rank for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling allpairs_sigmoid_rank for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDS

Embedding:   0%|          | 0/58 [00:00<?, ?subject/s]

  0%|          | 0.00/52497.25 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[2])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[162])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:

Loading to device:   0%|          | 0/36 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/36 [00:00<?, ?subject/s]

  0%|          | 0.00/34759.10 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[3])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[163])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[163])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[163])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/52 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/52 [00:00<?, ?subject/s]

  0%|          | 0.00/53577.73 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/43 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/43 [00:00<?, ?subject/s]

  0%|          | 0.00/42394.08 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/34 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/34 [00:00<?, ?subject/s]

  0%|          | 0.00/38175.63 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling copy for with global shapes and types [ShapedArray(int32[890])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/23 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/34 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/34 [00:00<?, ?subject/s]

  0%|          | 0.00/37403.22 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[113])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for

Loading to device:   0%|          | 0/21 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/21 [00:00<?, ?subject/s]

  0%|          | 0.00/26229.47 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[7])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[151])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[151])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/41 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/41 [00:00<?, ?subject/s]

  0%|          | 0.00/31088.95 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[118])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for

Loading to device:   0%|          | 0/49 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/49 [00:00<?, ?subject/s]

  0%|          | 0.00/47620.44 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[2])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[146])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[146])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/43 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/43 [00:00<?, ?subject/s]

  0%|          | 0.00/37492.69 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[2])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[130])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

Loading to device:   0%|          | 0/27 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/27 [00:00<?, ?subject/s]

  0%|          | 0.00/32267.31 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/38 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/38 [00:00<?, ?subject/s]

  0%|          | 0.00/37218.55 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/71 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/71 [00:00<?, ?subject/s]

  0%|          | 0.00/49964.28 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/28 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/28 [00:00<?, ?subject/s]

  0%|          | 0.00/34119.27 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/44 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/44 [00:00<?, ?subject/s]

  0%|          | 0.00/48058.44 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/29 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/29 [00:00<?, ?subject/s]

  0%|          | 0.00/29730.99 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[124])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean fo

Loading to device:   0%|          | 0/34 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/34 [00:00<?, ?subject/s]

  0%|          | 0.00/32208.54 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/35 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/35 [00:00<?, ?subject/s]

  0%|          | 0.00/36215.68 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/46 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/46 [00:00<?, ?subject/s]

  0%|          | 0.00/48748.46 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[140])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/45 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/45 [00:00<?, ?subject/s]

  0%|          | 0.00/43345.18 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[129])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/45 [00:00<?, ?subject/s]

  0%|          | 0.00/43345.18 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[129])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Embedding:   0%|          | 0/58 [00:00<?, ?subject/s]

  0%|          | 0.00/52497.25 [00:00<?, ?longitudinal-days/s]

Loading to device:   0%|          | 0/16 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/16 [00:00<?, ?subject/s]

  0%|          | 0.00/30894.75 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/33 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/33 [00:00<?, ?subject/s]

  0%|          | 0.00/33217.64 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/35 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/35 [00:00<?, ?subject/s]

  0%|          | 0.00/39101.22 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[96])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[96]), ShapedArray(float32[]), ShapedArray(float32[])]. Argumen

Loading to device:   0%|          | 0/33 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/33 [00:00<?, ?subject/s]

  0%|          | 0.00/43688.83 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[166])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[166])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[166])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/46 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/46 [00:00<?, ?subject/s]

  0%|          | 0.00/48496.65 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/47 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/47 [00:00<?, ?subject/s]

  0%|          | 0.00/37157.20 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[8])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[136])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/48 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/48 [00:00<?, ?subject/s]

  0%|          | 0.00/47866.53 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/47 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/47 [00:00<?, ?subject/s]

  0%|          | 0.00/48924.68 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


  0%|          | 0/23 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/45 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/45 [00:00<?, ?subject/s]

  0%|          | 0.00/48072.10 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/71 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/71 [00:00<?, ?subject/s]

  0%|          | 0.00/57266.49 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


Loading to device:   0%|          | 0/37 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/37 [00:00<?, ?subject/s]

  0%|          | 0.00/42205.59 [00:00<?, ?longitudinal-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]


In [7]:
import copy
class A:
    def __init__(self, x, y):
        self.x = x
        self.y = y

a = A(2, 4)
b = copy.copy(a)
b.x = 6
b.y = 0
a.x

2

In [11]:
a.__class__.__name__

'A'