# Setup

In [None]:
import gc
import numpy as np
import os
import sklearn.metrics
import tensorflow as tf

import models
import util

SEED = 2021
np.random.seed(SEED)
tf.random.set_seed(SEED)

# Stay in top-level directory for consistency
if '/src' in os.getcwd():
    os.chdir('..')

In [None]:
# The following functions rely on the variables named xtrain, ytrain, etc in the environment.

def train_logreg(name):
    lr = models.build_logreg()
    models.fit_logreg(lr, xtrain, ytrain)
    models.save_pickle(lr, f'models/lr_{name}')
    print(lr.score(xtest, ytest))
    print(sklearn.metrics.classification_report(ytest, lr.predict(xtest), digits=4))
    del lr
    gc.collect()

def train_gb(name):
    gb = models.build_gbdt()
    models.fit_gbdt(gb, xtrain, ytrain, xval, yval)
    models.save_pickle(gb, f'models/gb_{name}')
    print(gb.score(xtest, ytest))
    print(sklearn.metrics.classification_report(ytest, gb.predict(xtest), digits=4))
    del gb
    gc.collect()

def train_selu(name):
    selu = models.build_NN_selu(input_len=xtrain.shape[1])  # Assuming xtrain is (batch_size, n_features)
    selu.summary()
    models.fit_NN_selu(selu, xtrain, ytrain, xval, yval)
    models.save_NN(selu, f'models/selu_{name}')
    selu.evaluate(xtest, ytest)  # Output is [loss, accuracy, auc]
    del selu
    gc.collect()

def train_lrelu(name):
    lrelu = models.build_NN_lrelu(input_len=xtrain.shape[1])  # Assuming xtrain is (batch_size, n_features)
    lrelu.summary()
    models.fit_NN_lrelu(lrelu, xtrain, ytrain, xval, yval)
    models.save_NN(lrelu, f'models/lrelu_{name}')
    lrelu.evaluate(xtest, ytest)  # Output is [loss, accuracy, auc]
    del lrelu
    gc.collect()

# Stillbirth

## Race-Aware

In [None]:
# Load data
xtrain, ytrain0, xtest, ytest0, xval, yval0 = util.load_preg_data_final(datafile='stillbirth')

### Early Stillbirth

In [None]:
# Convert class labels to binary labels
ytrain = util.outcome_to_binary(ytrain0, outcome='early stillbirth')  # Choose between early stillbirth, late stillbirth, and preterm
ytest = util.outcome_to_binary(ytest0, outcome='early stillbirth')
yval = util.outcome_to_binary(yval0, outcome='early stillbirth')
gc.collect()

In [None]:
train_logreg('early_aware')

In [None]:
train_gb('early_aware')

In [None]:
train_selu('early_aware')

In [None]:
train_lrelu('early_aware')

### Late Stillbirth

In [None]:
# Convert class labels to binary labels
ytrain = util.outcome_to_binary(ytrain0, outcome='late stillbirth')  # Choose between early stillbirth, late stillbirth, and preterm
ytest = util.outcome_to_binary(ytest0, outcome='late stillbirth')
yval = util.outcome_to_binary(yval0, outcome='late stillbirth')
gc.collect()

In [None]:
train_logreg('late_aware')

In [None]:
train_gb('late_aware')

In [None]:
train_selu('late_aware')

In [None]:
train_lrelu('late_aware')

## Race-Unaware

In [None]:
# Drop the 'race' columns
xtrain = xtrain.drop(columns=['race_AmeriIndian', 'race_AsianPI', 'race_Black', 'race_White'])
xval = xval.drop(columns=['race_AmeriIndian', 'race_AsianPI', 'race_Black', 'race_White'])
xtest = xtest.drop(columns=['race_AmeriIndian', 'race_AsianPI', 'race_Black', 'race_White'])

### Early Stillbirth

In [None]:
# Convert class labels to binary labels
ytrain = util.outcome_to_binary(ytrain0, outcome='early stillbirth')  # Choose between early stillbirth, late stillbirth, and preterm
ytest = util.outcome_to_binary(ytest0, outcome='early stillbirth')
yval = util.outcome_to_binary(yval0, outcome='early stillbirth')
gc.collect()

In [None]:
train_logreg('early_unaware')

In [None]:
train_gb('early_unaware')

In [None]:
train_selu('early_unaware')

In [None]:
train_lrelu('early_unaware')

### Late Stillbirth

In [None]:
# Convert class labels to binary labels
ytrain = util.outcome_to_binary(ytrain0, outcome='late stillbirth')  # Choose between early stillbirth, late stillbirth, and preterm
ytest = util.outcome_to_binary(ytest0, outcome='late stillbirth')
yval = util.outcome_to_binary(yval0, outcome='late stillbirth')
gc.collect()

In [None]:
train_logreg('late_unaware')

In [None]:
train_gb('late_unaware')

In [None]:
train_selu('late_unaware')

In [None]:
train_lrelu('late_unaware')

# Preterm Birth

## Race-Aware

In [None]:
# Load data
xtrain, ytrain0, xtest, ytest0, xval, yval0 = util.load_preg_data_final(datafile='preterm')
# Convert class labels to binary labels
ytrain = util.outcome_to_binary(ytrain0, outcome='preterm')  # Choose between early stillbirth, late stillbirth, and preterm
ytest = util.outcome_to_binary(ytest0, outcome='preterm')
yval = util.outcome_to_binary(yval0, outcome='preterm')
gc.collect()

In [None]:
train_logreg('preterm_aware')

In [None]:
train_gb('preterm_aware')

In [None]:
train_selu('preterm_aware')

In [None]:
train_lrelu('preterm_aware')

## Race-Unaware

In [None]:
# Drop the 'race' columns
xtrain = xtrain.drop(columns=['race_AmeriIndian', 'race_AsianPI', 'race_Black', 'race_White'])
xval = xval.drop(columns=['race_AmeriIndian', 'race_AsianPI', 'race_Black', 'race_White'])
xtest = xtest.drop(columns=['race_AmeriIndian', 'race_AsianPI', 'race_Black', 'race_White'])

In [None]:
train_logreg('preterm_unaware')

In [None]:
train_gb('preterm_unaware')

In [None]:
train_selu('preterm_unaware')

In [None]:
train_lrelu('preterm_unaware')