In [40]:
!pip install pandas scikit-learn matplotlib seaborn



In [41]:
import pandas as pd
from sklearn.impute import KNNImputer
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import classification_report


In [42]:
finland_dti = pd.read_csv('/content/finland.native.dwi.tract.bundles.tissue.whole.voxel.harmz.map.dti_FA_mean.csv')
melb_dti = pd.read_csv('/content/melb.native.dwi.tract.bundles.tissue.whole.voxel.harmz.map.dti_FA_mean.csv')
ucla_dti = pd.read_csv('/content/ucla.native.dwi.tract.bundles.tissue.whole.voxel.harmz.map.dti_FA_mean.csv')
finland_labels = pd.read_csv('/content/pte.csv')

print(finland_dti.head())
print(melb_dti.head())
print(ucla_dti.head())
print(finland_labels.head())


          subject    id  tp  fimbria_left  vofb_left  slfc_right  ifof_right  \
0  Kuopio_2d_1008  1008  2d     -1.483774  -1.871713   -0.006680    0.525094   
1  Kuopio_2d_1010  1010  2d     -1.759198  -2.879283   -0.903091   -0.113726   
2  Kuopio_2d_1012  1012  2d     -1.201728  -1.432469   -0.422760    0.024351   
3  Kuopio_2d_1017  1017  2d     -0.581287  -0.272424   -0.422857    0.150489   
4  Kuopio_2d_1018  1018  2d     -1.196942  -2.552879   -0.499006    0.330303   

   slfc_left    cc_ant  cerped_left  ...  slfd_left  opticrad_left  chip_left  \
0  -1.509330  0.141561    -0.674018  ...        NaN            NaN        NaN   
1        NaN -0.824232    -1.054373  ...  -1.511104            NaN        NaN   
2  -0.674610 -0.348533    -0.745228  ...        NaN      -1.657730  -1.549084   
3  -0.753275 -0.679531    -0.145942  ...   0.105596      -0.603258  -0.235710   
4  -2.164975 -0.367490    -0.371646  ...  -1.012411      -3.063264  -3.085489   

   amyhypdor_left    cc_mid  vof

In [43]:
# append the Finnish labels to the main Finnish dataset
finland_merged = pd.merge(finland_dti, finland_labels, on='id')
print(finland_merged.head())

           subject    id   tp  fimbria_left  vofb_left  slfc_right  \
0   Kuopio_2d_1008  1008   2d     -1.483774  -1.871713   -0.006680   
1  Kuopio_30d_1008  1008  30d     -1.047364   0.390436   -0.788371   
2  Kuopio_5mo_1008  1008  5mo           NaN   3.138381   -0.099741   
3   Kuopio_9d_1008  1008   9d     -0.864658  -1.104843   -0.207132   
4   Kuopio_2d_1012  1012   2d     -1.201728  -1.432469   -0.422760   

   ifof_right  slfc_left    cc_ant  cerped_left  ...  opticrad_left  \
0    0.525094  -1.509330  0.141561    -0.674018  ...            NaN   
1   -0.376300  -1.166285  0.089390    -0.507800  ...      -1.329183   
2    0.844863   0.360937  0.098217     0.058027  ...            NaN   
3    0.174978  -0.888778  0.420657    -0.068937  ...      -2.130471   
4    0.024351  -0.674610 -0.348533    -0.745228  ...      -1.657730   

   chip_left  amyhypdor_left  cc_mid  vofa_left  thalsmed_right  cc_temp  \
0        NaN             NaN     NaN        NaN             NaN      NaN   


In [44]:
# congregate all the datasets into one
combined_data = pd.concat([finland_merged, melb_dti, ucla_dti], ignore_index=True)
print(combined_data.head())

           subject    id   tp  fimbria_left  vofb_left  slfc_right  \
0   Kuopio_2d_1008  1008   2d     -1.483774  -1.871713   -0.006680   
1  Kuopio_30d_1008  1008  30d     -1.047364   0.390436   -0.788371   
2  Kuopio_5mo_1008  1008  5mo           NaN   3.138381   -0.099741   
3   Kuopio_9d_1008  1008   9d     -0.864658  -1.104843   -0.207132   
4   Kuopio_2d_1012  1012   2d     -1.201728  -1.432469   -0.422760   

   ifof_right  slfc_left    cc_ant  cerped_left  ...  chip_left  \
0    0.525094  -1.509330  0.141561    -0.674018  ...        NaN   
1   -0.376300  -1.166285  0.089390    -0.507800  ...   0.708240   
2    0.844863   0.360937  0.098217     0.058027  ...   0.679290   
3    0.174978  -0.888778  0.420657    -0.068937  ...  -1.069747   
4    0.024351  -0.674610 -0.348533    -0.745228  ...  -1.549084   

   amyhypdor_left  cc_mid  vofa_left  thalsmed_right  cc_temp  \
