In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd
import dask

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'gpu')
# jax.config.update('jax_log_compiles', True)
# jax.config.update("jax_debug_nans", True)
# jax.config.update("jax_enable_x64", True)

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset, load_dataset_scheme, Dataset
from lib.ehr.interface import Patients
from lib.ehr.concepts import DemographicVectorConfig


In [3]:
import logging
logging.root.level = logging.INFO

In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")
# cache_to_disk = 'cached_inteface/m4inpatients_8000'
use_cached =  'cached_inteface/m4inpatients_8000' #cache_to_disk        # False # 


In [5]:
if use_cached:
    m4inpatients = Patients.load(use_cached)
    splits = m4inpatients.dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')

else:
    with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes', num_workers=12):
        
        # Load dataset
        m4icu_dataset = load_dataset('M4ICU', sample=8000)
        # Use training-split for fitting the outlier_remover and the scalers.
        splits = m4icu_dataset.random_splits([0.8, 0.9], random_seed=42, balanced='admissions')
        # Outlier removal
        outlier_remover = m4icu_dataset.fit_outlier_remover(splits[0])
        m4icu_dataset = m4icu_dataset.remove_outliers(outlier_remover)
        
        # Scale
        scalers = m4icu_dataset.fit_scalers(splits[0])
        m4icu_dataset = m4icu_dataset.apply_scalers(scalers)
        
        # Demographic vector attributes
        demographic_vector_conf = DemographicVectorConfig(age=True,
                                                      gender=True,
                                                      ethnicity=True)
        # Load interface
        m4inpatients = Patients(m4icu_dataset, demographic_vector_conf).load_subjects(num_workers=12)

        # Cache to disk
        m4inpatients.save(cache_to_disk, overwrite=True)

DEBUG:root:Loading dataframe files
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/adm_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/dx_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/static_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/obs_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_input.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_proc.csv
DEBUG:root:[DONE] Loading dataframe files
DEBUG:root:Preprocess admissions
DEBUG:root:Removing subjects with at least one negative adm_interval: 16
DEBUG:root:adm: Merging overlapping admissions
DEBUG:root:adm: Merged 56 overlapping admissions
DEBUG:root:[DONE] Preprocess admissions
DEBUG:root:Matching admission_id
DEBUG:root:[DONE] Matching admission_id
DEBUG:root:Time casting..
DEBUG:root:[DONE] Time casting..
DEBUG:root:Constructing dx_icd9 (<class 'lib.ehr.coding_

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


In [6]:
len(m4inpatients.subjects)

7984

In [7]:
# m4inpatients.size_in_bytes() / 1024 ** 3

In [8]:
# val_batch = m4inpatients.device_batch(splits[1])

In [9]:
# tst_batch = m4inpatients.device_batch(splits[2])

In [10]:
# val_batch.size_in_bytes() / 1024 ** 3, tst_batch.size_in_bytes() / 1024 ** 3

In [11]:
# batch = m4inpatients.device_batch(splits[0][:32])

In [12]:
# batch.size_in_bytes() / 1024 ** 3

In [13]:
# len(batch.subjects)

In [14]:
# batch.n_admissions()

In [15]:
# batch.n_segments()

In [16]:
# batch.n_obs_times()

In [17]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [18]:
# s = batch.subjects[splits[0][6]].admissions[0]
# s.observables[0].value

In [19]:
# batch.interval_hours(splits[0][:10])

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [20]:
from lib.ml import (InICENODE, InICENODEDimensions, InpatientEmbeddingDimensions, 
                    InTrainer, TrainerReporting, OptimizerConfig, WarmupConfig)
from lib.metric import  (CodeAUC, UntilFirstCodeAUC, AdmissionAUC,
                      CodeGroupTopAlarmAccuracy, LossMetric, ObsCodeLevelLossMetric)

import jax.random as jrandom

DEBUG:matplotlib:matplotlib data path: /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/matplotlib/mpl-data
DEBUG:matplotlib:CONFIGDIR=/home/asem/.config/matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is linux
DEBUG:matplotlib:CACHEDIR=/home/asem/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/asem/.cache/matplotlib/fontlist-v330.json


In [21]:
emb_dims = InpatientEmbeddingDimensions(dx=10, inp=10, proc=10, demo=5, inp_proc_demo=15)
dims = InICENODEDimensions(mem=15, obs=25, emb=emb_dims)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4inpatients.dataset.scheme,
              demographic_vector_config=m4inpatients.demographic_vector_config,
              key=key)

