# Get started

### Imports

In [None]:
#! pip install -U ipywidgets
#! pip install matplotlib 
#! pip install scikit-learn 
#! pip install ray 
#! pip install fsspec 
#! pip install pyarrow 
#! pip install sqlalchemy
#! pip install torchinfo

In [20]:
import sys

# Set path to root directory
sys.path.append(r'/home/rlfowler/Documents/research/tfo_inverse_modelling')

from pathlib import Path
from torch.optim import Adam, SGD
import torch.nn as nn
import torch
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
from sklearn import preprocessing
from inverse_modelling_tfo.data import (
    generate_data_loaders,
    config_based_normalization,
)
from inverse_modelling_tfo.data.intensity_interpolation import (
    interpolate_exp,
    get_interpolate_fit_params,
    exp_piecewise_affine,
)
from inverse_modelling_tfo.data.interpolation_function_zoo import *
from inverse_modelling_tfo.models import RandomSplit, ValidationMethod, HoldOneOut, CVSplit, CombineMethods
from inverse_modelling_tfo.models.custom_models import (
    SplitChannelCNN,
    PerceptronReLU,
    PerceptronBN,
    PerceptronDO,
    PerceptronBD,
)
from inverse_modelling_tfo.features.build_features import (
    FetalACFeatureBuilder,
    RowCombinationFeatureBuilder,
    TwoColumnOperationFeatureBuilder,
    FetalACbyDCFeatureBuilder,
    LogTransformFeatureBuilder,
    ConcatenateFeatureBuilder,
)
from inverse_modelling_tfo.features.data_transformations import (
    LongToWideIntensityTransformation,
    ToFittingParameterTransformation,
)
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler
import torchinfo
from inverse_modelling_tfo.misc.misc_training import set_seed

# Set my GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [21]:
DATA_PATH = r'/home/rraiyan/simulations/tfo_sim/data/compiled_intensity/dan_iccps_pencil.pkl'
CONFIG_PATH = Path(r'/home/rraiyan/simulations/tfo_sim/data/compiled_intensity/dan_iccps_pencil.json')

# Load data
data = pd.read_pickle(DATA_PATH)

print(data.shape)
data.head()

(19906560, 9)


Unnamed: 0,Wave Int,SDD,Uterus Thickness,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Hb Concentration,Fetal Saturation,Intensity
0,2.0,10,5.0,14.0,11.0,0.9,11.0,0.1,1910493.0
1,2.0,15,5.0,14.0,11.0,0.9,11.0,0.1,207826.6
2,2.0,19,5.0,14.0,11.0,0.9,11.0,0.1,37728.55
3,2.0,24,5.0,14.0,11.0,0.9,11.0,0.1,5132.866
4,2.0,28,5.0,14.0,11.0,0.9,11.0,0.1,881.2075


In [22]:
# Normalize data using the json file
config_based_normalization(data, CONFIG_PATH) # May need to change this for my own code

# # Drop Uterus Thickness for now
data = data.drop(columns="Uterus Thickness")

print(data.shape)
data.head()

(19906560, 8)


Unnamed: 0,Wave Int,SDD,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Hb Concentration,Fetal Saturation,Intensity
0,2.0,10,14.0,11.0,0.9,11.0,0.1,1.382111e-05
1,2.0,15,14.0,11.0,0.9,11.0,0.1,1.033645e-06
2,2.0,19,14.0,11.0,0.9,11.0,0.1,1.429687e-07
3,2.0,24,14.0,11.0,0.9,11.0,0.1,1.571002e-08
4,2.0,28,14.0,11.0,0.9,11.0,0.1,2.262073e-09


This typically takes 6 or 7 minutes to finish. It will put the data into the format of (number of samples, number of detectors * number of wavelengths + basic info).

Typically, basic info includes 5 things of maternal wall thickeness, maternal concentration, maternal saturation, fetal concentration, and fetal saturation. In other words, fetal depth, fetal and maternal concentrations, and fetal and maternal saturations.

(497664,45) for example run

In [23]:
data_transformer = LongToWideIntensityTransformation()
fitting_param_transformer = ToFittingParameterTransformation()


fitting_params = fitting_param_transformer.transform(data)
data = data_transformer.transform(data)
labels = data_transformer.get_label_names()
intensity_columns = data_transformer.get_feature_names()

