In [1]:
%load_ext autoreload
%autoreload 2
import sys
import os
import glob
import random
from collections import defaultdict
from pathlib import Path

from IPython.display import display

import pandas as pd

from tqdm import tqdm
import jax
jax.config.update('jax_platform_name', 'cpu')

In [2]:


sys.path.append("../..")

from lib import utils as U
from lib.ehr.dataset import load_dataset
from lib.ehr.icu_interface import Inpatients

In [3]:
import logging
logging.root.level = logging.DEBUG


In [4]:
# Assign the folder of the dataset to `DATA_FILE`.

HOME = os.environ.get('HOME')
DATA_DIR = f'{HOME}/GP/ehr-data'
SOURCE_DIR = os.path.abspath("..")

with U.modified_environ(DATA_DIR=DATA_DIR):
    m4inpatient_dataset = load_dataset('M4ICU', max_workers=1)
   

DEBUG:root:Loading dataframe files
DEBUG:root:[DONE] Loading dataframe files
DEBUG:root:Matching admission_id
DEBUG:root:[DONE] Matching admission_id
DEBUG:root:Time casting..
DEBUG:root:[DONE] Time casting..


INFO: Pandarallel will run on 1 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


DEBUG:root:Dataframes validation and time conversion
INFO:root:Unrecognised ICD v10 codes: 3323 (28.74%)
DEBUG:root:
                    Unrecognised <class 'lib.ehr.coding_scheme.DxICD10'> codes (3323)
                    to be removed: ['E08.3513', 'E10.3213', 'E10.3219', 'E10.3291', 'E10.3293', 'E10.3299', 'E10.3313', 'E10.3319', 'E10.3393', 'E10.3399', 'E10.3411', 'E10.3413', 'E10.3491', 'E10.3511', 'E10.3512', 'E10.3513', 'E10.3519', 'E10.3522', 'E10.3523', 'E10.3531', 'E10.3532', 'E10.3559', 'E10.3591', 'E10.3592', 'E10.3593', 'E10.3599', 'E11.3213', 'E11.3219', 'E11.3291', 'E11.3292', 'E11.3293', 'E11.3299', 'E11.3311', 'E11.3313', 'E11.3319', 'E11.3391', 'E11.3393', 'E11.3399', 'E11.3413', 'E11.3419', 'E11.3491', 'E11.3492', 'E11.3493', 'E11.3499', 'E11.3513', 'E11.3519', 'E11.3521', 'E11.3532', 'E11.3542', 'E11.3553', 'E11.3591', 'E11.3592', 'E11.3593', 'E11.3599', 'H34.8112', 'H34.8120', 'H34.8122', 'H34.8192', 'H34.8310', 'H34.8320', 'H35.3110', 'H35.3120', 'H35.3130', 'H35.

INFO:root:Unrecognised ICD v9 codes: 118 (1.63%)
DEBUG:root:
                    Unrecognised <class 'lib.ehr.coding_scheme.DxICD9'> codes (118)
                    to be removed: ['041.49', '173.21', '173.22', '173.30', '173.31', '173.32', '173.40', '173.41', '173.42', '173.50', '173.51', '173.52', '173.59', '173.60', '173.61', '173.62', '173.70', '173.71', '173.72', '173.79', '173.80', '173.81', '173.82', '173.91', '173.92', '173.99', '282.40', '282.43', '282.44', '282.46', '284.11', '284.12', '284.19', '286.52', '286.53', '286.59', '294.20', '294.21', '310.81', '310.89', '331.6', '348.82', '358.30', '365.70', '365.72', '365.73', '414.4', '415.13', '425.11', '425.18', '444.09', '488.81', '488.82', '488.89', '512.2', '512.82', '512.83', '512.84', '512.89', '516.31', '516.32', '516.33', '516.34', '516.35', '516.36', '516.37', '516.4', '516.5', '518.51', '518.52', '518.53', '539.01', '539.09', '539.81', '539.89', '573.5', '596.81', '596.82', '596.83', '596.89', '629.31', '649.81', '704.

In [5]:
splits = m4inpatient_dataset.random_splits(0.8, 0.9, 42)

In [6]:
preprocessing = m4inpatient_dataset.fit_preprocessing(splits[0])

In [7]:
m4inpatient_dataset.apply_preprocessing(preprocessing)

DEBUG:root:Removed 2320851 (0.023) outliers from obs


In [26]:
m4inaptients = Inpatients(m4inpatient_dataset, splits[2][:10])

DEBUG:root:Loading subjects..


INFO: Pandarallel will run on 1 workers.
INFO: Pandarallel will use Memory file system to transfer data between the main process and workers.


DEBUG:root:Extracting dx codes...
DEBUG:root:[DONE] Extracting dx codes
DEBUG:root:Extracting dx codes history...
DEBUG:root:[DONE] Extracting dx codes history
DEBUG:root:Extracting outcome...
DEBUG:root:[DONE] Extracting outcome
DEBUG:root:Extracting procedures...
DEBUG:root:[DONE] Extracting procedures
DEBUG:root:Extracting inputs...
DEBUG:root:[DONE] Extracting inputs
DEBUG:root:Extracting observables...
DEBUG:root:[DONE] Extracting observables
DEBUG:root:[DONE] Loading subjects


In [28]:
m4inaptients.size_in_bytes / 1024 ** 3

0.0014968374744057655

In [20]:
m4inaptients_jax = m4inaptients.to_jax_arrays(splits[2])

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00032019615173339844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140049084490352) for with global shapes and types (ShapedArray(float32[67]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001862049102783203 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140049084490352) for with global shapes and types (ShapedArray(float16[67,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020456314086914062 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140049084493632) for with global shapes and types (ShapedArray(bool[67,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019669532775878906 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005047321319580078 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052425308896) for with global shapes and types (ShapedArray(float16[90,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002143383026123047 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052425308736) for with global shapes and types (ShapedArray(bool[90,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003857612609863281 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052425306336) for with global shapes and types (ShapedArray(int32[70]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00038313865661621094 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022935867309570312 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047018710576) for with global shapes and types (ShapedArray(float16[68,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0006961822509765625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047018711616) for with global shapes and types (ShapedArray(bool[68,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002810955047607422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047018709696) for with global shapes and types (ShapedArray(int32[55]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020551681518554688 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00047898292541503906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053660157248) for with global shapes and types (ShapedArray(float32[131]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027060508728027344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053660159648) for with global shapes and types (ShapedArray(float32[53]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020241737365722656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053660159328) for with global shapes and types (ShapedArray(float16[53,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00041484832763671875 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002033710479736328 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052485327104) for with global shapes and types (ShapedArray(float16[54,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021266937255859375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052485328224) for with global shapes and types (ShapedArray(bool[54,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002651214599609375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054269374192) for with global shapes and types (ShapedArray(float16[11,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019431114196777344 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029087066650390625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505322656) for with global shapes and types (ShapedArray(float32[155]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024366378784179688 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505322816) for with global shapes and types (ShapedArray(float16[155,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003342628479003906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505322816) for with global shapes and types (ShapedArray(bool[155,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019431114196777344 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023055076599121094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505322736) for with global shapes and types (ShapedArray(float16[48,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00040531158447265625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505323056) for with global shapes and types (ShapedArray(bool[48,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003504753112792969 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505323856) for with global shapes and types (ShapedArray(int32[20]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002682209014892578 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0006020069122314453 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053538037168) for with global shapes and types (ShapedArray(float16[8,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022292137145996094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054607328736) for with global shapes and types (ShapedArray(float32[19]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021529197692871094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054607328336) for with global shapes and types (ShapedArray(float16[19,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022411346435546875 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028514862060546875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047018807920) for with global shapes and types (ShapedArray(int32[33]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00039958953857421875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047018808320) for with global shapes and types (ShapedArray(float32[33]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028777122497558594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047018811200) for with global shapes and types (ShapedArray(float32[96]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002493858337402344 sec
DEBUG:jax

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002617835998535156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054269458208) for with global shapes and types (ShapedArray(bool[77,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002808570861816406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054269458208) for with global shapes and types (ShapedArray(int32[26]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002429485321044922 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054269458368) for with global shapes and types (ShapedArray(float32[26]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00039076805114746094 sec
DEBUG:jax.i

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023484230041503906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048263223712) for with global shapes and types (ShapedArray(bool[29,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002269744873046875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048263223792) for with global shapes and types (ShapedArray(int32[60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002963542938232422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050247223312) for with global shapes and types (ShapedArray(int32[19]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005533695220947266 sec
DEBUG:jax.int

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002732276916503906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053420134544) for with global shapes and types (ShapedArray(bool[33,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00030875205993652344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053420138144) for with global shapes and types (ShapedArray(float32[140]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00031495094299316406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053420136464) for with global shapes and types (ShapedArray(float16[140,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00040459632873535156 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002224445343017578 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053420137344) for with global shapes and types (ShapedArray(float16[188,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003497600555419922 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053420137184) for with global shapes and types (ShapedArray(bool[188,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000370025634765625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053420137424) for with global shapes and types (ShapedArray(int32[289]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0006718635559082031 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003025531768798828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050187058688) for with global shapes and types (ShapedArray(float16[136,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002701282501220703 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050187058688) for with global shapes and types (ShapedArray(bool[136,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00031065940856933594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050187056848) for with global shapes and types (ShapedArray(int32[209]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002923011779785156 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019502639770507812 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670058976) for with global shapes and types (ShapedArray(float32[1084]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024509429931640625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052515237776) for with global shapes and types (ShapedArray(float32[281]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001850128173828125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052515237776) for with global shapes and types (ShapedArray(float16[281,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018453598022460938 sec


DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020766258239746094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054269387456) for with global shapes and types (ShapedArray(bool[66,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003960132598876953 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048523314256) for with global shapes and types (ShapedArray(float32[106]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00033164024353027344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052515437760) for with global shapes and types (ShapedArray(float16[106,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026679039001464844 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002865791320800781 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050327331680) for with global shapes and types (ShapedArray(float16[145,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002651214599609375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050327334000) for with global shapes and types (ShapedArray(bool[145,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00034308433532714844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050327245024) for with global shapes and types (ShapedArray(int32[54]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005540847778320312 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029659271240234375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050307248992) for with global shapes and types (ShapedArray(int32[78]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003268718719482422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050307212848) for with global shapes and types (ShapedArray(int32[204]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005116462707519531 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050307214208) for with global shapes and types (ShapedArray(float32[204]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003771781921386719 sec
DEBUG:jax.i

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003135204315185547 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047834578112) for with global shapes and types (ShapedArray(float16[102,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019240379333496094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047834575952) for with global shapes and types (ShapedArray(bool[102,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019693374633789062 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140047834576352) for with global shapes and types (ShapedArray(int32[159]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002300739288330078 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002799034118652344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052825091824) for with global shapes and types (ShapedArray(float16[150,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003829002380371094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052825089744) for with global shapes and types (ShapedArray(bool[150,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002467632293701172 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052825091904) for with global shapes and types (ShapedArray(int32[168]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002696514129638672 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026726722717285156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054824899584) for with global shapes and types (ShapedArray(float32[418]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024390220642089844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054824901504) for with global shapes and types (ShapedArray(float16[418,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003192424774169922 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054824901504) for with global shapes and types (ShapedArray(bool[418,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000484466552734375 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003070831298828125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054597962864) for with global shapes and types (ShapedArray(bool[147,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0006730556488037109 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054597963264) for with global shapes and types (ShapedArray(int32[127]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0004107952117919922 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054597963024) for with global shapes and types (ShapedArray(float32[127]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002636909484863281 sec
DEBUG:jax

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0004992485046386719 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054598030416) for with global shapes and types (ShapedArray(float32[333]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002613067626953125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054598043904) for with global shapes and types (ShapedArray(int32[182]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001926422119140625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054598042464) for with global shapes and types (ShapedArray(float32[182]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026679039001464844 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001888275146484375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054219159888) for with global shapes and types (ShapedArray(float32[183]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027370452880859375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054219211696) for with global shapes and types (ShapedArray(float32[98]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022673606872558594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054219211776) for with global shapes and types (ShapedArray(float16[98,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029397010803222656 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002696514129638672 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054119305920) for with global shapes and types (ShapedArray(float32[605]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00032019615173339844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054185083664) for with global shapes and types (ShapedArray(float32[382]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002162456512451172 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054185083664) for with global shapes and types (ShapedArray(float16[382,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020170211791992188 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003514289855957031 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053800322128) for with global shapes and types (ShapedArray(int32[110]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003094673156738281 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053800322128) for with global shapes and types (ShapedArray(int32[47]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00035500526428222656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053800323328) for with global shapes and types (ShapedArray(float16[86,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0004551410675048828 sec
DEBUG:jax

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00044226646423339844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053732389328) for with global shapes and types (ShapedArray(float16[194,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0008268356323242188 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053732392048) for with global shapes and types (ShapedArray(bool[194,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00033473968505859375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054628902848) for with global shapes and types (ShapedArray(float32[601]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003902912139892578 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005323886871337891 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055071424896) for with global shapes and types (ShapedArray(int32[1189]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00041222572326660156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055071426656) for with global shapes and types (ShapedArray(float32[1189]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00035071372985839844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055071426656) for with global shapes and types (ShapedArray(float32[576]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005192756652832031 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002105236053466797 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052695054832) for with global shapes and types (ShapedArray(float16[276,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002079010009765625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052695055712) for with global shapes and types (ShapedArray(bool[276,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021386146545410156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052695052752) for with global shapes and types (ShapedArray(int32[636]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000286102294921875 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020384788513183594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054680784384) for with global shapes and types (ShapedArray(bool[118,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0004227161407470703 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054680805184) for with global shapes and types (ShapedArray(float32[233]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002999305725097656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054680804784) for with global shapes and types (ShapedArray(float16[233,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003807544708251953 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020885467529296875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670124272) for with global shapes and types (ShapedArray(float16[231,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002155303955078125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670126032) for with global shapes and types (ShapedArray(bool[231,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00030112266540527344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670125552) for with global shapes and types (ShapedArray(int32[130]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002009868621826172 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003533363342285156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670101696) for with global shapes and types (ShapedArray(float32[457]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00042819976806640625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670100816) for with global shapes and types (ShapedArray(float32[385]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022363662719726562 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670100096) for with global shapes and types (ShapedArray(float16[385,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00031685829162597656 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002143383026123047 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054670032704) for with global shapes and types (ShapedArray(bool[79,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002923011779785156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054311176448) for with global shapes and types (ShapedArray(int32[71]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00030684471130371094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140049937628832) for with global shapes and types (ShapedArray(int32[118]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002868175506591797 sec
DEBUG:jax.in

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002124309539794922 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055071399344) for with global shapes and types (ShapedArray(float32[247]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003829002380371094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055071399424) for with global shapes and types (ShapedArray(float32[459]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003199577331542969 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055071396464) for with global shapes and types (ShapedArray(float16[459,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023126602172851562 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003178119659423828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054208922064) for with global shapes and types (ShapedArray(int32[1019]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002849102020263672 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054208924384) for with global shapes and types (ShapedArray(float32[1019]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003445148468017578 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054208924864) for with global shapes and types (ShapedArray(int32[141]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002040863037109375 sec
DEBUG:jax

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019979476928710938 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505309072) for with global shapes and types (ShapedArray(float16[387,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002384185791015625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505307952) for with global shapes and types (ShapedArray(bool[387,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002689361572265625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052505307472) for with global shapes and types (ShapedArray(int32[240]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021767616271972656 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000274658203125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055092876368) for with global shapes and types (ShapedArray(int32[211]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002810955047607422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052485364688) for with global shapes and types (ShapedArray(float32[226]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020456314086914062 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140055092876768) for with global shapes and types (ShapedArray(float16[226,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019669532775878906 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002868175506591797 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050257155952) for with global shapes and types (ShapedArray(int32[345]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003204345703125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054269398448) for with global shapes and types (ShapedArray(float32[170]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019550323486328125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050257143584) for with global shapes and types (ShapedArray(float16[170,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020766258239746094 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003266334533691406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050337125616) for with global shapes and types (ShapedArray(int32[107]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028586387634277344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050337111152) for with global shapes and types (ShapedArray(float32[532]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003216266632080078 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050337111792) for with global shapes and types (ShapedArray(float16[532,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001800060272216797 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002846717834472656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050436955568) for with global shapes and types (ShapedArray(float16[560,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001983642578125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050436984640) for with global shapes and types (ShapedArray(bool[560,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002334117889404297 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052415584992) for with global shapes and types (ShapedArray(int32[496]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020766258239746094 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020122528076171875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052535131248) for with global shapes and types (ShapedArray(float32[511]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002548694610595703 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052585439120) for with global shapes and types (ShapedArray(float32[282]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018787384033203125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052585439120) for with global shapes and types (ShapedArray(float16[282,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001888275146484375 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00033593177795410156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052585237936) for with global shapes and types (ShapedArray(int32[1621]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018978118896484375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052585216976) for with global shapes and types (ShapedArray(float32[1621]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024700164794921875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052585216416) for with global shapes and types (ShapedArray(float16[167,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018358230590820312 sec


DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003764629364013672 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052745265296) for with global shapes and types (ShapedArray(bool[362,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0008420944213867188 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052745268336) for with global shapes and types (ShapedArray(int32[794]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00047516822814941406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052745266896) for with global shapes and types (ShapedArray(float32[794]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002493858337402344 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003018379211425781 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052815195328) for with global shapes and types (ShapedArray(bool[450,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020503997802734375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052815194528) for with global shapes and types (ShapedArray(int32[687]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021123886108398438 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052815192608) for with global shapes and types (ShapedArray(float32[687]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002830028533935547 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002732276916503906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052954978240) for with global shapes and types (ShapedArray(float32[311]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001964569091796875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052954978400) for with global shapes and types (ShapedArray(float16[311,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019502639770507812 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052954979280) for with global shapes and types (ShapedArray(bool[311,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021028518676757812 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000286102294921875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050327264688) for with global shapes and types (ShapedArray(float16[610,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020551681518554688 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050327265008) for with global shapes and types (ShapedArray(bool[610,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020647048950195312 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050327264048) for with global shapes and types (ShapedArray(int32[1043]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002090930938720703 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003542900085449219 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050407019280) for with global shapes and types (ShapedArray(bool[103,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029969215393066406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053039041728) for with global shapes and types (ShapedArray(int32[476]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00017833709716796875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053007451776) for with global shapes and types (ShapedArray(float32[476]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002789497375488281 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020313262939453125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053075326752) for with global shapes and types (ShapedArray(bool[266,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000186920166015625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053075325632) for with global shapes and types (ShapedArray(int32[371]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000186920166015625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053075325632) for with global shapes and types (ShapedArray(float32[371]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002644062042236328 sec
DEBUG:jax.

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019311904907226562 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053175369536) for with global shapes and types (ShapedArray(bool[325,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024247169494628906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053175367216) for with global shapes and types (ShapedArray(float16[144,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002105236053466797 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053175365776) for with global shapes and types (ShapedArray(bool[144,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024771690368652344 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002713203430175781 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052435439920) for with global shapes and types (ShapedArray(bool[458,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024819374084472656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052435439760) for with global shapes and types (ShapedArray(int32[1242]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0004990100860595703 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052435442160) for with global shapes and types (ShapedArray(float32[1242]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002598762512207031 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001804828643798828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052495303440) for with global shapes and types (ShapedArray(bool[731,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018835067749023438 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052495304240) for with global shapes and types (ShapedArray(int32[2975]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018024444580078125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052495304160) for with global shapes and types (ShapedArray(float32[2975]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027370452880859375 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00032138824462890625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052645031584) for with global shapes and types (ShapedArray(int32[262]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003883838653564453 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052645031184) for with global shapes and types (ShapedArray(float32[262]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028395652770996094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052565293712) for with global shapes and types (ShapedArray(float32[393]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002448558807373047 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020956993103027344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052644998416) for with global shapes and types (ShapedArray(bool[190,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028824806213378906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052965024288) for with global shapes and types (ShapedArray(float32[220]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002052783966064453 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052965025728) for with global shapes and types (ShapedArray(float16[220,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022077560424804688 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027751922607421875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053266438880) for with global shapes and types (ShapedArray(float16[95,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026416778564453125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053266437600) for with global shapes and types (ShapedArray(bool[95,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029730796813964844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053266399456) for with global shapes and types (ShapedArray(int32[215]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002028942108154297 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003399848937988281 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053958731072) for with global shapes and types (ShapedArray(int32[257]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022411346435546875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053958731632) for with global shapes and types (ShapedArray(float32[257]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003757476806640625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053958744000) for with global shapes and types (ShapedArray(float16[159,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0004687309265136719 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022149085998535156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052465510816) for with global shapes and types (ShapedArray(bool[438,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020575523376464844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052465509456) for with global shapes and types (ShapedArray(int32[1641]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000270843505859375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052465509056) for with global shapes and types (ShapedArray(float32[1641]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002658367156982422 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018644332885742188 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052805228240) for with global shapes and types (ShapedArray(float32[654]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028324127197265625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052805185024) for with global shapes and types (ShapedArray(int32[231]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002815723419189453 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052805182144) for with global shapes and types (ShapedArray(int32[530]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028133392333984375 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001971721649169922 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054638199008) for with global shapes and types (ShapedArray(bool[396,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002655982971191406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054638200288) for with global shapes and types (ShapedArray(int32[663]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020384788513183594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054638200448) for with global shapes and types (ShapedArray(float32[663]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002651214599609375 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018072128295898438 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048686242528) for with global shapes and types (ShapedArray(int32[142]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003027915954589844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048686240128) for with global shapes and types (ShapedArray(int32[370]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021147727966308594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048686241088) for with global shapes and types (ShapedArray(float32[370]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026535987854003906 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028777122497558594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048897050720) for with global shapes and types (ShapedArray(float32[389]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029087066650390625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048897074016) for with global shapes and types (ShapedArray(float16[801,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002086162567138672 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140048897076496) for with global shapes and types (ShapedArray(bool[801,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000186920166015625 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00031185150146484375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050446921216) for with global shapes and types (ShapedArray(float32[355]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027060508728027344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050446919056) for with global shapes and types (ShapedArray(int32[202]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002090930938720703 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140050446918576) for with global shapes and types (ShapedArray(float32[202]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002701282501220703 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001785755157470703 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140051528614192) for with global shapes and types (ShapedArray(float16[445,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001983642578125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140051528614832) for with global shapes and types (ShapedArray(bool[445,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018215179443359375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140051528614032) for with global shapes and types (ShapedArray(int32[451]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001780986785888672 sec
DEBUG:ja

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002086162567138672 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052755207328) for with global shapes and types (ShapedArray(float32[1633]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002541542053222656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052755300576) for with global shapes and types (ShapedArray(int32[214]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003046989440917969 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052755312128) for with global shapes and types (ShapedArray(float32[553]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019812583923339844 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020003318786621094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052835205744) for with global shapes and types (ShapedArray(float16[495,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001804828643798828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052835206704) for with global shapes and types (ShapedArray(bool[495,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00030541419982910156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052835178832) for with global shapes and types (ShapedArray(float16[618,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020122528076171875 se

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002522468566894531 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053230108896) for with global shapes and types (ShapedArray(float16[379,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002257823944091797 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053230112656) for with global shapes and types (ShapedArray(bool[379,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021123886108398438 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053230110016) for with global shapes and types (ShapedArray(int32[753]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00028061866760253906 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002689361572265625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053402253088) for with global shapes and types (ShapedArray(int32[1703]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019741058349609375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053402254288) for with global shapes and types (ShapedArray(float32[1703]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00044989585876464844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053402252928) for with global shapes and types (ShapedArray(float32[373]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002357959747314453 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002300739288330078 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054947884864) for with global shapes and types (ShapedArray(float16[1096,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001857280731201172 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054947883024) for with global shapes and types (ShapedArray(bool[1096,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018596649169921875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054947881904) for with global shapes and types (ShapedArray(int32[1527]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001804828643798828 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002448558807373047 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052904979328) for with global shapes and types (ShapedArray(float32[268]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019025802612304688 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052904978928) for with global shapes and types (ShapedArray(float16[268,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023698806762695312 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052904979248) for with global shapes and types (ShapedArray(bool[268,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026869773864746094 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029730796813964844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054567433936) for with global shapes and types (ShapedArray(int32[166]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002903938293457031 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053243746960) for with global shapes and types (ShapedArray(int32[266]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00037026405334472656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140053243747680) for with global shapes and types (ShapedArray(float32[771]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002799034118652344 sec
DEBUG:jax

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00039887428283691406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052087057840) for with global shapes and types (ShapedArray(int32[969]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002117156982421875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052087057120) for with global shapes and types (ShapedArray(float32[969]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024580955505371094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140052087049568) for with global shapes and types (ShapedArray(float32[612]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002636909484863281 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00025200843811035156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054525004352) for with global shapes and types (ShapedArray(float32[542]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003237724304199219 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054525004752) for with global shapes and types (ShapedArray(float16[542,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024509429931640625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054525005312) for with global shapes and types (ShapedArray(bool[542,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002605915069580078 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018715858459472656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054519180800) for with global shapes and types (ShapedArray(int32[1150]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001842975616455078 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054519181120) for with global shapes and types (ShapedArray(float32[1150]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026226043701171875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054519180560) for with global shapes and types (ShapedArray(float16[356,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000274658203125 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020694732666015625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054512919392) for with global shapes and types (ShapedArray(int32[443]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020623207092285156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054512919952) for with global shapes and types (ShapedArray(float32[443]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002868175506591797 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054512918592) for with global shapes and types (ShapedArray(float32[455]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0012078285217285156 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023865699768066406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054503716400) for with global shapes and types (ShapedArray(float16[968,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021338462829589844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054503717600) for with global shapes and types (ShapedArray(bool[968,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002505779266357422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054501293168) for with global shapes and types (ShapedArray(float32[306]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018453598022460938 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001895427703857422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054492906896) for with global shapes and types (ShapedArray(float32[249]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002713203430175781 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054492906096) for with global shapes and types (ShapedArray(int32[328]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019741058349609375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054492905696) for with global shapes and types (ShapedArray(float32[328]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029087066650390625 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002262592315673828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054486802656) for with global shapes and types (ShapedArray(float16[539,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022721290588378906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054486806416) for with global shapes and types (ShapedArray(bool[539,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000240325927734375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054486803296) for with global shapes and types (ShapedArray(int32[1110]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003097057342529297 sec
DEBUG

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002224445343017578 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054478453792) for with global shapes and types (ShapedArray(float32[1815]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000240325927734375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054477272384) for with global shapes and types (ShapedArray(float32[314]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019788742065429688 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054477272784) for with global shapes and types (ShapedArray(float16[314,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001823902130126953 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021147727966308594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054468446384) for with global shapes and types (ShapedArray(float32[558]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002720355987548828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054468447024) for with global shapes and types (ShapedArray(float16[368,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002536773681640625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054468447584) for with global shapes and types (ShapedArray(bool[368,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024628639221191406 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003075599670410156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054459074176) for with global shapes and types (ShapedArray(float32[630]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026679039001464844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054458286864) for with global shapes and types (ShapedArray(float32[510]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020122528076171875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054458287184) for with global shapes and types (ShapedArray(float16[510,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002071857452392578 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002541542053222656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054455235264) for with global shapes and types (ShapedArray(int32[310]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026106834411621094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054453155616) for with global shapes and types (ShapedArray(float32[1856]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000270843505859375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054453156016) for with global shapes and types (ShapedArray(float16[1856,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023603439331054688 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001914501190185547 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054442559024) for with global shapes and types (ShapedArray(float32[930]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002741813659667969 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054442558784) for with global shapes and types (ShapedArray(float16[302,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002536773681640625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054442560624) for with global shapes and types (ShapedArray(bool[302,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019860267639160156 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002760887145996094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054430843744) for with global shapes and types (ShapedArray(float32[863]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005617141723632812 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054430844464) for with global shapes and types (ShapedArray(float16[863,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003304481506347656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054430844944) for with global shapes and types (ShapedArray(bool[863,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00024008750915527344 sec
DEB

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001838207244873047 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054422174032) for with global shapes and types (ShapedArray(float32[504]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00025153160095214844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054422173792) for with global shapes and types (ShapedArray(int32[498]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002048015594482422 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054422175392) for with global shapes and types (ShapedArray(float32[498]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027871131896972656 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00025010108947753906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054409448480) for with global shapes and types (ShapedArray(float32[1107]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027370452880859375 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054409448800) for with global shapes and types (ShapedArray(float16[1107,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000244140625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054409448560) for with global shapes and types (ShapedArray(bool[1107,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001971721649169922 sec
DEBUG:

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002799034118652344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054396978176) for with global shapes and types (ShapedArray(float32[410]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020384788513183594 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054396978736) for with global shapes and types (ShapedArray(float16[410,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020122528076171875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054396979536) for with global shapes and types (ShapedArray(bool[410,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020122528076171875 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002167224884033203 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054389570592) for with global shapes and types (ShapedArray(int32[868]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020122528076171875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054389571152) for with global shapes and types (ShapedArray(float32[868]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002689361572265625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054389405552) for with global shapes and types (ShapedArray(int32[239]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002491474151611328 sec
DEBUG:jax.

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019025802612304688 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054379945312) for with global shapes and types (ShapedArray(float32[1228]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00029015541076660156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054379206592) for with global shapes and types (ShapedArray(float32[488]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0003197193145751953 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054379206912) for with global shapes and types (ShapedArray(float16[488,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0005545616149902344 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019431114196777344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054365277056) for with global shapes and types (ShapedArray(bool[517,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023984909057617188 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054364554480) for with global shapes and types (ShapedArray(float32[594]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001842975616455078 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054364554880) for with global shapes and types (ShapedArray(float16[594,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018548965454101562 sec
D

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002636909484863281 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054353159648) for with global shapes and types (ShapedArray(float32[2483]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000244140625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054350875280) for with global shapes and types (ShapedArray(float16[363,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00017547607421875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054350875600) for with global shapes and types (ShapedArray(bool[363,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019407272338867188 sec
DEBUG:jax.i

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001804828643798828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054334928592) for with global shapes and types (ShapedArray(float32[652]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002872943878173828 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054334929792) for with global shapes and types (ShapedArray(float16[400,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018072128295898438 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054334930112) for with global shapes and types (ShapedArray(bool[400,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00020813941955566406 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027108192443847656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054324598800) for with global shapes and types (ShapedArray(int32[898]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023245811462402344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054324599360) for with global shapes and types (ShapedArray(float32[898]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00026535987854003906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054324597200) for with global shapes and types (ShapedArray(int32[919]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021409988403320312 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001823902130126953 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054315199840) for with global shapes and types (ShapedArray(bool[399,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00027942657470703125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054313056192) for with global shapes and types (ShapedArray(int32[570]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002655982971191406 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054313056592) for with global shapes and types (ShapedArray(float32[570]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023865699768066406 sec
DEBUG:j

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019025802612304688 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054298954304) for with global shapes and types (ShapedArray(int32[1059]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002601146697998047 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054298954624) for with global shapes and types (ShapedArray(float32[1059]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002846717834472656 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054298691280) for with global shapes and types (ShapedArray(float16[249,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00019431114196777344 sec
DE

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018477439880371094 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054292187792) for with global shapes and types (ShapedArray(float16[1155,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021409988403320312 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054292188352) for with global shapes and types (ShapedArray(bool[1155,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001862049102783203 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054292188912) for with global shapes and types (ShapedArray(int32[1890]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018835067749023438 sec


DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001914501190185547 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054284417040) for with global shapes and types (ShapedArray(float16[883,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018906593322753906 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054284418960) for with global shapes and types (ShapedArray(bool[883,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00018668174743652344 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054284417920) for with global shapes and types (ShapedArray(int32[856]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001862049102783203 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0001862049102783203 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054274521024) for with global shapes and types (ShapedArray(float32[1014]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002770423889160156 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054273463856) for with global shapes and types (ShapedArray(int32[1519]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00022101402282714844 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054273464256) for with global shapes and types (ShapedArray(float32[1519]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00025081634521484375 sec
DEBU

DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.0002765655517578125 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054263810784) for with global shapes and types (ShapedArray(int32[552]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.000244140625 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054258666608) for with global shapes and types (ShapedArray(float32[619]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00023651123046875 sec
DEBUG:jax.interpreters.pxla:Compiling prim_fun (140054258667008) for with global shapes and types (ShapedArray(float16[619,60]),). Argument mapping: (OpShardingSharding({replicated}),).
DEBUG:jax._src.dispatch:Finished tracing + transforming jit(convert_element_type) in 0.00021338462829589844 sec
DEBUG:jax.inte

In [21]:
m4inaptients_jax.size_in_bytes / 1024 ** 3

0.7668210975825787

In [23]:
m4inaptients_jax.subjects[splits[2][0]]

Inpatient(
  subject_id='18326767',
  static_info=StaticInfo(
    gender='F',
    date_of_birth=Timestamp('2104-01-01 00:00:00'),
    ethnicity=bool[5],
    ethnicity_scheme=<lib.ehr.coding_scheme.MIMIC4Eth5 object at 0x7f60b53dc670>,
    constant_vec=bool[6]
  ),
  admissions=[
    InpatientAdmission(
      admission_id='26809767',
      admission_dates=(
        Timestamp('2153-08-01 21:09:00'),
        Timestamp('2153-08-03 16:18:00')
      ),
      dx_codes=CodesVector(
        vec=bool[17375],
        scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7f5f11164400>
      ),
      dx_codes_history=CodesVector(
        vec=bool[17375],
        scheme=<lib.ehr.coding_scheme.DxICD9 object at 0x7f5f11164400>
      ),
      outcome=CodesVector(
        vec=bool[2081],
        scheme=<lib.ehr.outcome.OutcomeExtractor object at 0x7f60b53dc0d0>
      ),
      observables=InpatientObservables(
        time=f32[3],
        value=f16[3,60],
        mask=bool[3,60]
      ),
      interventions=I

## TODO

1. Squeeze code vectors.
2. Squeeze float32 to float16.