DEBUG:jax._src.interpreters.pxla:Compiling _threefry_seed for with global shapes and types [ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(uint32[5,2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEB

DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_t

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with g

DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_t

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global s

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global s

DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_t

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


In [22]:
# res = m.batch_predict(m4inpatients.device_batch(), leave_pbar=True)

In [23]:
trainer = InTrainer(optimizer_config=OptimizerConfig(opt='adam', lr=1e-3),
                    reg_hyperparams=None,
                    epochs=20,
                    batch_size=512,
                    dx_loss='balanced_focal_bce',
                    obs_loss='mse')

warmup = WarmupConfig(epochs=0.1, 
                      batch_size=32,
                      opt='adam', lr=1e-3, 
                      decay_rate=0.5)

loss_metric =  LossMetric(m4inpatients, 
                          dx_loss=('softmax_bce', 'balanced_focal_softmax_bce', 
                                   'balanced_focal_bce', 'allpairs_sigmoid_rank'),
                          obs_loss=('mse', 'rms', 'mae'))
obs_code_loss_metric = ObsCodeLevelLossMetric(m4inpatients, obs_loss=('mse', 'rms', 'mae'))

metrics = [#CodeAUC(m4inpatients), AdmissionAUC(m4inpatients), 
           obs_code_loss_metric, loss_metric]


reporting = TrainerReporting(output_dir='inicenode',
                             metrics=metrics,
                             console=True,
                             parameter_snapshots=True,
                             config_json=True)

In [None]:
# pbar_metric = LossMetric(m4inpatients, 
#                           dx_loss=('allpairs_sigmoid_rank',),
#                           obs_loss=('mse', ))
# rank_loss_extractor = pbar_metric.value_extractor({'field': 'allpairs_sigmoid_rank'})
# obs_loss_extractor = pbar_metric.value_extractor({'field': 'mse'})
# def pbar_desc_update(predictions):
#     df = pbar_metric.to_df(0, predictions)
#     return f'dx_rank_loss={rank_loss_extractor(df).item(): .3f}, obs_mse_loss={obs_loss_extractor(df).item(): .3f}'


splits = m4inpatients.random_splits([0.9, 0.95], 
                                    balanced='admissions')
res = trainer(m, m4inpatients, 
              splits=splits,
              reporting=reporting,
              n_evals=100,
              warmup_config=warmup)

INFO:root:Warming up...
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping:

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadc

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replica

Loading to device: 0subject [00:00, ?subject/s]

  0%|          | 0/1 [00:00<?, ?Epoch/s]

DEBUG:jax._src.interpreters.pxla:Compiling _shuffle for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(int32[721])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


  0%|          | 0/73 [00:00<?, ?Batch/s]

Loading to device:   0%|          | 0/14 [00:00<?, ?subject/s]





Embedding:   0%|          | 0/14 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling atleast_1d for with global shapes and types [ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling convert_element_type for with global shapes and types [ShapedArray(float32[1], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling convert_element_type for with global shapes and types [ShapedArray(bool[6])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float16[1]), ShapedArray(f

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(bool[99,318]), ShapedArray(bool[99,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSP

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[799,318]), ShapedArray(bool[799,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})

  0%|          | 0.00/290.94 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling join_state for with global shapes and types [ShapedArray(float32[10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[100])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(bool[100])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[60]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({rep

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[3])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[3,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[3,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global

DEBUG:jax._src.interpreters.pxla:Compiling _safe_integrate for with global shapes and types [ShapedArray(float32[250,65]), ShapedArray(float32[250]), ShapedArray(float32[250,250]), ShapedArray(float32[250]), ShapedArray(float32[50,250]), ShapedArray(float32[50]), ShapedArray(float32[]), ShapedArray(float32[50]), ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling split_state for with global shapes and types [ShapedArray(float32[50])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gp

DEBUG:jax._src.interpreters.pxla:Compiling __call__ for with global shapes and types [ShapedArray(float32[300,60]), ShapedArray(float32[300]), ShapedArray(float32[60,300]), ShapedArray(float32[60]), ShapedArray(float32[150,60]), ShapedArray(float32[150,50]), ShapedArray(float32[150]), ShapedArray(float32[50]), ShapedArray(float32[50]), ShapedArray(float32[60]), ShapedArray(float16[60]), ShapedArray(bool[60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling atleast_2d for with global shapes and types [ShapedArray(float32

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[4,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[4,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DE

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[9,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(i

DEBUG:jax._src.interpreters.pxla:Compiling relu for with global shapes and types [ShapedArray(float32[50])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[50]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling matmul for with global shapes and types [ShapedArray(float32[2081,50]), ShapedArray(float32[50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global sha

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[7,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60]), ShapedArray(float32[1,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[97]), Sha

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[203]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[203]), ShapedArray(int32[203]), ShapedArray(int32[203])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[203])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[142,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[142]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[141]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._s

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[430]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[499,15]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._s

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[20]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[19]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global 

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[4]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[3]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[17])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_part

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[182]), ShapedArray(int32[182]), ShapedArray(int32[182])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[182])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[182,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._s

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[47]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[47]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[47]), ShapedArray(int32[47]), ShapedArray(int32[47])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMD

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[37,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[37,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[5,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_repl

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[717]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[717]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[717]), ShapedArray(int32[717]), ShapedArray(int32[717])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[117]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[117]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[117]), ShapedArray(int32[117]), ShapedArray(int32[117])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[62]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[61]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[700])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global sh

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[76]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[76]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[76]), ShapedArray(int32[76]), ShapedArray(int32[76])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[48]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[93]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[93]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bri

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[3,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[3,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[0])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[0])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[4])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[4,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[4,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global sh

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[9,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global sha

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[15,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[15,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[13,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[14,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling balanced_focal_bce for with global shapes and types [ShapedArray(bool[2081]), ShapedArray(float32[2081]), ShapedArray(bool[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(float32[1]), ShapedArray(f

DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[30])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _where for with global shapes and types [ShapedArray(bool[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _where for with global shap

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[14])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:g

DEBUG:jax._src.interpreters.pxla:Compiling split_state for with global shapes and types [ShapedArray(float32[10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _safe_integrate for with global shapes and types [ShapedArray(float32[2,50]), ShapedArray(float32[2]), ShapedArray(float32[15]), ShapedArray(float32[250,65]), ShapedArray(float32[250]), ShapedArray(float32[250,250]), ShapedArray(float32[250]), ShapedArray(float32[50,250]), ShapedArray(float32[50]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(float32[50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), 

DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[125]), ShapedArray(float32[125]), ShapedArray(float32[125])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling matmul for with global shapes and types [ShapedArray(float32[25]), ShapedArray(float32[125,25]), ShapedArray(float32[125])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling split_state for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[150,50]), ShapedArray(float32[150,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[150]), ShapedArray(float32[150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[60,125]), ShapedArray(float32[60,125])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._sr

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[2081]), ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[2081,50]), ShapedArray(float32[2081,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[50,10]), ShapedArray(float32[50,10])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[75,25]), ShapedArray(float32[75,25])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[75]), ShapedArray(float32[75])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[15,75]), ShapedArray(float32[15,75])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.inte

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_update_slice for with global shapes and types [ShapedArray(float32[799,15]), ShapedArray(float32[1,15]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[799,15]), ShapedArray(float32[799,15])]. Argument mapping: (

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[16])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and 

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[499,15]), ShapedArray(float32[499,15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(bool[50]), ShapedArray(float32[17375]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(bool[25]), ShapedArray(float32[7]), ShapedArray(float32[499,75]), ShapedArray(float32[15,75]), ShapedArray(bool[499,75]), ShapedArray(float32[499,25]), ShapedArray(float32[75,25]), ShapedArray(float32[499,50]), ShapedArray(float32[10,50]), ShapedArray(bool[499,50]), ShapedArray(float32[499,

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(int32[], weak_type=True), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_r

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[50]), ShapedArray(float32[50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[10,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[10,50]), ShapedArray(float32[10,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[10]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[10]), ShapedArray(float32[10])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[10]), ShapedArray(float32[10])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG

DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[25]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[25]), ShapedArray(float32

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[5]), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitio

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[32,119]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[32,119]), ShapedArray(float32[32,119])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[32,119]), ShapedArray(float32[32,119])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment

DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[50,51])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[50,51]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[50,51])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[50,51]), ShapedA

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[75,25]), ShapedArray(float32[75,25])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[75])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[75]), ShapedArray(float32[75])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:ja

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[15,75]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[15,75]), ShapedArray(float32[15,75])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[15,75]), ShapedArray(float32[15,75])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({repli

DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[50,10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[50,10]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[50,10])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[50,10]), ShapedA

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[2081,50]), ShapedArray(float32[2081,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[2081]), ShapedArray(float32[2081])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[125,25]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[125,25]), ShapedArray(float32[125,25])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[125,25]), ShapedArray(float32[125,25])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[60,125])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[60,125]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[60,125])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DE

DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[60]), ShapedArray(float32[60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[250,65]), ShapedArray(float32[250,65])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[250,65])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with gl

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[250]), ShapedArray(float32[250])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[250,250]), ShapedArray(float32[250,250])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[250,250])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[50,250]), ShapedArray(float32[50,250])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[50,250]), ShapedArray(float32[50,250])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[300,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({

DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[300]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[300])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[300]), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[150,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[150,60]), ShapedArray(float32[150,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling square for with global shapes and types [ShapedArray(float32[150,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling true_divide for with global shapes and types [ShapedArray(float32[150,50]), ShapedArray(float32[150,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[150,50]), ShapedArray(float32[150,50])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[150])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({rep

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[50])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[10,50])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[10,50])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[25,7])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._sr

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[75,25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[75])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[75])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[15,75])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[60,125])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[60,125])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._sr

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[300])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[60,300])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[60,300])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[150,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:ja

Loading to device:   0%|          | 0/5 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/5 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[599,318]), ShapedArray(bool[599,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replica

  0%|          | 0.00/176.76 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[165]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[165]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[165]), ShapedArray(int32[165]), ShapedArray(int32[165])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[54]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[53]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[24]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[518]), ShapedArray(int32[518]), ShapedArray(int32[518])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[518])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[600]), ShapedArray(int32[518,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._s

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[60]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[33]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[33]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bri

DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[942]), ShapedArray(int32[942]), ShapedArray(int32[942])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[942])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[1000]), ShapedArray(int32[942,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[10,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[141]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[141]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:j

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[191,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[191]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[190]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._s

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[28])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[28]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interp

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global 

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


Loading to device:   0%|          | 0/15 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/15 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[1199,318]), ShapedArray(bool[1199,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({repli

  0%|          | 0.00/298.35 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[10]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[10]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[10]), ShapedArray(int32[10]), ShapedArray(int32[10])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[478]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[477]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[35])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with globa

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[125]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[125]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[125]), ShapedArray(int32[125]), ShapedArray(int32[125])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), 

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[232]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[232]), ShapedArray(int32[232]), ShapedArray(int32[232])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[232])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[17]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[19])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[19,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[19,60])]. Argumen

DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[143,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[143]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[142]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._s

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[6,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[1200])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(bool[1200])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:j

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[441]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[441]), ShapedArray(int32[441]), ShapedArray(int32[441])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[441])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[85]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[55]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[55]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[30,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[30]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[29]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[112])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[112,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[112]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[12])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[12,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[12,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[27])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[18,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[22])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[22])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[22,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[15])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:g

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_update_slice for with global shapes and types [ShapedArray(float32[1199,15]), ShapedArray(float32[1,15]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[1199,15]), ShapedArray(float32[1199,15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(f

Loading to device:   0%|          | 0/14 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/14 [00:00<?, ?subject/s]

  0%|          | 0.00/199.39 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[186]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[186]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[186]), ShapedArray(int32[186]), ShapedArray(int32[186])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[36]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[35]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[88]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax.

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[21])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[21,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[21]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for wit

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[43]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[43]), ShapedArray(int32[43]), ShapedArray(int32[43])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[43])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[102]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[73]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[73]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interp

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[34]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[33]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[28]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[676])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[700]), ShapedArray(int32[676,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[676]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[25]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[25]), ShapedArray(int32[25]), ShapedArray(int32[25])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[25])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:g

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[40])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[40])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[8])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and t

Loading to device:   0%|          | 0/12 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/12 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[399,318]), ShapedArray(bool[399,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replica

  0%|          | 0.00/271.91 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[400])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(bool[400])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[352]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[352]), ShapedArray(int32

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[5,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[17]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[17]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[51,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[51]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[50]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[100]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[100]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[100]), ShapedArray(int32[100]), ShapedArray(int32[100])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), 

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[158]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[158]), ShapedArray(int32[158]), ShapedArray(int32[158])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[158])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[81]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[151]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[151]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.inter

DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[121,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[121]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[120]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._s

DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[78]), ShapedArray(int32[78]), ShapedArray(int32[78])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[78])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[78,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.in

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[32])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[32,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[32,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for 

Loading to device:   0%|          | 0/15 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/15 [00:00<?, ?subject/s]

  0%|          | 0.00/306.69 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[149]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[149]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[149]), ShapedArray(int32[149]), ShapedArray(int32[149])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[311]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[310]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[53]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:ja

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[89])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[89,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[89]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for wit

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[110]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[110]), ShapedArray(int32[110]), ShapedArray(int32[110])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[110])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[622]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[20])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[20,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[20,60])]. Argume

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[63]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[63]), ShapedArray(int32[63]), ShapedArray(int32[63])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[63])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:g

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[213]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[258]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[258]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[74,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[74]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[73]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[63,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[15,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[96]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_r

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[35])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[35,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[35]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for wit

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[20,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[20,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[63])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with glo

Loading to device:   0%|          | 0/15 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/15 [00:00<?, ?subject/s]

  0%|          | 0.00/207.16 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[119]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[119]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[119]), ShapedArray(int32[119]), ShapedArray(int32[119])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[46]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[45]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[40]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax.

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[52])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[52,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[52]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for wit

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[307]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[307]), ShapedArray(int32[307]), ShapedArray(int32[307])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[307])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[14]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[29])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[29,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[29,60])]. Argumen

DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[300]), ShapedArray(int32[215,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[215]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[214]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._s

DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[56]), ShapedArray(int32[56]), ShapedArray(int32[56])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[56])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[56,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.in

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[29])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[29])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[29,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[29,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[36])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[36])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[4])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and t

