<a href="https://colab.research.google.com/github/odddkidout/basic-term-GPU-model/blob/master/jaxtuary%20m1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Methodology for Reserve calculations for Life Insurance Product**

Using jax and pymort

In [7]:
!pip install pymort
from jax import numpy as jnp
from pymort import getIdGroup, MortXML

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


For this we will be using the study `2017_CSO loaded preferred_structure gender_distinct ANB`. We can get the [pymort](https://github.com/actuarialopensource/pymort) object that represents this collection by referencing any of the [table ids](https://mort.soa.org/) belonging to the collection. 

In [8]:
print(getIdGroup(3299))

IdGroup(study='2017_CSO', grouping='loaded preferred_structure gender_distinct ANB', ids=(3299, 3300, 3301, 3302, 3303, 3304, 3305, 3306, 3307, 3308), genders=('male', 'male', 'male', 'female', 'female', 'female', 'male', 'male', 'female', 'female'), risks=('nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'nonsmoker super_preferred', 'nonsmoker preferred', 'nonsmoker residual', 'smoker preferred', 'smoker residual', 'smoker preferred', 'smoker residual'))


Load moratality tables to tensor format

In [9]:
ids = getIdGroup(3299).ids
select = jnp.array([MortXML(id).Tables[0].Values.unstack().values for id in ids])
ultimate = jnp.array([MortXML(id).Tables[1].Values.unstack().values for id in ids])
print(f"select.shape: {select.shape}") # tableIds [3299, 3308], issue_ages [18, 95], durations [1, 25]
print(f"ultimate.shape: {ultimate.shape}") # tableIds [3299, 3308], attained_ages [18, 120]



select.shape: (10, 78, 25)
ultimate.shape: (10, 103)


Policy holder attributes

In [11]:
mortality_table_index = jnp.array([0,1,2])
issue_age = jnp.array([30, 40, 50])
duration = jnp.array([0, 0, 0]) # new business
face = jnp.array([1000*x for x in [100, 500, 250]])
ann_prem = jnp.array([20.070742, 224.05084 , 322.29498])

Traditional actuarial modeling techniques do calculations recursively. In contrast, we compute cashflows for all points in time simultaneously. This allows parallelization over the time dimension on the GPU. 

Take the initial `duration` vector of shape `(modelpoints, )` and turn it into a `duration_projected` matrix shape `(timesteps, modelpoints)` where each row represents a different timestep.

Use broadcasting to do this. Broadcasting is explained in detail [here](https://numpy.org/doc/stable/user/basics.broadcasting.html#general-broadcasting-rules), it discusses this exact problem.

In [12]:
timesteps = 5 # The policy is a 5-year policy
print(f"duration: \n {duration}")
time_axis = jnp.arange(timesteps)[:, jnp.newaxis]
print(f"time_axis: \n {time_axis}")
duration_projected = time_axis + duration
print(f"duration_projected: \n {duration_projected}")

duration: 
 [0 0 0]
time_axis: 
 [[0]
 [1]
 [2]
 [3]
 [4]]
duration_projected: 
 [[0 0 0]
 [1 1 1]
 [2 2 2]
 [3 3 3]
 [4 4 4]]


In [17]:
# Once duration passes select table, use ultimate table
q = jnp.where(
    duration_projected < select.shape[-1],
    select[mortality_table_index, issue_age - 18, duration_projected],
    ultimate[mortality_table_index, (issue_age - 18) + duration_projected],
)

npx = jnp.concatenate([jnp.ones((1, q.shape[1])), jnp.cumprod(1-q, axis=0)[:-1]])

claims = face * npx * q
premiums = ann_prem * npx
print("premiums: \n", premiums)
print("claims: \n", claims)

prem: 
 [[ 20.070742 224.05084  322.29498 ]
 [ 20.06773  224.00827  322.0565  ]
 [ 20.06452  223.92987  321.7473  ]
 [ 20.060307 223.8179   321.31296 ]
 [ 20.05549  223.68584  320.7892  ]]
claims: 
 [[ 15.000001  95.       185.      ]
 [ 15.9976   174.96675  239.82239 ]
 [ 20.99349  249.86502  336.9265  ]
 [ 23.987522 294.6933   406.25833 ]
 [ 26.979483 339.4461   487.71072 ]]
net_cashflow: 
 [[   5.0707407   129.05084     137.29498   ]
 [   4.0701303    49.04152      82.2341    ]
 [  -0.92897034  -25.93515     -15.179199  ]
 [  -3.9272156   -70.8754      -84.94537   ]
 [  -6.923992   -115.76027    -166.92151   ]]


In [19]:
discount_factor = 1/(1.02)
# discount factors for payments by policyholder starting at t = 0
discounts = discount_factor ** jnp.arange(timesteps)[:, jnp.newaxis]
# discounts factors for payments from insurer starting at t = 1 (payouts are at end of year of death)
discounts_lagged = discounts * discount_factor
print(discounts)
print(discounts_lagged)
discounted_expected_claims = face * npx * q * discounts_lagged
print("#### INPUT ####")
print("face shape: ", face.shape)
print("npx shape: ", npx.shape)
print("q shape: ", q.shape)
print("discounts_lagged shape: ", discounts_lagged.shape)
print("#### OUTPUT ####")
print("discounted_expected_claims = face * npx * q * discounts_lagged")
print("discounted_expected_claims shape: ", discounted_expected_claims.shape)

[[1.        ]
 [0.98039216]
 [0.96116877]
 [0.9423223 ]
 [0.9238454 ]]
[[0.98039216]
 [0.96116877]
 [0.9423223 ]
 [0.9238454 ]
 [0.9057308 ]]
#### INPUT ####
face shape:  (3,)
npx shape:  (5, 3)
q shape:  (5, 3)
discounts_lagged shape:  (5, 1)
#### OUTPUT ####
discounted_expected_claims = face * npx * q * discounts_lagged
discounted_expected_claims shape:  (5, 3)


In [25]:
discounted_expected_premiums = premiums * discounts
discounted_expected_claims = face * npx * q * discounts_lagged
print(discounted_expected_premiums)
print(discounted_expected_claims)
print("Reserves at each timestamps \n", jnp.sum(discounted_expected_premiums-discounted_expected_claims,axis=1))

[[ 20.070742 224.05084  322.29498 ]
 [ 19.674246 219.61595  315.74167 ]
 [ 19.285389 215.23439  309.25348 ]
 [ 18.903275 210.9086   302.78036 ]
 [ 18.528173 206.65114  296.35965 ]]
[[ 14.705883  93.13725  181.37254 ]
 [ 15.376393 168.17258  230.50978 ]
 [ 19.782635 235.45338  317.49338 ]
 [ 22.160763 272.25104  375.3199  ]
 [ 24.436148 307.44678  441.73462 ]]
Reserves at each timestamps 
 [ 277.20087   140.97311   -28.956139 -137.13945  -252.07858 ]
