In [3]:
import sys
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  
#sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import train_test, transformers, classifiers

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import ParameterGrid, ParameterSampler
import datetime
import pandas as pd
import argparse
import random

from pathlib import Path


n_jobs = 1 # Get the value of n_jobs from the parsed arguments
# Get the number of inner and outer folds
k_out = 5
k_in = 5

# Get the current date and time in string format
time = datetime.datetime.now().strftime("%Y%m%d_%H%M")

# Create the output directory if it doesn't exist
model_type = "SVM"
output_dir = f"/Users/jsevere2/Documents/AML_PhD/leukem_ai/out/{model_type}/{time}"
os.makedirs(output_dir, exist_ok=True)
print(f"Output dir is {output_dir}")

# Load and prepare data
print("Loading and preparing data")

base_path = "/Users/jsevere2/Documents/AML_PhD/leukem_ai"
data_path = base_path + "/data"

X, y, study_labels = train_test.load_data(data_path)
X, y, study_labels = train_test.filter_data(X, y, study_labels, min_n = 20)
y, label_mapping = train_test.encode_labels(y)

# Define the model and parameter grid   
if model_type == "XGBOOST":
    model = classifiers.WeightedXGBClassifier
    param_grid = {
        'n_genes': [2000, 3000, 5000],
        'class_weight': [True],
        'max_depth': [2, 3, 5],
        'learning_rate': [0.05, 0.1],
        'n_estimators': [100, 200],
        'min_child_weight': [1, 3, 5],
        'gamma': [0, 0.1],
        'subsample': [0.8],
        'colsample_bytree': [0.8],
        'reg_alpha': [0, 0.1],
        'reg_lambda': [1.0]
    }
elif model_type == "SVM":
    from sklearn.svm import SVC
    model = SVC
    param_grid = {
        'n_genes': [1000, 2000, 3000],
        'C': [0.1, 1, 10, 100, 1000],  
        'gamma': ['auto', 'scale', 0.0001, 0.001, 0.01, 0.1],  
        'class_weight': ["balanced", None],
        'probability': [True]
    }
elif model_type == "NN":
    model = classifiers.NeuralNet
    param_grid = {
        'n_genes': [2000],
        'n_neurons':[
                    [800,400,100],
                    [400,200,50],
                    [200,100,25],
                    [800,400],
                    [400,200],
                    [200,100]
                    ],
        'use_batch_norm': [True, False],
        'dropout_rate': [0, 0.2,0.5], 
        'batch_size': [32],
        'patience': [2],
        'l2_reg': [0.001, 0],
        'class_weight': [True, False],
        'min_delta': [0.001],
        'learning_rate': [0.0001],
        'loss_function': ["standard", "focal"]
    }
else:
    raise ValueError(f"Model type {model_type} not supported")

# If needed downsample param_list
full_param_list = list(ParameterGrid(param_grid))

# Batch norm and dropout do not play nicely together, waste of compute
if model_type == "NN":
    full_param_list = [
        params for params in full_param_list
        if not (params['use_batch_norm'] and params['dropout_rate'] > 0)
    ]

# Downsample if needed
n_downsample = 1
if len(full_param_list) > n_downsample:
    param_list = random.sample(full_param_list, k=n_downsample)
else:
    param_list = full_param_list


Output dir is /Users/jsevere2/Documents/AML_PhD/leukem_ai/out/SVM/20250605_1426
Loading and preparing data


  studies_series: 2834
  X_df: (60660, 2834)
  y_series: 2834
  Studies: 2834
  X shape: (2834, 60660)
  y: 2834


  Studies: 2268
  X shape: (2268, 60660)
  y: 2268


In [4]:
# Define the pipeline
pipe = Pipeline([
    ('DEseq2', transformers.DESeq2RatioNormalizer()),
    ('feature_selection', transformers.FeatureSelection2()),
    ('scaler', StandardScaler())
])
print("Pipeline set up")