print(data.shape)
data.head()

(497664, 45)


Unnamed: 0,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Hb Concentration,Fetal Saturation,10_2.0,15_2.0,19_2.0,24_2.0,28_2.0,...,55_1.0,59_1.0,64_1.0,68_1.0,72_1.0,77_1.0,81_1.0,86_1.0,90_1.0,94_1.0
0,4.0,11.0,0.9,10.45,0.1,1e-05,5.546016e-07,5.976895e-08,5.521151e-09,8.930319e-10,...,3.105321e-10,1.751115e-10,1.024528e-10,5.965252e-11,3.538759e-11,2.100689e-11,1.131124e-11,7.575306e-12,4.827689e-12,2.934189e-12
1,4.0,11.0,0.9,10.45,0.145455,1e-05,5.545093e-07,5.973192e-08,5.509972e-09,8.893708e-10,...,3.171048e-10,1.787795e-10,1.045478e-10,6.089661e-11,3.612448e-11,2.142066e-11,1.155682e-11,7.74093e-12,4.928378e-12,2.987896e-12
2,4.0,11.0,0.9,10.45,0.190909,1e-05,5.544174e-07,5.969527e-08,5.498937e-09,8.857842e-10,...,3.240113e-10,1.826264e-10,1.067464e-10,6.22013e-11,3.689672e-11,2.185401e-11,1.181459e-11,7.914433e-12,5.033968e-12,3.044133e-12
3,4.0,11.0,0.9,10.45,0.236364,1e-05,5.543305e-07,5.965898e-08,5.488053e-09,8.822532e-10,...,3.312745e-10,1.866678e-10,1.090565e-10,6.357169e-11,3.770683e-11,2.230864e-11,1.208545e-11,8.096426e-12,5.144832e-12,3.103095e-12
4,4.0,11.0,0.9,10.45,0.281818,1e-05,5.542409e-07,5.962281e-08,5.477293e-09,8.78797e-10,...,3.389246e-10,1.909189e-10,1.11489e-10,6.50129e-11,3.855791e-11,2.278616e-11,1.237057e-11,8.287542e-12,5.261411e-12,3.165011e-12


In [24]:
# Drop rows that contain NULL values
data.dropna(inplace=True)

print(data.shape)
data.head()

(497664, 45)


Unnamed: 0,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Hb Concentration,Fetal Saturation,10_2.0,15_2.0,19_2.0,24_2.0,28_2.0,...,55_1.0,59_1.0,64_1.0,68_1.0,72_1.0,77_1.0,81_1.0,86_1.0,90_1.0,94_1.0
0,4.0,11.0,0.9,10.45,0.1,1e-05,5.546016e-07,5.976895e-08,5.521151e-09,8.930319e-10,...,3.105321e-10,1.751115e-10,1.024528e-10,5.965252e-11,3.538759e-11,2.100689e-11,1.131124e-11,7.575306e-12,4.827689e-12,2.934189e-12
1,4.0,11.0,0.9,10.45,0.145455,1e-05,5.545093e-07,5.973192e-08,5.509972e-09,8.893708e-10,...,3.171048e-10,1.787795e-10,1.045478e-10,6.089661e-11,3.612448e-11,2.142066e-11,1.155682e-11,7.74093e-12,4.928378e-12,2.987896e-12
2,4.0,11.0,0.9,10.45,0.190909,1e-05,5.544174e-07,5.969527e-08,5.498937e-09,8.857842e-10,...,3.240113e-10,1.826264e-10,1.067464e-10,6.22013e-11,3.689672e-11,2.185401e-11,1.181459e-11,7.914433e-12,5.033968e-12,3.044133e-12
3,4.0,11.0,0.9,10.45,0.236364,1e-05,5.543305e-07,5.965898e-08,5.488053e-09,8.822532e-10,...,3.312745e-10,1.866678e-10,1.090565e-10,6.357169e-11,3.770683e-11,2.230864e-11,1.208545e-11,8.096426e-12,5.144832e-12,3.103095e-12
4,4.0,11.0,0.9,10.45,0.281818,1e-05,5.542409e-07,5.962281e-08,5.477293e-09,8.78797e-10,...,3.389246e-10,1.909189e-10,1.11489e-10,6.50129e-11,3.855791e-11,2.278616e-11,1.237057e-11,8.287542e-12,5.261411e-12,3.165011e-12


