# Measure Sampling Time - Numpyro


# Select Data, IRT model, and Device


In [1]:
# 0 -> bone
# 1 (others) -> brain

#DATA = 0 
DATA = 1



In [2]:
# 0 -> 1PL-IRT
# 1 (others) -> 2PL-IRT 

MODEL = 0
#MODEL = 1

In [3]:
# 0 -> CPU
# 1 -> GPU
# 2 -> GPU vectorized

#DEVICE = 0
DEVICE = 1

###########################
#DEVICE = 2 # do not work
###########################

In [4]:
num_chains = 2

if DEVICE == 2:
    chain_method = 'vectorized'
else:
    chain_method = 'parallel'

# Prepare

In [5]:
! cat /proc/cpuinfo

processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 85
model name	: Intel(R) Xeon(R) CPU @ 2.00GHz
stepping	: 3
microcode	: 0xffffffff
cpu MHz		: 2000.154
cache size	: 39424 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa mmi

In [6]:
#! pip install -q "jax[cuda11_cudnn805]"==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
! pip install -q numpyro==0.10.1 arviz==0.12.1

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/292.7 KB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[90m╺[0m [32m286.7/292.7 KB[0m [31m17.7 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m292.7/292.7 KB[0m [31m6.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [7]:
import numpy as np
import pandas as pd
import datetime as dt
import time

import matplotlib.pyplot as plt
import seaborn as sns

In [8]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az
import jax.numpy as jnp

In [9]:
if DEVICE == 0:
    numpyro.set_platform('cpu')
    numpyro.set_host_device_count(num_chains)
else:
    numpyro.set_platform('gpu')
    n = jax.device_count()
    print("number of GPU", n)
    if n < 1:
        raise Exception("no GPU")
    else:
        ! nvidia-smi

number of GPU 1
Thu Mar  9 03:32:24 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   45C    P0    26W /  70W |    105MiB / 15360MiB |      6%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------

## Import Data

In [10]:
fff = "idata_ppc_for_data%s_model%s.nc" % (DATA, MODEL)

! wget https://filedn.com/lpAczQGgeBjkX6l7SpI5JJy/__ws/stan_irt_nrm_rad/{fff} -O idata_ppc.nc

--2023-03-09 03:32:24--  https://filedn.com/lpAczQGgeBjkX6l7SpI5JJy/__ws/stan_irt_nrm_rad/idata_ppc_for_data1_model0.nc
Resolving filedn.com (filedn.com)... 23.109.93.100
Connecting to filedn.com (filedn.com)|23.109.93.100|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 38241044 (36M) [application/x-netcdf]
Saving to: ‘idata_ppc.nc’


2023-03-09 03:32:26 (25.1 MB/s) - ‘idata_ppc.nc’ saved [38241044/38241044]



In [11]:
idata_ppc = az.from_netcdf('idata_ppc.nc')

In [12]:
y_ppc = idata_ppc.posterior_predictive['obs'].to_numpy()
y_ppc.shape

(6, 3000, 42, 14)

# Run Numpyro

## Define Model

In [13]:
def model_1pl(y=None, num_cases=0, num_doctors=0):
    with numpyro.plate('doctor', num_doctors):
      theta = numpyro.sample('theta', dist.Normal(0, 2))

    with numpyro.plate('case', num_cases, dim=-2):
        beta = numpyro.sample('beta', dist.Normal(0, 2))
        with numpyro.plate('doctor', num_doctors):
            mu = theta - beta
            numpyro.sample('obs', dist.Bernoulli(logits=mu), obs=y)

In [14]:
def model_2pl(y=None, num_cases=0, num_doctors=0):
    with numpyro.plate('doctor', num_doctors):
      theta = numpyro.sample('theta', dist.Normal(0, 2))

    with numpyro.plate('case', num_cases, dim=-2):
        beta = numpyro.sample('beta', dist.Normal(0, 2))
        log_d = numpyro.sample('log_d', dist.Normal(0.5, 1))
        with numpyro.plate('doctor', num_doctors):
            mu = jnp.exp(log_d)*(theta - beta)
            numpyro.sample('obs', dist.Bernoulli(logits=mu), obs=y)

In [15]:
model = model_1pl if MODEL == 0 else model_2pl

In [16]:
nuts = numpyro.infer.NUTS(model)

#mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=num_chains)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=num_chains, chain_method=chain_method)


  mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=num_chains, chain_method=chain_method)


## Measure Inference Time

In [17]:
key = jax.random.PRNGKey(0)

factors = [1, 1, 2, 5, 10, 20, 50, 100, 200, 500, 1000] 
#factors = [1, 50, 20, 10, 5, 2, 1] 

num_doctors = y_ppc.shape[-1]

In [18]:
lines = []

for k, factor in enumerate(factors):
    y_simulated = y_ppc[0, :factor, :, :].reshape(-1, num_doctors)
    print("simulated data shape:", y_simulated.shape)

    num_cases = y_simulated.shape[0]
    start_time = dt.datetime.now()

    mcmc.run(key, y=y_simulated, num_cases=num_cases, num_doctors=num_doctors)

    end_time = dt.datetime.now()
    elapsed_time = (end_time - start_time).total_seconds()
    lines.append(f'{num_cases},{start_time},{end_time},{elapsed_time}')

    time.sleep(1) # sleep 1 sec

simulated data shape: (42, 14)


sample: 100%|██████████| 3500/3500 [00:22<00:00, 153.21it/s, 7 steps of size 4.26e-01. acc. prob=0.86]
sample: 100%|██████████| 3500/3500 [00:20<00:00, 172.27it/s, 7 steps of size 4.48e-01. acc. prob=0.85]


simulated data shape: (42, 14)


sample: 100%|██████████| 3500/3500 [00:23<00:00, 149.56it/s, 7 steps of size 4.26e-01. acc. prob=0.86]
sample: 100%|██████████| 3500/3500 [00:18<00:00, 188.06it/s, 7 steps of size 4.48e-01. acc. prob=0.85]


simulated data shape: (84, 14)


sample: 100%|██████████| 3500/3500 [00:26<00:00, 132.21it/s, 15 steps of size 4.16e-01. acc. prob=0.84]
sample: 100%|██████████| 3500/3500 [00:28<00:00, 121.56it/s, 15 steps of size 3.85e-01. acc. prob=0.87]


simulated data shape: (210, 14)


sample: 100%|██████████| 3500/3500 [00:32<00:00, 108.25it/s, 15 steps of size 3.06e-01. acc. prob=0.87]
sample: 100%|██████████| 3500/3500 [00:29<00:00, 118.44it/s, 15 steps of size 3.49e-01. acc. prob=0.84]


simulated data shape: (420, 14)


sample: 100%|██████████| 3500/3500 [00:33<00:00, 103.73it/s, 15 steps of size 2.83e-01. acc. prob=0.86]
sample: 100%|██████████| 3500/3500 [00:31<00:00, 110.00it/s, 15 steps of size 2.72e-01. acc. prob=0.87]


simulated data shape: (840, 14)


sample: 100%|██████████| 3500/3500 [00:35<00:00, 97.38it/s, 15 steps of size 2.49e-01. acc. prob=0.85] 
sample: 100%|██████████| 3500/3500 [00:33<00:00, 103.08it/s, 15 steps of size 2.32e-01. acc. prob=0.87]


simulated data shape: (2100, 14)


sample: 100%|██████████| 3500/3500 [00:57<00:00, 61.31it/s, 15 steps of size 2.10e-01. acc. prob=0.84]
sample: 100%|██████████| 3500/3500 [01:00<00:00, 58.26it/s, 31 steps of size 1.71e-01. acc. prob=0.89]


simulated data shape: (4200, 14)


sample: 100%|██████████| 3500/3500 [01:03<00:00, 55.08it/s, 31 steps of size 1.64e-01. acc. prob=0.86]
sample: 100%|██████████| 3500/3500 [01:01<00:00, 57.36it/s, 31 steps of size 1.73e-01. acc. prob=0.84]


simulated data shape: (8400, 14)


sample: 100%|██████████| 3500/3500 [01:12<00:00, 48.60it/s, 31 steps of size 1.44e-01. acc. prob=0.85]
sample: 100%|██████████| 3500/3500 [01:11<00:00, 49.16it/s, 31 steps of size 1.65e-01. acc. prob=0.80]


simulated data shape: (21000, 14)


sample: 100%|██████████| 3500/3500 [01:17<00:00, 44.96it/s, 31 steps of size 1.25e-01. acc. prob=0.83]
sample: 100%|██████████| 3500/3500 [01:14<00:00, 46.69it/s, 31 steps of size 1.38e-01. acc. prob=0.79]


simulated data shape: (42000, 14)


sample: 100%|██████████| 3500/3500 [02:25<00:00, 24.06it/s, 63 steps of size 9.87e-02. acc. prob=0.84]
sample: 100%|██████████| 3500/3500 [02:23<00:00, 24.31it/s, 63 steps of size 9.82e-02. acc. prob=0.84]


## Export Data

In [19]:
path = "time_measured_numpyro_data%s_model%s_device%s.csv" % (DATA, MODEL, DEVICE)

header = ['num_cases,start_time,end_time,elapsed_time']

with open(path, mode='w') as f:
    f.write('\n'.join(header + lines))

In [20]:
from google.colab import files

files.download(path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>