# Start the inner cross-validation process
print("Starting inner cross-validation process.")
# Iterate through different multiclass classification strategies
# standard: Uses the classifier's default multiclass handling
# OvO: One-vs-One strategy - trains binary classifier between each pair of classes
# OvR: One-vs-Rest strategy - trains binary classifier for each class against all others
if model_type == "NN":
    multi_types = ["standard"]
else:
    multi_types = ["standard", "OvO", "OvR"]

multi_types = ["standard", "OvO", "OvR"]

Pipeline set up
Starting inner cross-validation process.


In [5]:
fold_type = "CV"
if fold_type == "CV":
    for multi_type in multi_types:
        df = train_test.run_inner_cv(
            X, y, study_labels, model, param_list, n_jobs, pipe, 
            multi_type=multi_type, k_out=k_out, k_in=k_in,
            model_type = model_type
            )
        
        # Convert encoded labels back to original class names
        df = train_test.restore_labels(df, label_mapping)
        
        # Save results to CSV file with model type, strategy and timestamp
        df.to_csv(f"{output_dir}/{model_type}_inner_cv_{multi_type}_{time}.csv")   
elif fold_type == "loso":
    for multi_type in multi_types:
        df = train_test.run_inner_cv_loso(
            X, y, study_labels, model, param_list, n_jobs, pipe, 
            multi_type=multi_type,
            model_type = model_type
            )
        
        # Convert encoded labels back to original class names
        df = train_test.restore_labels(df, label_mapping)
        
        # Save results to CSV file with model type, strategy and timestamp
        df.to_csv(f"{output_dir}/{model_type}_inner_cv_loso_{multi_type}_{time}.csv")   
else:
    raise ValueError(f"Fold type {fold_type} not supported.")

print("Cross-validation process finished.")

[1000]
outer_fold
0
inner_fold
0