## Build Features

What is the fetal conc group?

In [25]:
fetal_conc_group_mapping = {
    10.45 : 0,
    10.88 : 0,
    11. : 0,
    11.31: 1,
    11.45: 1,
    11.55: 1,
    11.75: 2,
    11.91: 2,
    12.03: 2,
    12.18: 3,
    12.36: 3,
    12.5: 3,
    12.61: 4,
    12.82: 4,
    12.98: 4,
    13.04: 5,
    13.27: 5,
    13.46: 5,
    13.47: 6,
    13.73: 6,
    13.9: 6,
    13.94: 7,
    14.18: 7,
    14.34: 7,
    14.41: 8,
    14.64: 8,
    14.77: 8,
    14.89: 9,
    15.09: 9,
    15.2: 9,
    15.37: 10,
    15.55: 10,
    15.85: 10,
    16.: 11,
    16.32: 11,
    16.8: 11, 
}
data['FconcCenters'] = data['Fetal Hb Concentration'].round(2).map(fetal_conc_group_mapping)
print(data.shape)
data.head()

(497664, 46)


Unnamed: 0,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Hb Concentration,Fetal Saturation,10_2.0,15_2.0,19_2.0,24_2.0,28_2.0,...,59_1.0,64_1.0,68_1.0,72_1.0,77_1.0,81_1.0,86_1.0,90_1.0,94_1.0,FconcCenters
0,4.0,11.0,0.9,10.45,0.1,1e-05,5.546016e-07,5.976895e-08,5.521151e-09,8.930319e-10,...,1.751115e-10,1.024528e-10,5.965252e-11,3.538759e-11,2.100689e-11,1.131124e-11,7.575306e-12,4.827689e-12,2.934189e-12,0
1,4.0,11.0,0.9,10.45,0.145455,1e-05,5.545093e-07,5.973192e-08,5.509972e-09,8.893708e-10,...,1.787795e-10,1.045478e-10,6.089661e-11,3.612448e-11,2.142066e-11,1.155682e-11,7.74093e-12,4.928378e-12,2.987896e-12,0
2,4.0,11.0,0.9,10.45,0.190909,1e-05,5.544174e-07,5.969527e-08,5.498937e-09,8.857842e-10,...,1.826264e-10,1.067464e-10,6.22013e-11,3.689672e-11,2.185401e-11,1.181459e-11,7.914433e-12,5.033968e-12,3.044133e-12,0
3,4.0,11.0,0.9,10.45,0.236364,1e-05,5.543305e-07,5.965898e-08,5.488053e-09,8.822532e-10,...,1.866678e-10,1.090565e-10,6.357169e-11,3.770683e-11,2.230864e-11,1.208545e-11,8.096426e-12,5.144832e-12,3.103095e-12,0
4,4.0,11.0,0.9,10.45,0.281818,1e-05,5.542409e-07,5.962281e-08,5.477293e-09,8.78797e-10,...,1.909189e-10,1.11489e-10,6.50129e-11,3.855791e-11,2.278616e-11,1.237057e-11,8.287542e-12,5.261411e-12,3.165011e-12,0


In [29]:
print(data['FconcCenters'])

0          0
1          0
2          0
3          0
4          0
          ..
497659    11
497660    11
497661    11
497662    11
497663    11
Name: FconcCenters, Length: 497664, dtype: int64


Typically takes 10 or 11 minutes to run.

In [31]:
fb1 = FetalACbyDCFeatureBuilder('FconcCenters', 'perm', intensity_columns, labels, "max")
data = fb1(data)

print(data.shape)
data.head()

(995328, 47)


