In [1]:
import numpy as np
import pandas as pd
from sklearn.utils import shuffle
from sklearn.linear_model import LogisticRegressionCV
from sklearn.metrics import classification_report

### Data from **Machine Learning Identifies Candidates for Drug Repurposing in Alzheimer's Disease**

https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE164788

Will be used as inputs for the trained model. Most differentially expressed gene lists will be compared to randomly generated gene lists.

In [2]:
data_treatments = pd.read_csv("GSE164788_normalized_counts.csv")

In [3]:
data_treatments

Unnamed: 0,gene_name,sample,count
0,A1BG,dge1_A01,4.204721
1,A1CF,dge1_A01,0.000000
2,A2M,dge1_A01,0.000000
3,A2ML1,dge1_A01,0.840944
4,A4GALT,dge1_A01,0.000000
...,...,...,...
12061867,ZXDC,dge2_P24,1.718219
12061868,ZYG11A,dge2_P24,3.436439
12061869,ZYG11B,dge2_P24,5.154658
12061870,ZYX,dge2_P24,6.872878


Features: gene names, samples (drug treatment and concentration), normalized RNA counts from RNA-seq

### Data from **Expression data from post mortem Alzheimer's disease brains**

https://www.ncbi.nlm.nih.gov/geo/query/acc.cgi?acc=GSE36980

GEO data imported through pipeline in R -> saved as csv files -> continued processing in Python

In [5]:
data_ad = pd.read_csv("data_ad.csv")
data_non_ad = pd.read_csv("data_non_ad.csv")
data_platform = pd.read_csv("data_platform.csv")

In [6]:
data_platform = data_platform[["ID", "gene_assignment"]]

In [7]:
gene_info = data_platform['gene_assignment'].str.split(' // ', 1, expand=True)[[1]]

In [8]:
gene = gene_info[1].str.split(' // ', 1, expand=True)[[0]]

In [9]:
data_platform = pd.concat([data_platform[["ID"]], gene], axis=1).dropna()

In [10]:
data_platform.columns = ["ID_REF", "gene_name"]

In [11]:
data_platform

Unnamed: 0,ID_REF,gene_name
1,7896738,OR4G2P
2,7896740,OR4F4
3,7896742,LOC728323
4,7896744,OR4F29
5,7896746,MT-TM
...,...,...
28864,8180111,RXRB
28865,8180123,VPS52
28866,8180144,RGL2
28867,8180166,TAPBP


In [12]:
data_ad = data_platform.merge(data_ad, on = "ID_REF", how = "inner").drop(["ID_REF"], axis = 1).sort_values("gene_name")

In [13]:
data_ad_genes = data_ad["gene_name"].tolist()

In [14]:
data_ad = data_ad.drop("gene_name", axis = 1)

In [15]:
data_ad.index = data_ad_genes

In [16]:
data_ad = data_ad.transpose()

In [17]:
data_ad[["AD"]] = 1

In [18]:
data_ad

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZXDA,ZXDB,ZXDC,ZXDC.1,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,AD
792,7.96624,5.81419,9.70892,7.16942,6.51408,7.48697,5.80606,8.65507,9.46209,3.80073,...,9.62911,8.48052,8.98204,6.84168,6.01203,10.1674,9.80742,9.04197,9.5493,1
793,7.78709,6.05618,9.97573,7.43889,6.51426,7.70994,5.60887,8.5725,9.18525,3.37144,...,9.79006,8.52298,8.98228,6.73092,5.6583,9.81976,9.73623,8.80349,9.39174,1
794,7.81763,5.84426,10.9047,7.34895,6.65532,7.55734,5.56829,8.73213,9.40789,3.25177,...,9.35647,8.52178,8.90193,6.50045,5.59253,10.2355,9.76895,8.83038,9.38623,1
795,8.13618,6.04908,10.0327,6.63615,6.9053,7.94987,5.85624,8.35117,9.52586,3.48771,...,9.92919,8.60574,9.00606,6.96981,5.9437,9.91802,9.91226,8.80348,9.31199,1
796,8.07354,5.83877,11.3582,7.05852,6.71011,7.70073,5.66704,8.24658,9.15813,3.63294,...,9.74196,8.58287,8.99106,6.93115,5.55914,10.1234,10.0026,8.74404,9.38609,1
797,7.89438,6.05758,11.2578,7.25293,6.49411,7.84509,5.88339,8.57513,8.80024,3.7934,...,9.55649,8.15799,9.03498,6.85667,6.03614,9.52649,9.54855,8.7074,9.25519,1
798,7.94045,5.87928,11.0025,6.91588,6.72065,7.69354,5.63377,8.59376,9.47513,3.55845,...,9.55927,8.45047,9.08179,6.99131,5.91537,9.7016,9.813,8.97597,9.16524,1
799,7.96769,5.73791,10.6846,7.04073,6.52764,7.52497,5.63406,8.36354,9.43966,3.06018,...,9.57313,8.59194,8.99513,6.93191,5.85266,10.0987,9.77681,8.98655,9.39685,1
800,7.8263,5.76587,10.9889,7.3622,6.54795,7.5139,5.44143,8.52656,9.36753,3.26465,...,9.11827,8.39342,9.13759,6.89274,5.58874,10.0459,9.59356,8.93453,9.44252,1
801,7.92924,5.82255,9.11586,6.70876,6.7341,7.72401,5.63236,8.61012,9.39841,3.4618,...,9.71504,8.53373,8.87045,6.71998,5.684,10.1767,9.44181,8.8696,9.31358,1