Loading to device:   0%|          | 0/11 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/11 [00:00<?, ?subject/s]

  0%|          | 0.00/173.62 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[57]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[57]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[57]), ShapedArray(int32[57]), ShapedArray(int32[57])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[108]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[107]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[1134]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[66])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[66,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[66]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for wit

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[132]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[132]), ShapedArray(int32[132]), ShapedArray(int32[132])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[132])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_c

Loading to device:   0%|          | 0/11 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/11 [00:00<?, ?subject/s]

  0%|          | 0.00/236.30 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[9]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[9]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[9]), ShapedArray(int32[9]), ShapedArray(int32[9])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assig

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[42]), ShapedArray(int32[42]), ShapedArray(int32[42])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[42])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[42,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xl

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[31]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[31]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[31]), ShapedArray(int32[31]), ShapedArray(int32[31])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMD

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[154]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[153]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[187]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:j

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[209])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[300]), ShapedArray(int32[209,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[209]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[42]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[42])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[42])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 devi

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


Loading to device:   0%|          | 0/16 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/16 [00:00<?, ?subject/s]

  0%|          | 0.00/150.47 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[70]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[70]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[70]), ShapedArray(int32[70]), ShapedArray(int32[70])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[22]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[21]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[426]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[38]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[38]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[38]), ShapedArray(int32[38]), ShapedArray(int32[38])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMD

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[106]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[105]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[83])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_par