Unnamed: 0,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Saturation,FconcCenters,Fetal Hb Concentration 0,Fetal Hb Concentration 1,MAX_ACbyDC_WV1_0,MAX_ACbyDC_WV2_0,MAX_ACbyDC_WV1_1,...,MAX_ACbyDC_WV1_15,MAX_ACbyDC_WV2_15,MAX_ACbyDC_WV1_16,MAX_ACbyDC_WV2_16,MAX_ACbyDC_WV1_17,MAX_ACbyDC_WV2_17,MAX_ACbyDC_WV1_18,MAX_ACbyDC_WV2_18,MAX_ACbyDC_WV1_19,MAX_ACbyDC_WV2_19
0,4.0,11.0,0.9,0.1,0.0,10.45,10.881818,4.3e-05,0.000895,0.000597,...,0.015192,0.025316,0.014195,0.027839,0.017197,0.028094,0.017509,0.026767,0.022373,0.02353
1,4.0,11.0,0.9,0.1,0.0,10.45,11.0,5.4e-05,0.001128,0.000757,...,0.019226,0.031988,0.01796,0.035162,0.021767,0.035498,0.022154,0.03381,0.028258,0.029731
2,4.0,11.0,0.9,0.1,0.0,10.881818,10.45,-4.3e-05,-0.000895,-0.000597,...,-0.015192,-0.025316,-0.014195,-0.027839,-0.017197,-0.028094,-0.017509,-0.026767,-0.022373,-0.02353
3,4.0,11.0,0.9,0.1,0.0,10.881818,11.0,1.1e-05,0.000233,0.00016,...,0.004096,0.006846,0.003819,0.007533,0.00465,0.007618,0.004728,0.007237,0.00602,0.006351
4,4.0,11.0,0.9,0.1,0.0,11.0,10.45,-5.4e-05,-0.001128,-0.000757,...,-0.019226,-0.031988,-0.01796,-0.035162,-0.021767,-0.035498,-0.022154,-0.03381,-0.028258,-0.029731


In [32]:
labels = fb1.get_label_names()
print(labels)

['Maternal Wall Thickness', 'Maternal Hb Concentration', 'Maternal Saturation', 'Fetal Saturation', 'FconcCenters', 'Fetal Hb Concentration 0', 'Fetal Hb Concentration 1']


In [41]:
x_columns = fb1.get_feature_names()
y_columns = ["Fetal Saturation"]
print(x_columns)
print(len(x_columns))

['MAX_ACbyDC_WV1_0', 'MAX_ACbyDC_WV1_1', 'MAX_ACbyDC_WV1_2', 'MAX_ACbyDC_WV1_3', 'MAX_ACbyDC_WV1_4', 'MAX_ACbyDC_WV1_5', 'MAX_ACbyDC_WV1_6', 'MAX_ACbyDC_WV1_7', 'MAX_ACbyDC_WV1_8', 'MAX_ACbyDC_WV1_9', 'MAX_ACbyDC_WV1_10', 'MAX_ACbyDC_WV1_11', 'MAX_ACbyDC_WV1_12', 'MAX_ACbyDC_WV1_13', 'MAX_ACbyDC_WV1_14', 'MAX_ACbyDC_WV1_15', 'MAX_ACbyDC_WV1_16', 'MAX_ACbyDC_WV1_17', 'MAX_ACbyDC_WV1_18', 'MAX_ACbyDC_WV1_19', 'MAX_ACbyDC_WV2_0', 'MAX_ACbyDC_WV2_1', 'MAX_ACbyDC_WV2_2', 'MAX_ACbyDC_WV2_3', 'MAX_ACbyDC_WV2_4', 'MAX_ACbyDC_WV2_5', 'MAX_ACbyDC_WV2_6', 'MAX_ACbyDC_WV2_7', 'MAX_ACbyDC_WV2_8', 'MAX_ACbyDC_WV2_9', 'MAX_ACbyDC_WV2_10', 'MAX_ACbyDC_WV2_11', 'MAX_ACbyDC_WV2_12', 'MAX_ACbyDC_WV2_13', 'MAX_ACbyDC_WV2_14', 'MAX_ACbyDC_WV2_15', 'MAX_ACbyDC_WV2_16', 'MAX_ACbyDC_WV2_17', 'MAX_ACbyDC_WV2_18', 'MAX_ACbyDC_WV2_19']
40


In [34]:
## Scale y
y_scaler = preprocessing.StandardScaler()
data[y_columns] = y_scaler.fit_transform(data[y_columns])

## Scale x
x_scaler = preprocessing.StandardScaler()
data[x_columns] = x_scaler.fit_transform(data[x_columns])

StandardScaler()
StandardScaler()


In [35]:
IN_FEATURES = len(x_columns)
OUT_FEATURES = len(y_columns)
print("In Features :", IN_FEATURES)  
print("Out Features:", OUT_FEATURES)

In Features : 40
Out Features: 1


Stored to save time...

In [36]:
data.to_pickle('rishad_data.pkl')