Features: normalized counts from Affymetrix microarray, row names are sample IDs of AD tissue, columns are gene names

In [19]:
data_non_ad = data_platform.merge(data_non_ad, on = "ID_REF", how = "inner").drop(["ID_REF"], axis = 1).sort_values("gene_name")

In [20]:
data_non_ad_genes = data_non_ad["gene_name"].tolist()

In [21]:
data_non_ad = data_non_ad.drop("gene_name", axis = 1)

In [22]:
data_non_ad.index = data_non_ad_genes

In [23]:
data_non_ad = data_non_ad.transpose()

In [24]:
data_non_ad[["AD"]] = 0

In [25]:
data_non_ad

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZXDA,ZXDB,ZXDC,ZXDC.1,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,AD
807,8.02615,5.74977,10.7661,7.02421,6.52383,7.49452,5.78059,8.5646,9.30087,3.67653,...,9.82848,8.56778,8.91728,6.85554,5.70159,10.2985,9.63687,8.84568,9.48304,0
808,8.13425,6.00997,11.0566,7.2891,6.79921,7.68277,5.82445,8.75197,9.48541,3.53126,...,9.88864,8.71116,9.13661,6.99588,5.80611,9.93258,9.70307,9.10864,9.43455,0
809,8.13383,5.84701,10.3378,7.06735,6.65486,7.73712,5.7371,8.72561,9.39478,3.29071,...,10.0911,8.75463,8.97811,6.92161,5.90811,10.0605,9.86397,8.83065,9.35658,0
810,8.1405,5.69996,10.231,7.11331,6.75592,7.57879,5.93671,8.52862,9.47191,3.66251,...,9.79375,8.50611,9.08523,6.82001,5.89166,10.2839,9.70793,8.93402,9.5653,0
811,7.78678,5.79046,10.572,6.47287,6.52379,7.27927,5.51362,8.40032,9.36997,3.18567,...,9.55614,8.65715,8.92187,6.74308,5.70151,10.3732,9.38336,8.809,9.71948,0
812,7.97243,5.8645,11.1779,7.17213,6.68574,7.51031,5.6259,8.52586,9.77563,3.8146,...,9.49517,8.41368,9.08575,7.01267,5.99378,9.92172,9.73643,8.89044,9.19186,0
813,7.76442,6.08599,10.2614,7.11016,6.55102,7.60013,5.59712,8.53054,9.59434,3.45778,...,9.30092,8.4961,8.88493,6.95741,5.9534,10.3064,9.74294,8.91555,9.36529,0
814,7.99574,6.0239,10.2982,6.81346,6.70091,7.65518,5.74789,8.46447,9.46181,3.84386,...,9.79907,8.83179,8.81901,6.72516,5.87278,10.3745,9.6358,8.78451,9.3054,0
815,7.78947,5.90262,11.871,6.92067,6.75427,7.6544,5.74288,8.56502,9.18488,3.35018,...,9.56516,8.48031,9.10166,6.34041,5.67535,9.92013,10.4288,8.70833,9.33149,0
816,7.75073,5.84965,10.3017,7.0234,6.65364,7.44928,5.45419,8.545,9.48812,3.355,...,9.58885,8.67606,8.93337,6.70914,5.97534,10.2419,9.83267,8.91534,9.23855,0


Features: normalized counts from Affymetrix microarray, row names are sample IDs of non-AD tissue, columns are gene names

