# 5d. Parameter learning using Expectation Maximization (EM)
This notebook shows how parameter estimation is implemented in Thomas.

In [1]:
%run '_preamble.ipynb'

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload

available imports:
  import os
  import logging
  import pandas as pd
  import numpy as np

connect to this kernel with:
  jupyter console --existing 481401e2-4b20-4d15-b767-66e4f1af2fd7

Logging to: "/Users/melle/software-development/thomas-master/logs/5d. Parameter learning using Expectation Maximization (EM).log"
Current date/time: 04-07-2020, 16:28
Current working directory: "/Users/melle/software-development/thomas-master/notebooks"


In [2]:
import functools

from thomas.core import examples
from thomas.core import BayesianNetwork, Factor, CPT, JPT
from thomas.core.bayesiannetwork import DiscreteNetworkNode
from thomas.jupyter import BayesianNetworkWidget

from IPython.display import display, HTML

## Initialization

### Load the BN

In [None]:
# Load the JSON version so we can copy the node's positions ... 
lc = examples.get_lungcancer_network()
positions = {n.RV: n.position for n in lc.nodes.values()}
positions

In [19]:
# bn = examples.get_lungcancer_network()
from thomas.core.reader import oobn

filename = thomas.core.get_pkg_data('TNM_explains_death_before_training_with_TNM7_priors.oobn')
bn = oobn.read(filename)

for RV, position in positions.items():
    bn[RV].position = position

In [20]:
widget = BayesianNetworkWidget(bn, height=600)
widget

BayesianNetworkWidget(height=600, marginals_and_evidence={'marginals': {'TNM': {'1A': 0.11290445070655371, '1B…

In [5]:
bn['edition'].cpt

edition,TNM 5,TNM 6,TNM 7
,1.0,1.0,1.0


### Load the data

In [6]:
T_dtype = pd.api.types.CategoricalDtype(
    categories=[
        '0',
        'IS',
        '1', '1MI', '1A', '1B', '1C',
        '2', '2A', '2B', '2C',
        '3', '3A', '3B', '3C',
        '4', '4A', '4B', '4C',
        'X',
        'NaN',
    ],
    ordered=True
)

N_dtype = pd.api.types.CategoricalDtype(
    categories=[
        '0',
        '1', '1A', '1B', '1C', '1M',
        '2', '2A', '2B', '2C',
        '3', '3A', '3B', '3C',
        'X',
        'NaN',
    ],
    ordered=True
)

M_dtype = pd.api.types.CategoricalDtype(
    categories=[
    '0',
    '1', '1A', '1B', '1C',
    'X', '-',
    'NaN',
    ],
    ordered=True
)

TNM_dtype = pd.api.types.CategoricalDtype(
    categories=[
        '1', '1A', '1A1', '1A2', '1A3', '1B', '1C',
        '2', '2A', '2B', '2C',
        '3', '3A', '3B', '3C',
        '4', '4A', '4B', '4C',
        'M',
        'NaN',
    ],
    ordered=True
)

edition_dtype = pd.api.types.CategoricalDtype(
    categories=[
        'TNM 5', 'TNM 6', 'TNM 7',
        'NaN',
    ],
    ordered=True
)

death_dtype = pd.api.types.CategoricalDtype(
    categories=[
        '0-30 days', 
        '1-4 months', 
        '4-6 months', 
        '6-12 months', 
        '1-2 years', 
        '> 2 years',
        'NaN',
    ],
    ordered=True
)


In [7]:
filename = thomas.core.get_pkg_data('data_training_subset_tnm567_resampled.csv')
df = pd.read_csv(
    filename, 
    sep=',',
    dtype={
        'cT': T_dtype,
        'cN': N_dtype,
        'cM': M_dtype,
        'cTNM': TNM_dtype,
        'edition': edition_dtype,
        'death': death_dtype,
    },
)

print(f'df.shape: {df.shape[0]} rows x {df.shape[1]} cols')
print(f'This dataset has {df.isna().sum().sum()} NAs')

df.head()

df.shape: 135000 rows x 6 cols
This dataset has 18781 NAs


Unnamed: 0,cT,cN,cM,cTNM,edition,death
0,2,0,0,1B,TNM 5,> 2 years
1,1,0,0,1A,TNM 5,1-2 years
2,2,2,0,3A,TNM 5,4-6 months
3,4,2,1,4,TNM 5,4-6 months
4,4,2,1,4,TNM 5,> 2 years


In [8]:
df.edition.value_counts()

TNM 7    45000
TNM 6    45000
TNM 5    45000
NaN          0
Name: edition, dtype: int64

## EM Learning the using junction tree algorithm

In [9]:
%load_ext line_profiler

In [None]:
lc = examples.get_lungcancer_network()
positions = {n.RV: n.position for n in lc.nodes.values()}
positions

copy = bn.copy()
copy.elimination_order = ["TNM", "cTNM", "cM", "cT", "cN", "edition", "death", "N","T" ,"M"]

for RV, position in positions.items():
    copy[RV].position = position

In [12]:
BayesianNetworkWidget(copy, height=600)

BayesianNetworkWidget(height=600, marginals_and_evidence={'marginals': {'TNM': {'1A': 0.11290445070655371, '1B…

In [13]:
# %lprun -m thomas.core -s -u1 -T profile2.txt copy.EM_learning(df, max_iterations=1)
copy.EM_learning(df, max_iterations=7)

100%|██████████| 7/7 [01:36<00:00, 13.78s/it]


In [14]:
type(bn['cTNM'].cpt.flat[0])

numpy.float64

In [15]:
bn['cTNM'].cpt

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,cTNM,1A,1B,2A,2B,3A,3B,4,X
edition,cT,cN,cM,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
TNM 5,1,0,0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 5,1,0,1,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 5,1,0,1A,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 5,1,0,1B,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 5,1,1,0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
...,...,...,...,...,...,...,...,...,...,...,...
TNM 7,X,2,1B,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 7,X,3,0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 7,X,3,1,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0
TNM 7,X,3,1A,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


In [16]:
copy['cTNM'].cpt

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,cTNM,1A,1B,2A,2B,3A,3B,4,X
edition,cN,cT,cM,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1
TNM 5,0,1,0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TNM 5,0,1,1,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
TNM 5,0,1,1A,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TNM 5,0,1,1B,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TNM 5,0,1A,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...
TNM 7,3,4,1B,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
TNM 7,3,X,0,0.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
TNM 7,3,X,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
TNM 7,3,X,1A,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0


In [17]:
copy.get_marginals(['cTNM'])

{'cTNM': factor(cTNM)
 cTNM
 1A      0.101
 1B      0.108
 2A      0.020
 2B      0.044
 3A      0.130
 3B      0.149
 4       0.446
 X       0.003
 dtype: float64}

In [18]:
raise Exception('nooo')

Exception: nooo

In [None]:
bn.vars

In [None]:
overlapping_cols = list(set(df.columns).intersection(bn.vars))
overlapping_cols

In [None]:
sizes = df.groupby(overlapping_cols).size()
sizes

In [None]:
sizes.name = 'count'
counts = pd.DataFrame(sizes)
counts = counts.reset_index()
counts = counts[counts['count'] > 0]
counts = counts.reset_index(drop=True)
counts

In [None]:
counts.replace('NaN', np.nan)

In [None]:
counts.isna().any(axis=1).sum()

In [None]:
row = counts.iloc[0]
row

In [None]:
row.pop('count')

In [None]:
row

In [None]:
bn.nodes['cN']

In [None]:
np.isnan(bn.nodes['cN'].cpt.values).sum()