## Load data just processed

In [37]:
data = pd.read_pickle('rishad_data.pkl')
data.head()

Unnamed: 0,Maternal Wall Thickness,Maternal Hb Concentration,Maternal Saturation,Fetal Saturation,FconcCenters,Fetal Hb Concentration 0,Fetal Hb Concentration 1,MAX_ACbyDC_WV1_0,MAX_ACbyDC_WV2_0,MAX_ACbyDC_WV1_1,...,MAX_ACbyDC_WV1_15,MAX_ACbyDC_WV2_15,MAX_ACbyDC_WV1_16,MAX_ACbyDC_WV2_16,MAX_ACbyDC_WV1_17,MAX_ACbyDC_WV2_17,MAX_ACbyDC_WV1_18,MAX_ACbyDC_WV2_18,MAX_ACbyDC_WV1_19,MAX_ACbyDC_WV2_19
0,4.0,11.0,0.9,-1.593255,0.0,10.45,10.881818,4.065299,4.941459,4.430259,...,2.037649,1.805695,1.300163,1.648643,0.802957,1.698323,1.521571,1.645864,1.362108,1.319674
1,4.0,11.0,0.9,-1.593255,0.0,10.45,11.0,5.11694,6.229803,5.61621,...,2.578698,2.281599,1.644967,2.082326,1.016351,2.145934,1.925202,2.078976,1.720388,1.667477
2,4.0,11.0,0.9,-1.593255,0.0,10.881818,10.45,-4.065299,-4.941459,-4.430259,...,-2.037649,-1.805695,-1.300163,-1.648643,-0.802957,-1.698323,-1.521571,-1.645864,-1.362108,-1.319674
3,4.0,11.0,0.9,-1.593255,0.0,10.881818,11.0,1.051686,1.289497,1.18666,...,0.549395,0.488265,0.34977,0.446102,0.217127,0.460549,0.410825,0.445024,0.366479,0.356184
4,4.0,11.0,0.9,-1.593255,0.0,11.0,10.45,-5.11694,-6.229803,-5.61621,...,-2.578698,-2.281599,-1.644967,-2.082326,-1.016351,-2.145934,-1.925202,-2.078976,-1.720388,-1.667477


In [43]:
x_columns = data.columns[7:].tolist()
y_columns = ["Fetal Saturation"]
print(x_columns)
print(len(x_columns))

['MAX_ACbyDC_WV1_0', 'MAX_ACbyDC_WV2_0', 'MAX_ACbyDC_WV1_1', 'MAX_ACbyDC_WV2_1', 'MAX_ACbyDC_WV1_2', 'MAX_ACbyDC_WV2_2', 'MAX_ACbyDC_WV1_3', 'MAX_ACbyDC_WV2_3', 'MAX_ACbyDC_WV1_4', 'MAX_ACbyDC_WV2_4', 'MAX_ACbyDC_WV1_5', 'MAX_ACbyDC_WV2_5', 'MAX_ACbyDC_WV1_6', 'MAX_ACbyDC_WV2_6', 'MAX_ACbyDC_WV1_7', 'MAX_ACbyDC_WV2_7', 'MAX_ACbyDC_WV1_8', 'MAX_ACbyDC_WV2_8', 'MAX_ACbyDC_WV1_9', 'MAX_ACbyDC_WV2_9', 'MAX_ACbyDC_WV1_10', 'MAX_ACbyDC_WV2_10', 'MAX_ACbyDC_WV1_11', 'MAX_ACbyDC_WV2_11', 'MAX_ACbyDC_WV1_12', 'MAX_ACbyDC_WV2_12', 'MAX_ACbyDC_WV1_13', 'MAX_ACbyDC_WV2_13', 'MAX_ACbyDC_WV1_14', 'MAX_ACbyDC_WV2_14', 'MAX_ACbyDC_WV1_15', 'MAX_ACbyDC_WV2_15', 'MAX_ACbyDC_WV1_16', 'MAX_ACbyDC_WV2_16', 'MAX_ACbyDC_WV1_17', 'MAX_ACbyDC_WV2_17', 'MAX_ACbyDC_WV1_18', 'MAX_ACbyDC_WV2_18', 'MAX_ACbyDC_WV1_19', 'MAX_ACbyDC_WV2_19']
40


## Create Model