In [36]:
data = pd.concat([data_ad, data_non_ad])

In [37]:
data = shuffle(data)

In [38]:
X = data.drop("AD", axis = 1)

In [39]:
y = data[["AD"]].values.flatten()

In [40]:
data

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZXDA,ZXDB,ZXDC,ZXDC.1,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3,AD
812,7.97243,5.86450,11.17790,7.17213,6.68574,7.51031,5.62590,8.52586,9.77563,3.81460,...,9.49517,8.41368,9.08575,7.01267,5.99378,9.92172,9.73643,8.89044,9.19186,0
823,7.75861,5.90429,8.99811,6.38328,6.74589,7.91683,5.70141,8.60434,9.28019,3.60985,...,9.56863,8.89361,9.02360,6.67531,5.66780,10.18030,9.83523,8.66083,9.41963,0
849,7.92667,5.81062,10.46130,7.08007,6.75429,7.37385,5.78590,8.76387,9.31409,2.98598,...,9.61889,8.72223,9.57071,7.00156,5.50150,10.32230,9.32383,8.91440,9.58261,0
836,7.74680,5.63290,10.77490,7.24241,6.71802,7.50321,5.48292,8.80156,9.40642,2.95792,...,9.45014,8.38797,9.23303,6.83409,5.44034,10.52100,9.54982,8.89388,9.74650,0
819,7.79413,5.82167,11.05090,6.73459,6.35388,7.61721,5.45341,8.55889,9.45785,3.74217,...,9.54045,8.54850,8.88678,6.67159,5.74930,10.22050,9.95074,9.09799,9.72213,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
818,7.70396,5.72859,11.04700,7.36176,6.55526,7.59661,5.41214,8.71350,9.33606,3.78922,...,9.67071,8.41037,9.01571,6.88975,5.83852,10.30090,9.79697,9.01822,9.48706,0
817,7.82954,5.70127,11.08630,7.28409,6.26765,7.54374,5.38357,8.61731,9.54965,3.18586,...,9.28835,8.55236,8.99323,6.85139,5.64595,10.23380,9.91210,9.06998,9.41443,0
847,7.81142,5.79635,10.38380,6.53037,6.97836,7.62315,5.80709,8.95384,9.58933,2.97836,...,9.65624,8.53238,9.29357,6.79219,5.46777,10.26770,9.81958,9.17279,9.64197,0
832,7.80219,5.86953,11.35590,6.89824,6.81467,7.83714,5.68698,8.84148,9.24835,3.08457,...,9.49469,8.64804,9.31855,6.95699,5.51119,10.25420,9.61903,9.00393,9.33654,1


In [41]:
X_train = X.iloc[15:, :]
X_test = X.iloc[:15, :]

y_train = y[15:]
y_test = y[:15]

In [42]:
X_train

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZWINT,ZXDA,ZXDB,ZXDC,ZXDC.1,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3
801,7.92924,5.82255,9.11586,6.70876,6.73410,7.72401,5.63236,8.61012,9.39841,3.46180,...,7.22634,9.71504,8.53373,8.87045,6.71998,5.68400,10.17670,9.44181,8.86960,9.31358
833,7.92100,5.72441,10.44290,6.52614,6.62071,7.61405,5.81759,8.57349,9.13741,2.82196,...,6.89277,9.87911,8.78080,9.14828,6.85675,5.51991,10.40610,9.47958,8.95714,9.71963
826,7.55784,5.77249,11.13160,7.23214,6.49752,7.64233,5.65273,8.94328,9.43865,3.03265,...,6.95936,9.42565,8.97898,9.22223,6.82039,5.64283,10.70340,9.50480,8.96762,9.92429
856,7.66434,6.21755,10.84400,7.09940,6.41047,7.38079,5.95650,8.96273,8.61018,3.59680,...,7.09434,9.26256,8.75391,9.44916,7.25978,5.82429,9.68041,9.07117,8.97729,9.81956
858,7.69497,6.13298,10.30670,7.02095,6.56698,7.62902,5.59535,8.39134,7.78271,3.86024,...,7.03825,9.30355,8.74910,9.06021,7.05573,5.43965,10.17190,9.49717,9.10534,9.61384
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
818,7.70396,5.72859,11.04700,7.36176,6.55526,7.59661,5.41214,8.71350,9.33606,3.78922,...,7.12195,9.67071,8.41037,9.01571,6.88975,5.83852,10.30090,9.79697,9.01822,9.48706
817,7.82954,5.70127,11.08630,7.28409,6.26765,7.54374,5.38357,8.61731,9.54965,3.18586,...,7.23461,9.28835,8.55236,8.99323,6.85139,5.64595,10.23380,9.91210,9.06998,9.41443
847,7.81142,5.79635,10.38380,6.53037,6.97836,7.62315,5.80709,8.95384,9.58933,2.97836,...,7.29903,9.65624,8.53238,9.29357,6.79219,5.46777,10.26770,9.81958,9.17279,9.64197
832,7.80219,5.86953,11.35590,6.89824,6.81467,7.83714,5.68698,8.84148,9.24835,3.08457,...,7.04663,9.49469,8.64804,9.31855,6.95699,5.51119,10.25420,9.61903,9.00393,9.33654