Loading to device:   0%|          | 0/8 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/8 [00:00<?, ?subject/s]

  0%|          | 0.00/137.69 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[12]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[12]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[12]), ShapedArray(int32[12]), ShapedArray(int32[12])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[993]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[992]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[6]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[23])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[23,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[23]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for wit

DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[3])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[35])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[35]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1

Loading to device:   0%|          | 0/13 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/13 [00:00<?, ?subject/s]

  0%|          | 0.00/106.08 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[123]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[123]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[123]), ShapedArray(int32[123]), ShapedArray(int32[123])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[247]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[246]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[122]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:j

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[109])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[109,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[109]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[6])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_c

Loading to device:   0%|          | 0/8 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/8 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[1099,318]), ShapedArray(bool[1099,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({repli

  0%|          | 0.00/384.86 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[1100])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(bool[1100])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[1030]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[1030]), ShapedArray(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[2397]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[2397]), ShapedArray(int32[2397]), ShapedArray(int32[2397])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[2397])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xl

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[300]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[299]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[1]), ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_update_slice for with global shapes and types [ShapedArray(float32[2399,15]), ShapedArray(float32[1,15]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[2399,15]), ShapedArray(float32[2399,15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(bool[50]), ShapedArray(float32[17375]), ShapedArray(float

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[899,318]), ShapedArray(bool[899,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replica

  0%|          | 0.00/183.52 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[5]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[5]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[5]), ShapedArray(int32[5]), ShapedArray(int32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assig

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[591]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[590]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[900])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global 

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[107]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[107]), ShapedArray(int32[107]), ShapedArray(int32[107])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[107])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bri

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[13]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[114]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[114]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_b

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[100]), ShapedArray(int32[64,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[64]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[63]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[9])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[9])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[9])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and typ

DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


Loading to device:   0%|          | 0/15 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/15 [00:00<?, ?subject/s]

  0%|          | 0.00/192.05 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[77]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[77]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[77]), ShapedArray(int32[77]), ShapedArray(int32[77])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[39]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[38]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[138]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[115])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[115,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[115]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[255]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[255]), ShapedArray(int32[255]), ShapedArray(int32[255])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[255])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpr

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[605]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[105]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[105]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.inte

DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[32])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


Loading to device:   0%|          | 0/12 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/12 [00:00<?, ?subject/s]

  0%|          | 0.00/148.39 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[59]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[59]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[59]), ShapedArray(int32[59]), ShapedArray(int32[59])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[13]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[12]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[116]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[233])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[300]), ShapedArray(int32[233,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[233]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

Loading to device:   0%|          | 0/6 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/6 [00:00<?, ?subject/s]

  0%|          | 0.00/260.74 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[560]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[560]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[560]), ShapedArray(int32[560]), ShapedArray(int32[560])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[54,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[6,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[34])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[170]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[170]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[170]), ShapedArray(int32[170]), ShapedArray(int32[170])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), 

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[54,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[54,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[34])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with glo

Loading to device:   0%|          | 0/11 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/11 [00:00<?, ?subject/s]

  0%|          | 0.00/178.43 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[407]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[407]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[407]), ShapedArray(int32[407]), ShapedArray(int32[407])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[298]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[297]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[194]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:j

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[32]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[32]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[32]), ShapedArray(int32[32]), ShapedArray(int32[32])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMD

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[157]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[156]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[345]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:j

DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[33])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[33]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[33])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types

Loading to device:   0%|          | 0/7 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/7 [00:00<?, ?subject/s]

  0%|          | 0.00/108.12 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[68]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[68]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[68]), ShapedArray(int32[68]), ShapedArray(int32[68])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[44]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[43]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[113]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[365])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[400]), ShapedArray(int32[365,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[365]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

Loading to device:   0%|          | 0/6 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/6 [00:00<?, ?subject/s]

  0%|          | 0.00/243.10 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[41]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[41]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[41]), ShapedArray(int32[41]), ShapedArray(int32[41])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[71]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[70]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[1041]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:ja

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[159])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[159,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[159]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

Loading to device:   0%|          | 0/11 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/11 [00:00<?, ?subject/s]

  0%|          | 0.00/128.65 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[11]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[11]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[11]), ShapedArray(int32[11]), ShapedArray(int32[11])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[128]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[127]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[201]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:j

Loading to device:   0%|          | 0/7 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/7 [00:00<?, ?subject/s]

  0%|          | 0.00/162.16 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[29]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[29]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[29]), ShapedArray(int32[29]), ShapedArray(int32[29])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1076]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1075]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[155]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[312])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[400]), ShapedArray(int32[312,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[312]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

Loading to device:   0%|          | 0/3 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/3 [00:00<?, ?subject/s]

  0%|          | 0.00/182.93 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[349]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[349]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[349]), ShapedArray(int32[349]), ShapedArray(int32[349])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[137]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[136]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16]), ShapedArray(float32[16]), ShapedArray(float32[5])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

  0%|          | 0.00/218.29 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[166]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[166]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[166]), ShapedArray(int32[166]), ShapedArray(int32[166])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[373]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[372]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[41])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with globa

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[102]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[102]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[102]), ShapedArray(int32[102]), ShapedArray(int32[102])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), 

