In [1]:
import pandas as pd

In [11]:
dat = pd.read_csv('../data/dataset_mock_final.csv', sep=';')

In [12]:
dat.head()

Unnamed: 0,date,severity,mortality_ratio,age,num_proc,ambulatory,origin,expected_length,tip_grd,tip_adm,exitus,dataset
0,2016-07,,0.001193,15603.0,4.0,,,7.0,M,1.0,0,train
1,2016-05,1.0,0.0,14285.0,3.0,,1.0,,M,1.0,0,train
2,2016-01,,0.0,6046.0,2.0,,,2.0,,1.0,0,train
3,2016-01,1.0,0.00406,27340.0,4.0,,2.0,9.0,Q,,0,train
4,2016-05,2.0,0.028365,28685.0,10.0,0.0,,9.0,M,1.0,0,train


In [13]:
dat.drop('date', axis = 1, inplace = True)

In [14]:
cat_var = ['severity', 'ambulatory', 'origin', 'tip_grd', 'tip_adm']
non_cat_var = list(set(dat.columns) - set(cat_var))
num_var = list(set(dat.columns) - set(cat_var) - {'dataset', 'exitus'})

In [15]:
dat.isna().any()

severity            True
mortality_ratio     True
age                 True
num_proc            True
ambulatory          True
origin              True
expected_length     True
tip_grd             True
tip_adm             True
exitus             False
dataset            False
dtype: bool

In [16]:
from sklearn.ensemble import RandomForestRegressor
from fancyimpute import IterativeImputer as MICE

# 3) Define "model"
model = MICE(estimator=RandomForestRegressor())

# 4) Train "model"
model.fit(dat[num_var][dat['dataset'] == 'train'])

# 5) "Predict"
dat[num_var] = model.transform(dat[num_var])
dat.isna().any()



severity            True
mortality_ratio    False
age                False
num_proc           False
ambulatory          True
origin              True
expected_length    False
tip_grd             True
tip_adm             True
exitus             False
dataset            False
dtype: bool

In [17]:
dat[cat_var] = dat[cat_var].astype('str')

In [18]:
dat.loc[dat['dataset'] == 'train', cat_var] = dat.loc[dat['dataset'] == 'train', cat_var].fillna('UNKNOWN')
dat[cat_var][dat['dataset'] == 'train'].isna().sum()

severity      0
ambulatory    0
origin        0
tip_grd       0
tip_adm       0
dtype: int64

In [19]:
dat.isna().any()

severity           False
mortality_ratio    False
age                False
num_proc           False
ambulatory         False
origin             False
expected_length    False
tip_grd            False
tip_adm            False
exitus             False
dataset            False
dtype: bool

In [30]:
from sklearn.preprocessing import OneHotEncoder

ohe = OneHotEncoder(sparse_output = False, drop='first')

# 4) Training model
ohe.fit(dat[cat_var][dat['dataset'] == 'train'])

# 5) Predicting
dat_ohe = pd.DataFrame(ohe.fit_transform(dat[cat_var]))

# Optional
dat_ohe.columns = ohe.get_feature_names_out()
dat = pd.concat((dat[non_cat_var], dat_ohe), axis=1)

In [31]:
100*dat.groupby(['exitus'])['exitus'].agg(['count'])/dat.shape[0]

Unnamed: 0_level_0,count
exitus,Unnamed: 1_level_1
0,96.235664
1,3.764336


In [43]:
def compute_sampling_strategy(frac_minority, minority_count, majority_count):
    synthetic_samples = (frac_minority * majority_count - (1 - frac_minority) * minority_count) / (1 - frac_minority)
    strategy = (minority_count + synthetic_samples) / majority_count
    return strategy

# Assume you have counts for your classes
minority_count = sum(dat['exitus'] == 1)
majority_count = sum(dat['exitus'] == 0)

# For a 10-90 split:
fraction = 0.1
sampling_value = compute_sampling_strategy(fraction, minority_count, majority_count)
print(f"For a {fraction*100}% minority class after oversampling, set sampling_strategy to {sampling_value:.2f} in SMOTE.")

For a 10.0% minority class after oversampling, set sampling_strategy to 0.11 in SMOTE.


In [44]:
from imblearn.over_sampling import SMOTE
sm = SMOTE(sampling_strategy =sampling_value,
           random_state = 0,
           k_neighbors = 5)

X_res, y_res = sm.fit_resample(dat.drop(['exitus', 'dataset'], axis = 1), dat['exitus'])

X_res['exitus'] = y_res

X_res['dataset'] = 'train'

dat_new = pd.concat([X_res, dat[dat['dataset'] == 'val'], dat[dat['dataset'] == 'test']])

# Checking the class distribution after SMOTE
100*X_res.exitus.value_counts()/X_res.shape[0]

exitus
0    90.000865
1     9.999135
Name: count, dtype: float64