In [43]:
X_test

Unnamed: 0,A1BG,A1CF,A2M,A2ML1,A3GALT2,A4GALT,A4GNT,AAAS,AACS,AADAC,...,ZWINT,ZXDA,ZXDB,ZXDC,ZXDC.1,ZYG11A,ZYG11B,ZYX,ZZEF1,ZZZ3
812,7.97243,5.8645,11.1779,7.17213,6.68574,7.51031,5.6259,8.52586,9.77563,3.8146,...,7.27065,9.49517,8.41368,9.08575,7.01267,5.99378,9.92172,9.73643,8.89044,9.19186
823,7.75861,5.90429,8.99811,6.38328,6.74589,7.91683,5.70141,8.60434,9.28019,3.60985,...,7.19372,9.56863,8.89361,9.0236,6.67531,5.6678,10.1803,9.83523,8.66083,9.41963
849,7.92667,5.81062,10.4613,7.08007,6.75429,7.37385,5.7859,8.76387,9.31409,2.98598,...,7.07002,9.61889,8.72223,9.57071,7.00156,5.5015,10.3223,9.32383,8.9144,9.58261
836,7.7468,5.6329,10.7749,7.24241,6.71802,7.50321,5.48292,8.80156,9.40642,2.95792,...,7.00744,9.45014,8.38797,9.23303,6.83409,5.44034,10.521,9.54982,8.89388,9.7465
819,7.79413,5.82167,11.0509,6.73459,6.35388,7.61721,5.45341,8.55889,9.45785,3.74217,...,7.18326,9.54045,8.5485,8.88678,6.67159,5.7493,10.2205,9.95074,9.09799,9.72213
800,7.8263,5.76587,10.9889,7.3622,6.54795,7.5139,5.44143,8.52656,9.36753,3.26465,...,7.14124,9.11827,8.39342,9.13759,6.89274,5.58874,10.0459,9.59356,8.93453,9.44252
821,7.76722,5.58866,11.3484,6.80574,6.67476,7.56103,5.35727,8.37129,9.41421,3.12052,...,7.16346,9.32228,8.77451,8.99542,6.64253,5.80383,9.9117,10.1245,9.04579,9.1517
857,7.94139,6.39821,10.977,7.98247,6.64475,7.91133,5.71189,8.73875,8.70071,3.76634,...,7.17552,9.93567,8.58789,9.39396,7.28761,5.84564,9.53175,9.26043,9.01381,9.60276
850,7.71957,5.8368,10.3368,6.93213,6.60016,7.53139,5.71384,8.51584,9.38061,3.37167,...,7.00744,9.4802,8.78541,9.25338,6.82307,5.52976,10.4377,9.51859,8.82372,9.82282
863,7.5707,5.96271,10.7944,6.27664,6.56874,7.48597,5.70317,8.63163,9.43466,3.25902,...,7.11466,9.10395,8.70903,9.15704,7.16399,5.82269,10.6067,8.80216,8.92054,9.48334


In [44]:
y_train

array([1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 1, 1,
       0, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1],
      dtype=int64)

In [45]:
y_test

array([0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 1, 0], dtype=int64)

### Logistic regression with cross-validation

In [46]:
clf = LogisticRegressionCV(cv=5, random_state=0, max_iter=1000).fit(X_train, y_train)

In [47]:
print(classification_report(y_test, clf.predict(X_test)))

              precision    recall  f1-score   support

           0       1.00      0.91      0.95        11
           1       0.80      1.00      0.89         4

    accuracy                           0.93        15
   macro avg       0.90      0.95      0.92        15
weighted avg       0.95      0.93      0.94        15

