In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset
from lib.ehr.inpatient_interface import Inpatients

In [3]:
import logging
logging.root.level = logging.DEBUG


In [4]:
# Assign the folder of the dataset to `DATA_FILE`.
import dask

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR), dask.config.set(scheduler='processes'):
    m4inpatient_dataset = load_dataset('M4ICU')
   

DEBUG:root:Loading dataframe files
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/adm_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/dx_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/static_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/obs_df.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_input.csv
DEBUG:fsspec.local:open file: /home/asem/GP/ehr-data/mimic4icu-cohort/int_proc.csv
DEBUG:root:[DONE] Loading dataframe files
DEBUG:root:Matching admission_id
DEBUG:root:[DONE] Matching admission_id
DEBUG:root:Time casting..
DEBUG:root:[DONE] Time casting..
DEBUG:absl:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:absl:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:absl:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
DEBUG:absl:Constructing dx_icd9 (<class 'lib.ehr.coding_

In [5]:
splits = m4inpatient_dataset.random_splits([0.1, 0.7], 42, 'subjects')



In [6]:
preprocessing = m4inpatient_dataset.fit_preprocessing(splits[0])

In [7]:

m4inpatient_dataset.apply_preprocessing(preprocessing)

DEBUG:root:Removed 2431596 (0.024) outliers from obs


In [8]:
# from concurrent.futures import ThreadPoolExecutor
# with dask.config.set(pool=ThreadPoolExecutor(12)):
with dask.config.set(scheduler='processes', num_workers=12):
    m4inpatients = Inpatients(m4inpatient_dataset, splits[0], num_workers=12)

DEBUG:root:Loading subjects..
  dob = anchor_date + anchor_age
DEBUG:absl:Constructing mimic4_eth32 (<class 'lib.ehr.coding_scheme.MIMIC4Eth32'>) scheme
DEBUG:absl:Constructing mimic4_eth5 (<class 'lib.ehr.coding_scheme.MIMIC4Eth5'>) scheme
DEBUG:root:Extracting dx codes...
DEBUG:absl:Constructing dx_icd10 (<class 'lib.ehr.coding_scheme.DxICD10'>) scheme
DEBUG:absl:Constructing dx_icd9 (<class 'lib.ehr.coding_scheme.DxICD9'>) scheme
                            dx_icd10->dx_icd9 Unrecognised t_codes
                            (169):
                            ['041.41', '041.42', '041.43', '041.49', '173.00', '173.01', '173.02', '173.09', '173.10', '173.11', '173.12', '173.19', '173.20', '173.21', '173.22', '173.29', '173.30', '173.31', '173.32', '173.39']...
                            dx_icd10->dx_icd9 Unrecognised s_codes
                            (49910):
                            ['E08.3211', 'E08.3212', 'E08.3213', 'E08.3219', 'E08.3291', 'E08.3292', 'E08.3293', 'E08.3299', 

In [9]:
m4inpatients.size_in_bytes() / 1024 ** 3

0.7975713079795241

In [10]:
m4inpatients_jax = m4inpatients.to_jax_arrays(splits[0])

DEBUG:jax._src.xla_bridge:Initializing backend 'cpu'
DEBUG:jax._src.xla_bridge:Backend 'cpu' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'cuda'
DEBUG:jax._src.xla_bridge:Backend 'cuda' initialized
DEBUG:jax._src.xla_bridge:Initializing backend 'rocm'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Interpreter CUDA
DEBUG:jax._src.xla_bridge:Initializing backend 'tpu'
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'


































In [11]:
m4inpatients_jax.size_in_bytes() / 1024 ** 3

0.7911796784028411

In [12]:
len(m4inpatients_jax.subjects)

5083

In [13]:
m4inpatients_jax.n_admissions()

17720

In [14]:
m4inpatients_jax.n_segments()

901721

In [15]:
m4inpatients_jax.n_obs_times()

730030

In [16]:
# import numpy as np
# import matplotlib.pyplot as plt

# a = m4inpatients_jax.obs_coocurrence_matrix
# a = np.array(a)
# plt.imshow(a, cmap='hot', interpolation='nearest')
# plt.show()

In [23]:
s = m4inpatients_jax.subjects[splits[0][6]].admissions[0].interventions.input_
s

InpatientInput(
  index=i32[90],
  rate=f16[90],
  starttime=f32[90],
  endtime=f32[90],
  size=469
)

In [24]:
m4inpatients_jax.intervals_sum(splits[0])

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[161]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling squeeze for with global shapes and types [ShapedArray(float32[1])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling <lambda> for with global shapes and types [ShapedArray(float32[]), ShapedArray(float32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _reduce_sum f

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[60]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[17]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[28]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[121]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[216]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[6]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1745]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[222]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[255]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replic

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[27]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[118]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[56]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[188]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[13]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[313]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicat

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[69]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[101]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[113]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[144]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[47]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[106]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicat

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[120]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[764]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[16]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[140]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[172]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[464]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[53]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[45]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[182]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[110]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[871]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[162]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[89]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[156]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[103]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[409]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[7]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[275]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicate

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[571]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[71]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[254]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(i

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[543]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[199]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[278]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[111]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[130]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[546]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1808]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[88]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[127]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1231]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[179]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[572]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[703]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[104]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1567]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replic

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[143]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[167]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[257]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[290]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[854]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[105]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[168]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[134]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[227]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[85]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[625]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[555]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicat

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[931]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[137]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[149]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[235]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1697]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1140]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({repli

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[495]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[719]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[213]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[146]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[231]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[466]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[196]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[210]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[217]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[240]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[250]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[212]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[178]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[793]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[581]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1320]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1104]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[245]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({repli

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1258]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[291]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[302]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1228]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[745]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[449]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replic

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[153]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[330]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[739]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[269]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[651]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[422]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[355]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[835]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[864]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[825]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[599]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[920]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[322]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1189]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[169]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[380]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[472]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[252]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1732]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[629]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[527]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[337]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[339]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[407]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[403]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[701]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[506]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[477]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[375]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[699]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[284]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[718]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[514]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[360]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[344]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[223]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[570]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[402]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[427]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[456]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[219]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1050]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replic

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[265]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[267]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[715]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[450]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1497]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1018]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({repli

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1448]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[225]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[441]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[815]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[347]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[816]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1042]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[317]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[475]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[509]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[810]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[553]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1111]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[229]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[412]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1328]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[737]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[473]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replic

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[425]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1209]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[800]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[824]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[692]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[270]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replica

DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[453]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[338]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[334]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[1332]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[454]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(float32[619]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replic

Array(2678208.8, dtype=float32)

In [25]:
s

InpatientInput(
  index=i32[90],
  rate=f16[90],
  starttime=f32[90],
  endtime=f32[90],
  size=469
)

### التدريب على نموذج المعادلات التفاضلية الاعتيادية العصبية


In [26]:
from lib.ml.in_icenode import InICENODE, InICENODEDimensions
import jax.random as jrandom

DEBUG:matplotlib:matplotlib data path: /home/asem/GP/env/icenode-dev/lib/python3.9/site-packages/matplotlib/mpl-data
DEBUG:matplotlib:CONFIGDIR=/home/asem/.config/matplotlib
DEBUG:matplotlib:interactive is False
DEBUG:matplotlib:platform is linux
DEBUG:matplotlib:CACHEDIR=/home/asem/.cache/matplotlib
DEBUG:matplotlib.font_manager:Using fontManager instance from /home/asem/.cache/matplotlib/fontlist-v330.json
- the fwd and bwd functions take an extra `perturbed` argument, which     indicates which primals actually need a gradient. You can use this     to skip computing the gradient for any unperturbed value. (You can     also safely just ignore this if you wish.)
- `None` was previously passed to indicate a symbolic zero gradient for     all objects that weren't inexact arrays, but all inexact arrays     always had an array-valued gradient. Now, `None` may also be passed     to indicate that an inexact array has a symbolic zero gradient.


In [27]:
dims = InICENODEDimensions(state_m=15, 
                state_dx_e=10,
                state_obs_e=25,
                input_e=10,
                proc_e=10,
                demo_e=5,
                int_e=15)
key = jrandom.PRNGKey(0)

m = InICENODE(dims=dims, 
              scheme=m4inpatient_dataset.scheme,
              key=key)

DEBUG:jax._src.interpreters.pxla:Compiling _threefry_seed for with global shapes and types [ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(uint32[5,2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({r

DEBUG:jax._src.interpreters.pxla:Compiling select_n for with global shapes and types [ShapedArray(bool[], weak_type=True), ShapedArray(int32[], weak_type=True), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(uint32[2,2]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling squeeze for with global shapes and types [ShapedArray(uint32[1,2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_option

DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),).


DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform

DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _threefry_split for with global shapes and types [ShapedArray(uint32[2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling dynamic_slice for with global shapes and types [ShapedArray(uint32[4,2]), ShapedArray(int32[]), ShapedArray(int32[])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (G

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[250,65])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[250])]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling fn for with global shapes and types [ShapedArray(float32[], weak_type=True), ShapedArray(float32[250,250])]. Argument mapping: (GSPMDSharding({replicat

DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _unstack for with global shapes and types [ShapedArray(uint32[4,2])]. Argument mapping: (GSPMDSharding({replicated}),).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=True)]. Argument mapping: (GSPMDSharding({replicated}), GSPMDSharding({replicated}), GSPMDSharding({replicated})).
DEBUG:jax._src.xla_bridge:get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax._src.interpreters.pxla:Compiling _uniform for with global shapes and types [ShapedArray(key<fry>[]), ShapedArray(float32[], weak_type=True), ShapedArray(float32[], weak_type=Tru

In [31]:
res = m.batch_predict(m4inpatients_jax, splits[0][:100])


Embedding:   0%|                                                                                | 0/100 [00:00<?, ?it/s][A
Embedding:   1%|▋                                                                       | 1/100 [00:00<01:24,  1.18it/s][A
Embedding:   2%|█▍                                                                      | 2/100 [00:02<02:05,  1.28s/it][A
Embedding:   3%|██▏                                                                     | 3/100 [00:06<04:10,  2.59s/it][A
Embedding:   4%|██▉                                                                     | 4/100 [00:06<02:45,  1.72s/it][A
Embedding:   5%|███▌                                                                    | 5/100 [00:08<02:44,  1.73s/it][A
Embedding:   6%|████▎                                                                   | 6/100 [00:08<01:52,  1.20s/it][A
Embedding:   7%|█████                                                                   | 7/100 [00:09<01:27,  1.06it/s][A
Embeddi

Embedding:  71%|██████████████████████████████████████████████████▍                    | 71/100 [01:17<00:28,  1.01it/s][A
Embedding:  72%|███████████████████████████████████████████████████                    | 72/100 [01:17<00:23,  1.21it/s][A
Embedding:  73%|███████████████████████████████████████████████████▊                   | 73/100 [01:18<00:17,  1.51it/s][A
Embedding:  74%|████████████████████████████████████████████████████▌                  | 74/100 [01:22<00:44,  1.71s/it][A
Embedding:  75%|█████████████████████████████████████████████████████▎                 | 75/100 [01:22<00:31,  1.27s/it][A
Embedding:  77%|██████████████████████████████████████████████████████▋                | 77/100 [01:23<00:18,  1.27it/s][A
Embedding:  78%|███████████████████████████████████████████████████████▍               | 78/100 [01:23<00:14,  1.51it/s][A
Embedding:  79%|████████████████████████████████████████████████████████               | 79/100 [01:23<00:12,  1.68it/s][A
Embeddin

Subject 5/100:   6%|███▍                                                   | 3838.03/61689.47 [00:05<00:53, 1090.69it/s][A
Subject 5/100:   6%|███▌                                                   | 3958.82/61689.47 [00:06<00:52, 1090.69it/s][A
Subject 5/100:   7%|███▋                                                    | 4023.47/61689.47 [00:06<01:01, 935.78it/s][A
Subject 5/100:   7%|███▋                                                    | 4023.47/61689.47 [00:06<01:01, 935.78it/s][A
Subject 5/100:   7%|███▋                                                    | 4094.17/61689.47 [00:06<01:01, 935.78it/s][A
Subject 5/100:   7%|███▊                                                    | 4141.12/61689.47 [00:06<01:05, 877.52it/s][A
Subject 5/100:   7%|███▊                                                    | 4141.12/61689.47 [00:06<01:05, 877.52it/s][A
Subject 6/100:   7%|███▊                                                    | 4189.82/61689.47 [00:06<01:05, 877.52it/s][A
Subject 

Subject 17/100:  19%|█████████▉                                           | 11507.88/61689.47 [00:14<00:32, 1559.78it/s][A
Subject 18/100:  19%|█████████▉                                           | 11581.50/61689.47 [00:14<00:32, 1559.78it/s][A
Subject 18/100:  19%|██████████▎                                          | 11990.63/61689.47 [00:14<00:42, 1166.41it/s][A
Subject 19/100:  19%|██████████▎                                          | 11990.63/61689.47 [00:14<00:42, 1166.41it/s][A
Subject 19/100:  20%|██████████▎                                          | 12060.68/61689.47 [00:14<00:42, 1166.41it/s][A
Subject 19/100:  20%|██████████▊                                           | 12295.77/61689.47 [00:15<00:54, 906.10it/s][A
Subject 20/100:  20%|██████████▊                                           | 12295.77/61689.47 [00:15<00:54, 906.10it/s][A
Subject 20/100:  20%|██████████▉                                           | 12455.68/61689.47 [00:15<00:57, 854.35it/s][A
Subject 

Subject 32/100:  34%|██████████████████▌                                   | 21244.13/61689.47 [00:30<00:59, 681.69it/s][A
Subject 32/100:  35%|██████████████████▊                                   | 21508.17/61689.47 [00:30<00:45, 881.76it/s][A
Subject 32/100:  35%|██████████████████▊                                   | 21508.17/61689.47 [00:30<00:45, 881.76it/s][A
Subject 32/100:  35%|██████████████████▉                                   | 21571.62/61689.47 [00:30<00:45, 881.76it/s][A
Subject 32/100:  35%|██████████████████▉                                   | 21639.48/61689.47 [00:30<00:45, 881.76it/s][A
Subject 32/100:  35%|██████████████████▋                                  | 21737.83/61689.47 [00:30<00:38, 1036.43it/s][A
Subject 32/100:  35%|██████████████████▋                                  | 21737.83/61689.47 [00:30<00:38, 1036.43it/s][A
Subject 32/100:  36%|███████████████████                                  | 22246.05/61689.47 [00:30<00:31, 1245.59it/s][A
Subject 

Subject 48/100:  47%|█████████████████████████▌                            | 29203.22/61689.47 [00:44<01:11, 454.51it/s][A
Subject 49/100:  47%|█████████████████████████▌                            | 29203.22/61689.47 [00:44<01:11, 454.51it/s][A
Subject 49/100:  47%|█████████████████████████▌                            | 29229.85/61689.47 [00:44<01:11, 454.51it/s][A
Subject 49/100:  47%|█████████████████████████▌                            | 29249.87/61689.47 [00:44<01:11, 454.51it/s][A
Subject 49/100:  48%|█████████████████████████▋                            | 29336.30/61689.47 [00:44<01:05, 493.95it/s][A
Subject 49/100:  48%|█████████████████████████▋                            | 29336.30/61689.47 [00:44<01:05, 493.95it/s][A
Subject 49/100:  48%|█████████████████████████▋                            | 29346.73/61689.47 [00:44<01:05, 493.95it/s][A
Subject 49/100:  48%|█████████████████████████▊                            | 29453.20/61689.47 [00:44<01:05, 493.95it/s][A
Subject 

Subject 59/100:  58%|██████████████████████████████▉                      | 36041.97/61689.47 [00:49<00:16, 1593.71it/s][A
Subject 59/100:  58%|██████████████████████████████▉                      | 36041.97/61689.47 [00:49<00:16, 1593.71it/s][A
Subject 59/100:  59%|███████████████████████████████                      | 36132.02/61689.47 [00:49<00:16, 1593.71it/s][A
Subject 59/100:  59%|███████████████████████████████▏                     | 36308.52/61689.47 [00:49<00:19, 1322.49it/s][A
Subject 59/100:  59%|███████████████████████████████▏                     | 36308.52/61689.47 [00:49<00:19, 1322.49it/s][A
Subject 59/100:  59%|███████████████████████████████▍                     | 36602.82/61689.47 [00:50<00:22, 1107.52it/s][A
Subject 59/100:  59%|███████████████████████████████▍                     | 36602.82/61689.47 [00:50<00:22, 1107.52it/s][A
Subject 59/100:  60%|███████████████████████████████▊                     | 37014.00/61689.47 [00:50<00:24, 1008.30it/s][A
Subject 

Subject 63/100:  66%|███████████████████████████████████▍                  | 40451.73/61689.47 [00:54<00:26, 807.52it/s][A
Subject 63/100:  66%|███████████████████████████████████▍                  | 40550.83/61689.47 [00:54<00:26, 807.52it/s][A
Subject 63/100:  66%|███████████████████████████████████▌                  | 40650.02/61689.47 [00:54<00:25, 823.63it/s][A
Subject 63/100:  66%|███████████████████████████████████▌                  | 40650.02/61689.47 [00:54<00:25, 823.63it/s][A
Subject 63/100:  66%|███████████████████████████████████▋                  | 40793.80/61689.47 [00:55<00:31, 671.09it/s][A
Subject 63/100:  66%|███████████████████████████████████▋                  | 40793.80/61689.47 [00:55<00:31, 671.09it/s][A
Subject 63/100:  66%|███████████████████████████████████▊                  | 40865.18/61689.47 [00:55<00:31, 671.09it/s][A
Subject 63/100:  66%|███████████████████████████████████▊                  | 40904.62/61689.47 [00:55<00:30, 671.09it/s][A
Subject 

Subject 63/100:  72%|██████████████████████████████████████▋               | 44135.62/61689.47 [00:59<00:27, 631.13it/s][A
Subject 63/100:  72%|██████████████████████████████████████▋               | 44215.60/61689.47 [00:59<00:27, 631.13it/s][A
Subject 63/100:  72%|██████████████████████████████████████▊               | 44356.03/61689.47 [00:59<00:22, 786.61it/s][A
Subject 63/100:  72%|██████████████████████████████████████▊               | 44356.03/61689.47 [00:59<00:22, 786.61it/s][A
Subject 63/100:  72%|██████████████████████████████████████▉               | 44452.00/61689.47 [00:59<00:21, 786.61it/s][A
Subject 63/100:  72%|██████████████████████████████████████▉               | 44551.98/61689.47 [00:59<00:18, 911.74it/s][A
Subject 63/100:  72%|██████████████████████████████████████▉               | 44551.98/61689.47 [00:59<00:18, 911.74it/s][A
Subject 63/100:  72%|███████████████████████████████████████               | 44646.62/61689.47 [00:59<00:18, 911.74it/s][A
Subject 

Subject 73/100:  80%|███████████████████████████████████████████           | 49163.72/61689.47 [01:05<00:15, 803.53it/s][A
Subject 73/100:  80%|███████████████████████████████████████████           | 49163.72/61689.47 [01:05<00:15, 803.53it/s][A
Subject 73/100:  82%|███████████████████████████████████████████▎         | 50357.70/61689.47 [01:05<00:06, 1712.31it/s][A
Subject 74/100:  82%|███████████████████████████████████████████▎         | 50357.70/61689.47 [01:05<00:06, 1712.31it/s][A
Subject 74/100:  84%|█████████████████████████████████████████████▏        | 51614.85/61689.47 [01:08<00:12, 795.74it/s][A
Subject 74/100:  84%|█████████████████████████████████████████████▏        | 51614.85/61689.47 [01:08<00:12, 795.74it/s][A
Subject 74/100:  85%|████████████████████████████████████████████▉        | 52261.72/61689.47 [01:08<00:09, 1045.39it/s][A
Subject 74/100:  85%|████████████████████████████████████████████▉        | 52261.72/61689.47 [01:08<00:09, 1045.39it/s][A
Subject 

Subject 90/100:  94%|██████████████████████████████████████████████████▊   | 57999.05/61689.47 [01:19<00:07, 515.81it/s][A
Subject 90/100:  94%|██████████████████████████████████████████████████▊   | 57999.05/61689.47 [01:19<00:07, 515.81it/s][A
Subject 90/100:  94%|██████████████████████████████████████████████████▉   | 58181.32/61689.47 [01:20<00:06, 527.32it/s][A
Subject 90/100:  94%|██████████████████████████████████████████████████▉   | 58181.32/61689.47 [01:20<00:06, 527.32it/s][A
Subject 90/100:  94%|██████████████████████████████████████████████████▉   | 58207.37/61689.47 [01:20<00:06, 527.32it/s][A
Subject 90/100:  94%|██████████████████████████████████████████████████▉   | 58239.12/61689.47 [01:20<00:06, 527.32it/s][A
Subject 90/100:  95%|███████████████████████████████████████████████████   | 58328.02/61689.47 [01:20<00:06, 527.32it/s][A
Subject 90/100:  95%|███████████████████████████████████████████████████▏  | 58446.17/61689.47 [01:20<00:04, 669.83it/s][A
Subject 