# Measure Sampling Time - Numpyro


# Select Data, IRT model, and Device


In [21]:
# 0 -> bone
# 1 (others) -> brain

#DATA = 0 
DATA = 1


In [22]:
# 0 -> 1PL-IRT
# 1 (others) -> 2PL-IRT 

MODEL = 0
#MODEL = 1


In [23]:
# 0 -> CPU
# 1 -> GPU
# 2 -> GPU vectorized

DEVICE = 0
#DEVICE = 1

###########################
#DEVICE = 2 # may not work
###########################

In [24]:
num_chains = 2

if DEVICE == 2:
    chain_method = 'vectorized'
else:
    chain_method = 'parallel'

# Prepare

In [25]:
! cat /proc/cpuinfo

processor	: 0
vendor_id	: GenuineIntel
cpu family	: 6
model		: 79
model name	: Intel(R) Xeon(R) CPU @ 2.20GHz
stepping	: 0
microcode	: 0xffffffff
cpu MHz		: 2199.998
cache size	: 56320 KB
physical id	: 0
siblings	: 2
core id		: 0
cpu cores	: 1
apicid		: 0
initial apicid	: 0
fpu		: yes
fpu_exception	: yes
cpuid level	: 13
wp		: yes
flags		: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single ssbd ibrs ibpb stibp fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx smap xsaveopt arat md_clear arch_capabilities
bugs		: cpu_meltdown spectre_v1 spectre_v2 spec_store_bypass l1tf mds swapgs taa mmio_stale_data retbleed
bogomips	: 4399.99
clflush size	: 64
cache_alignment	: 64
addres

In [26]:
#! pip install -q "jax[cuda11_cudnn805]"==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
! pip install -q numpyro==0.10.1 arviz==0.12.1

In [27]:
import numpy as np
import pandas as pd
import datetime as dt
import time

import matplotlib.pyplot as plt
import seaborn as sns

In [28]:
import numpyro
import numpyro.distributions as dist

import jax
import arviz as az
import jax.numpy as jnp

In [29]:
if DEVICE == 0:
    numpyro.set_platform('cpu')
    numpyro.set_host_device_count(num_chains)
else:
    numpyro.set_platform('gpu')
    n = jax.device_count()
    print("number of GPU", n)
    if n < 1:
        raise Exception("no GPU")
    else:
        ! nvidia-smi

## Import Data

In [30]:
fff = "idata_ppc_for_data%s_model%s.nc" % (DATA, MODEL)

! wget https://filedn.com/lpAczQGgeBjkX6l7SpI5JJy/__ws/stan_irt_nrm_rad/{fff} -O idata_ppc.nc

--2023-03-15 22:19:21--  https://filedn.com/lpAczQGgeBjkX6l7SpI5JJy/__ws/stan_irt_nrm_rad/idata_ppc_for_data1_model0.nc
Resolving filedn.com (filedn.com)... 74.120.9.25
Connecting to filedn.com (filedn.com)|74.120.9.25|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 38241044 (36M) [application/x-netcdf]
Saving to: ‘idata_ppc.nc’


2023-03-15 22:19:36 (2.52 MB/s) - ‘idata_ppc.nc’ saved [38241044/38241044]



In [31]:
idata_ppc = az.from_netcdf('idata_ppc.nc')

In [32]:
y_ppc = idata_ppc.posterior_predictive['obs'].to_numpy()
y_ppc.shape

(6, 3000, 42, 14)

# Run Numpyro

## Define Model

In [33]:
def model_1pl(y=None, num_cases=0, num_doctors=0):
    with numpyro.plate('doctor', num_doctors):
      theta = numpyro.sample('theta', dist.Normal(0, 2))

    with numpyro.plate('case', num_cases, dim=-2):
        beta = numpyro.sample('beta', dist.Normal(0, 2))
        with numpyro.plate('doctor', num_doctors):
            mu = theta - beta
            numpyro.sample('obs', dist.Bernoulli(logits=mu), obs=y)

In [34]:
def model_2pl(y=None, num_cases=0, num_doctors=0):
    with numpyro.plate('doctor', num_doctors):
      theta = numpyro.sample('theta', dist.Normal(0, 2))

    with numpyro.plate('case', num_cases, dim=-2):
        beta = numpyro.sample('beta', dist.Normal(0, 2))
        log_d = numpyro.sample('log_d', dist.Normal(0.5, 1))
        with numpyro.plate('doctor', num_doctors):
            mu = jnp.exp(log_d)*(theta - beta)
            numpyro.sample('obs', dist.Bernoulli(logits=mu), obs=y)

In [35]:
model = model_1pl if MODEL == 0 else model_2pl

In [36]:
nuts = numpyro.infer.NUTS(model)

#mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=num_chains)
mcmc = numpyro.infer.MCMC(nuts, num_warmup=500, num_samples=3000, num_chains=num_chains, chain_method=chain_method)


## Measure Inference Time

In [37]:
key = jax.random.PRNGKey(0)

if MODEL == 0:
    factors = [1, 1, 2, 5, 10, 20, 50, 100, 200, 500] 
else:
    factors = [1, 1, 2, 5, 10, 20, 50, 100, 200] 

num_doctors = y_ppc.shape[-1]

In [38]:
lines = []

for k, factor in enumerate(factors):
    y_simulated = y_ppc[0, :factor, :, :].reshape(-1, num_doctors)
    print("simulated data shape:", y_simulated.shape)

    num_cases = y_simulated.shape[0]
    start_time = dt.datetime.now()

    mcmc.run(key, y=y_simulated, num_cases=num_cases, num_doctors=num_doctors)

    end_time = dt.datetime.now()
    elapsed_time = (end_time - start_time).total_seconds()
    lines.append(f'{num_cases},{start_time},{end_time},{elapsed_time}')

    time.sleep(1) # sleep 1 sec

simulated data shape: (42, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (42, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (84, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (210, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (420, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (840, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (2100, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (4200, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (8400, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

simulated data shape: (21000, 14)


  0%|          | 0/3500 [00:00<?, ?it/s]

  0%|          | 0/3500 [00:00<?, ?it/s]

## Export Data

In [39]:
path = "time_measured_numpyro_data%s_model%s_device%s.csv" % (DATA, MODEL, DEVICE)

header = ['num_cases,start_time,end_time,elapsed_time']

with open(path, mode='w') as f:
    f.write('\n'.join(header + lines))

In [40]:
from google.colab import files

files.download(path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>