[   7    9   10   20   34   37   39   47   52   55   85   89   96  104
  106  114  121  127  141  147  151  156  157  166  174  176  179  181
  184  191  194  196  200  207  211  215  220  224  235  260  261  268
  275  278  280  283  286  288  295  301  307  313  325  330  336  349
  351  368  376  377  379  387  396  400  410  412  413  428  433  434
  437  461  462  466  475  477  483  484  485  489  491  496  507  509
  514  519  528  530  533  536  537  538  541  545  552  555  557  560
  565  572  574  583  596  597  599  609  617  618  620  621  622  627
  632  647  650  652  667  675  682  685  704  708  709  723  725  737
  739  759  764  773  781  785  790  791  793  795  810  813  826  827
  833  834  857  895  897  909  912  921  932  935  937  940  941  943
  945  948  951  954  959  961  963  967  971  973  974  980  984  999
 1006 1015 1025 1038 1048 1054 1057 1061 1064 1065 1068 1079 1084 1105
 1111 1129 1153 1158 1162 1169 1171 1173 1175 1178 1191 1212 1235 1236
 1238 



[   0    2    3    9   18   23   24   32   35   43   45   48   56   60
   65   70   74   77   80   89   95  116  117  125  133  145  146  149
  153  162  183  198  210  212  224  227  230  238  241  242  243  248
  252  259  273  279  285  286  294  300  305  320  324  327  334  347
  349  361  374  375  389  393  397  398  407  410  421  422  429  433
  437  439  442  443  445  450  457  467  469  470  476  491  495  499
  504  525  531  543  544  549  550  551  559  567  571  583  592  595
  596  603  607  609  613  623  625  631  633  643  658  664  665  669
  672  673  679  683  689  702  710  715  720  731  734  745  754  761
  779  785  792  810  814  815  831  836  844  845  846  852  855  860
  862  866  878  889  925  932  958  966  967  969  980  987  999 1000
 1006 1013 1016 1022 1028 1032 1038 1047 1058 1061 1064 1069 1075 1087
 1088 1090 1092 1107 1116 1131 1139 1142 1143 1147 1169 1182 1192 1199
 1201 1218 1239 1241 1249 1263 1264 1274 1275 1276 1277 1281 1283 1292
 1293 



[   0    2   11   23   25   28   40   41   55   58   60   64   75   77
   86  117  122  123  126  127  128  138  139  147  148  156  166  170
  181  190  193  195  213  218  219  225  226  234  237  242  243  246
  262  264  266  268  271  292  297  307  314  315  323  343  358  361
  375  376  380  383  391  394  396  397  398  401  402  406  415  423
  428  429  437  444  445  447  457  463  466  481  490  491  495  501
  503  512  516  517  518  528  530  535  538  540  541  549  554  555
  566  567  576  578  602  604  616  625  630  634  638  641  648  657
  666  668  672  682  690  697  700  704  711  714  721  724  727  730
  751  760  779  793  795  796  797  804  807  808  814  820  828  845
  869  872  880  888  889  902  905  910  911  914  917  921  925  928
  929  931  945  953  976  983  991  996 1001 1013 1036 1038 1059 1060
 1066 1075 1082 1086 1089 1091 1099 1109 1113 1135 1138 1142 1146 1168
 1177 1182 1193 1194 1227 1233 1243 1248 1250 1258 1261 1262 1263 1274
 1280 



[   0    2    3   10   16   19   23   24   38   43   49   53   54   71
   72   73   75   84   85  100  102  121  122  124  129  133  147  149
  153  162  163  179  182  217  222  227  228  230  241  243  248  252
  259  264  270  272  293  300  305  307  316  320  324  334  336  360
  364  370  383  388  390  392  395  398  402  408  409  415  422  436
  438  439  443  446  456  478  481  485  486  487  489  490  492  495
  501  521  525  530  533  538  540  542  548  554  566  573  577  584
  591  596  598  603  610  613  615  616  630  638  651  655  661  662
  682  684  692  695  700  702  707  709  710  728  730  739  756  757
  762  765  767  773  778  788  791  796  797  799  803  809  816  835
  845  850  875  876  893  904  910  925  928  932  938  941  958  961
  972  976  977  992  993 1002 1007 1033 1041 1057 1063 1066 1077 1079
 1081 1093 1098 1102 1107 1123 1124 1125 1137 1138 1154 1162 1164 1166
 1169 1181 1191 1194 1196 1197 1208 1209 1210 1214 1218 1220 1224 1233
 1240 



[   0    3    4   10   11   26   34   36   40   51   52   65   71   85
   86   90   93  100  102  109  128  129  140  143  149  151  157  159
  164  173  174  204  211  212  213  219  225  241  242  245  255  259
  262  270  273  275  279  283  292  294  316  320  325  352  355  356
  369  377  379  387  390  391  392  395  399  402  407  408  409  414
  429  440  458  465  471  474  484  485  489  495  502  505  509  521
  528  529  534  543  544  558  565  592  597  598  604  605  612  614
  618  642  647  657  666  667  683  684  698  702  716  727  732  735
  739  749  753  754  755  756  763  773  781  787  790  794  797  798
  799  800  801  809  814  828  839  842  853  855  859  861  862  863
  867  869  876  889  899  905  917  918  922  923  925  933  960  965
  973  989  990  993  997 1002 1008 1020 1026 1030 1044 1053 1059 1064
 1070 1084 1089 1100 1109 1119 1121 1139 1142 1147 1150 1159 1173 1175
 1177 1183 1189 1191 1193 1200 1203 1207 1212 1217 1221 1237 1239 1245
 1253 



inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
1




inner_fold
0
inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
2
inner_fold
0




inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
3
inner_fold
0




inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
4
inner_fold
0




inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
[1000]
outer_fold
0
inner_fold
0




inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
1




inner_fold
0
inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
2




inner_fold
0
inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
3




inner_fold
0
inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
outer_fold
4




inner_fold
0
inner_fold
1
inner_fold
2
inner_fold
3
inner_fold
4
Cross-validation process finished.