0             NaN     NaN        NaN             NaN      NaN   
1             NaN     NaN        NaN       -1.

In [45]:
# standarize data for time point column
combined_data['tp'] = combined_data['tp'].replace({'2d': 'baseline', '1mo': '1_month'})
print(combined_data['tp'].unique())

['baseline' '30d' '5mo' '9d' '1_month']


In [46]:
# remove any columns missing more than 10% of data
threshold = 0.10 * len(combined_data)
combined_data = combined_data.dropna(thresh=threshold, axis=1)
print(combined_data.head())

           subject    id        tp  fimbria_left  vofb_left  slfc_right  \
0   Kuopio_2d_1008  1008  baseline     -1.483774  -1.871713   -0.006680   
1  Kuopio_30d_1008  1008       30d     -1.047364   0.390436   -0.788371   
2  Kuopio_5mo_1008  1008       5mo           NaN   3.138381   -0.099741   
3   Kuopio_9d_1008  1008        9d     -0.864658  -1.104843   -0.207132   
4   Kuopio_2d_1012  1012  baseline     -1.201728  -1.432469   -0.422760   

   ifof_right  slfc_left    cc_ant  cerped_left  ...  opticrad_left  \
0    0.525094  -1.509330  0.141561    -0.674018  ...            NaN   
1   -0.376300  -1.166285  0.089390    -0.507800  ...      -1.329183   
2    0.844863   0.360937  0.098217     0.058027  ...            NaN   
3    0.174978  -0.888778  0.420657    -0.068937  ...      -2.130471   
4    0.024351  -0.674610 -0.348533    -0.745228  ...      -1.657730   

   chip_left  amyhypdor_left  cc_mid  vofa_left  thalsmed_right  cc_temp  \
0        NaN             NaN     NaN        Na

In [47]:
# KNN imputation only works with numerical data, so split up data by categorical vs numerical. Apply KNN imputation to numerical, and mode imputation to categorical
numerical_cols = combined_data.select_dtypes(include=['float64', 'int64']).columns
categorical_cols = combined_data.select_dtypes(include=['object']).columns

imputer = KNNImputer(n_neighbors=5)
imputed_numerical_data = pd.DataFrame(imputer.fit_transform(combined_data[numerical_cols]),
                                      columns=numerical_cols)

imputed_categorical_data = combined_data[categorical_cols].fillna(combined_data[categorical_cols].mode().iloc[0])

imputed_data = pd.concat([imputed_numerical_data, imputed_categorical_data], axis=1)

print(imputed_data.head())


       id  fimbria_left  vofb_left  slfc_right  ifof_right  slfc_left  \
0  1008.0     -1.483774  -1.871713   -0.006680    0.525094  -1.509330   
1  1008.0     -1.047364   0.390436   -0.788371   -0.376300  -1.166285   
2  1008.0     -0.615552   3.138381   -0.099741    0.844863   0.360937   
3  1008.0     -0.864658  -1.104843   -0.207132    0.174978  -0.888778   
4  1012.0     -1.201728  -1.432469   -0.422760    0.024351  -0.674610   

     cc_ant  cerped_left  slfb_left  thalslat_right  ...  amyhypdor_left  \
0  0.141561    -0.674018  -0.968109       -1.051188  ...       -0.209276   
1  0.089390    -0.507800  -0.595976       -0.999030  ...       -0.209276   
2  0.098217     0.058027   0.073093       -0.073283  ...       -0.209276   
3  0.420657    -0.068937  -0.186715        0.125790  ...       -0.209276   
4 -0.348533    -0.745228  -0.641502       -0.838921  ...       -0.209276   

     cc_mid  vofa_left  thalsmed_right   cc_temp          subject        tp  \
0 -0.562352  -0.379304   

In [48]:
tp_mapping = {'baseline': 0, '9d': 9, '30d': 30, '1_month': 30, '5mo': 150}
imputed_data['tp'] = imputed_data['tp'].map(tp_mapping)

print(imputed_data['tp'].head())


0      0
1     30
2    150
3      9
4      0
Name: tp, dtype: int64


In [49]:
# prepare X and Y training dataframes - remove unneccessary columns from X and assign labels to Y
X = imputed_data.drop(['subject', 'group', 'status', 'site', 'id'], axis=1)
y = imputed_data['status'].fillna(imputed_data['group'])



print(X.head())
print(y.head())


   fimbria_left  vofb_left  slfc_right  ifof_right  slfc_left    cc_ant  \