Loading to device:   0%|          | 0/13 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/13 [00:00<?, ?subject/s]

  0%|          | 0.00/381.47 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[535]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[535]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[535]), ShapedArray(int32[535]), ShapedArray(int32[535])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[893]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[892]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[521]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:j

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[472])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[500]), ShapedArray(int32[472,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[472]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[96,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[96,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[16,60])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replic

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[96,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


Loading to device:   0%|          | 0/7 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/7 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[2299,318]), ShapedArray(bool[2299,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({repli

  0%|          | 0.00/283.90 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[193]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[193]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[193]), ShapedArray(int32[193]), ShapedArray(int32[193])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[2280])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[2300]), ShapedArray(int32[2280,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[2280]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice 

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[38])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[38])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[38,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[38,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling integer_pow for with global shapes and types [ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


Loading to device:   0%|          | 0/11 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/11 [00:00<?, ?subject/s]

  0%|          | 0.00/207.39 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[79]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[79]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[79]), ShapedArray(int32[79]), ShapedArray(int32[79])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[353]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[352]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[37]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:ja

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[664])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[700]), ShapedArray(int32[664,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[664]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[83]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[83]), ShapedArray(int32[83]), ShapedArray(int32[83])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[83])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[382]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[80]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[80]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interp

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

  0%|          | 0.00/296.98 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[43])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float16[43,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(bool[43,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling concatenate for with global shapes and types [ShapedArray(float32[16,60]), ShapedArray(float32[16,60]), ShapedArray(float32[11,60])]

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[248]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[247]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[39])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with globa

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[43])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global shapes and types [ShapedArray(bool[43])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[43,60])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_any for with global

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

  0%|          | 0.00/192.18 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[140]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[140]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[140]), ShapedArray(int32[140]), ShapedArray(int32[140])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[264]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[263]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(float32[26])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with globa

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[390]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[390]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[390]), ShapedArray(int32[390]), ShapedArray(int32[390])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), 

Loading to device:   0%|          | 0/6 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/6 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[1699,318]), ShapedArray(bool[1699,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({repli

  0%|          | 0.00/158.59 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[1700])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(bool[1700])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[1677]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[1677]), ShapedArray(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[160])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[200]), ShapedArray(int32[160,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[160]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1

Loading to device:   0%|          | 0/11 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/11 [00:00<?, ?subject/s]

  0%|          | 0.00/141.08 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[72]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[72]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[72]), ShapedArray(int32[72]), ShapedArray(int32[72])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[574]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[573]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[168]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:j

DEBUG:jax._src.interpreters.pxla:Compiling broadcast_in_dim for with global shapes and types [ShapedArray(int32[216])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling gather for with global shapes and types [ShapedArray(float32[300]), ShapedArray(int32[216,1])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[216]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for 

Loading to device:   0%|          | 0/15 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/15 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[1299,318]), ShapedArray(bool[1299,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({repli

  0%|          | 0.00/216.74 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[1300])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(bool[1300])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[1225]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[1225]), ShapedArray(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(bool[50]), ShapedArray(float32[17375]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(bool[25]), ShapedArray(float32[7]), ShapedArray(float32[1299,75]), ShapedArray(float32[15,75]), ShapedArray(bool[1299,75]), ShapedArray(float32[1299,25]), ShapedArray(float32[75,25]), ShapedArray(float32[1299,50]), ShapedArray(float32[10,50]), ShapedArray(bool[1299,50]), ShapedArray(float32[1299,197]), ShapedArray(float32[50,197]), ShapedArray(float32[32,119]), ShapedArray(float32[1299,119]), ShapedArray(float32[1299,50]), ShapedArray(float32[10,50]), ShapedArray(bool[1299,50]), ShapedArray(float32[1299,51]), ShapedArray(float32[10]), ShapedArray(float32[1299,15])]. Argument mapping: (GSPMDSharding({replicated}), G

Loading to device:   0%|          | 0/7 [00:00<?, ?subject/s]



Embedding:   0%|          | 0/7 [00:00<?, ?subject/s]

DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(float32[50,17375]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[25,7]), ShapedArray(float32[25]), ShapedArray(float32[5,25]), ShapedArray(float32[5]), ShapedArray(float32[32,119]), ShapedArray(float32[50,197]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[50,51]), ShapedArray(float32[50]), ShapedArray(float32[10,50]), ShapedArray(float32[10]), ShapedArray(float32[75,25]), ShapedArray(float32[75]), ShapedArray(float32[15,75]), ShapedArray(float32[15]), ShapedArray(float16[7]), ShapedArray(bool[17375]), ShapedArray(float16[2799,318]), ShapedArray(bool[2799,51])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({repli

  0%|          | 0.00/213.45 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[385]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[385]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[385]), ShapedArray(int32[385]), ShapedArray(int32[385])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[842]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[841]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling isnan for with global shapes and types [ShapedArray(float32[2800])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_update_slice for with global shapes and types [ShapedArray(float32[2799,15]), ShapedArray(float32[1,15]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(float32[2799,15]), ShapedArray(float32[2799,15])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _embed_admission for with global shapes and types [ShapedArray(f

Loading to device:   0%|          | 0/9 [00:00<?, ?subject/s]

Embedding:   0%|          | 0/9 [00:00<?, ?subject/s]

  0%|          | 0.00/168.19 [00:00<?, ?odeint-days/s]

DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[252]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling add for with global shapes and types [ShapedArray(int32[252]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[252]), ShapedArray(int32[252]), ShapedArray(int32[252])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 de

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[262]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[261]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(int32[85]), ShapedArray(int32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:ja

DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(float32[29])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling nanmean for with global shapes and types [ShapedArray(bool[29]), ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[29])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling slice for with global shapes and types [ShapedArray(float32[13])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[gpu(id=0)]]


In [None]:
# import jax.tree_util as jtu
# import jax.numpy as jnp
# import equinox as eqx

# jtu.tree_map(lambda x: f'{x.shape} {jnp.any(jnp.isnan(x)).item()}' if eqx.is_array(x) else None , m)

In [None]:
# emb_subj = {i: m.f_emb(s) for i, s in m4inpatients.device_batch().subjects.items()}