0     -1.483774  -1.871713   -0.006680    0.525094  -1.509330  0.141561   
1     -1.047364   0.390436   -0.788371   -0.376300  -1.166285  0.089390   
2     -0.615552   3.138381   -0.099741    0.844863   0.360937  0.098217   
3     -0.864658  -1.104843   -0.207132    0.174978  -0.888778  0.420657   
4     -1.201728  -1.432469   -0.422760    0.024351  -0.674610 -0.348533   

   cerped_left  slfb_left  thalslat_right  cing_right  ...  ant_comm  \
0    -0.674018  -0.968109       -1.051188    0.111440  ...  0.432852   
1    -0.507800  -0.595976       -0.999030   -0.423590  ... -0.057724   
2     0.058027   0.073093       -0.073283   -0.061513  ...  0.633985   
3    -0.068937  -0.186715        0.125790   -0.092026  ...  0.405225   
4    -0.745228  -0.641502       -0.838921    0.003674  ... -0.227730   

   slfd_left  opticrad_left  chip_left  amyhypdor_left    cc_mid  vofa_left  \
0   0.279302      -1.678138  -0.51020

In [50]:
# train a logistic regression classifier and SVM model
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.metrics import classification_report

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

log_model = LogisticRegression(max_iter=1000)
log_model.fit(X_train, y_train)

log_predictions = log_model.predict(X_test)
print("Logistic Regression results:")
print(classification_report(y_test, log_predictions))

svm_model = SVC()
svm_model.fit(X_train, y_train)

svm_predictions = svm_model.predict(X_test)
print("SVM results:")
print(classification_report(y_test, svm_predictions))


Logistic Regression results:
              precision    recall  f1-score   support

         PTE       0.80      0.67      0.73         6
        Sham       1.00      0.21      0.35        14
         TBI       0.86      0.99      0.92        80

    accuracy                           0.86       100
   macro avg       0.89      0.62      0.67       100
weighted avg       0.87      0.86      0.83       100

SVM results:
              precision    recall  f1-score   support

         PTE       0.00      0.00      0.00         6
        Sham       0.00      0.00      0.00        14
         TBI       0.80      1.00      0.89        80

    accuracy                           0.80       100
   macro avg       0.27      0.33      0.30       100
weighted avg       0.64      0.80      0.71       100



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [52]:

print(y.value_counts())


status
TBI     418
Sham     43
PTE      36
Name: count, dtype: int64


In [59]:
'''
Logistic Regression:

- does well on tbi but is potentially overfitting
- sham predictions are all correct, but misses many sham entries
- pte is similar to sham, except less predictions correct + catches more pte entries

SVM:
- likely overfit to tbi
- poor for both sham and pte
- probably not the best model to use, focus on Logistic Regression / more complex models instead

Imbalance of TBI cases to Sham/PTE probably leads to overfitting, so we can try rerunning logistic regression / SVM with the 'balanced' keyword to account for the difference


'''

"\nLogistic Regression: \n\n- does well on tbi but is potentially overfitting \n- sham predictions are all correct, but misses many sham entries\n- pte is similar to sham, except less predictions correct + catches more pte entries\n\nSVM:\n- likely overfit to tbi\n- poor for both sham and pte\n- probably not the best model to use, focus on Logistic Regression / more complex models instead\n\nImbalance of TBI cases to Sham/PTE probably leads to overfitting, so we can try rerunning logistic regression / SVM with the 'balanced' keyword to account for the difference\n\n\n"

In [60]:
from imblearn.over_sampling import SMOTE

smote = SMOTE()
X_resampled, y_resampled = smote.fit_resample(X_train, y_train)

In [61]:
log_model_balanced = LogisticRegression(max_iter=1000, class_weight='balanced')
log_model_balanced.fit(X_resampled, y_resampled)

svm_model_balanced = SVC(class_weight='balanced')
svm_model_balanced.fit(X_resampled, y_resampled)


log_predictions = log_model.predict(X_test)
print("Logistic Regression results:")
print(classification_report(y_test, log_predictions))

svm_predictions = svm_model.predict(X_test)
print("SVM results:")
print(classification_report(y_test, svm_predictions))

Logistic Regression results:
              precision    recall  f1-score   support

         PTE       0.80      0.67      0.73         6
        Sham       1.00      0.21      0.35        14
         TBI       0.86      0.99      0.92        80

    accuracy                           0.86       100
   macro avg       0.89      0.62      0.67       100
weighted avg       0.87      0.86      0.83       100

SVM results:
              precision    recall  f1-score   support

         PTE       0.00      0.00      0.00         6
        Sham       0.00      0.00      0.00        14
         TBI       0.80      1.00      0.89        80

    accuracy                           0.80       100
   macro avg       0.27      0.33      0.30       100
weighted avg       0.64      0.80      0.71       100



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
