In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import itertools
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader, Dataset, SubsetRandomSampler

from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, precision_score, recall_score, roc_auc_score, confusion_matrix
from sklearn.metrics import RocCurveDisplay
from sklearn.model_selection import KFold

import ms2_model
from ms2_model import Net256
from ms2_dataset import EmbedVec256Dataset

import matplotlib.pyplot as plt
import seaborn

In [3]:
torch.manual_seed(764)

<torch._C.Generator at 0x148f7f867250>

In [4]:
def balance_train_data(df, random_state):
  df_grp = df.groupby(["Metadata_MoA"])["Metadata_Compound"].count().reset_index(name="count")
  mean_count = int(df_grp.drop(df_grp[df_grp["Metadata_MoA"] == "DMSO"].index)["count"].mean().round())

  df_dmso = df[df["Metadata_MoA"] == "DMSO"].sample(n=mean_count, random_state=random_state)
  df_other = df.drop(df[df["Metadata_MoA"] == "DMSO"].index)

  df_all = pd.concat([df_other, df_dmso], axis=0)
  return df_all.reset_index(drop=True)

In [5]:
data_dir = "~/siads696/data"

random_state = 764
cv_splits = 5

learning_rate = 0.001
n_epochs = 14
batch_size = 32
chunk_print = 10

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = "cpu"
print(f"Using device: {device}")
# if device == "cuda:0":
#     torch.backends.cudnn.deterministic = True

Using device: cpu


In [6]:
df_data = pd.read_parquet(os.path.join(data_dir, "bbbc021_image_embed_batchcorr_256.parquet"))
# df_data = pd.read_parquet(os.path.join(data_dir, "well_grouped_256.parquet"))

print(f"Embedding vector dataset shape: {df_data.shape}")
print(f"Embedding vectors MoA assigned: {df_data[~df_data['Metadata_MoA'].isnull()].shape[0]}")

Embedding vector dataset shape: (13200, 265)
Embedding vectors MoA assigned: 6160


In [7]:
df_train_test = pd.read_csv(os.path.join(data_dir, "compound_moas_trainVtest.csv"))
df_train = df_train_test[~df_train_test["in_testset"]]
df_test = df_train_test[df_train_test["in_testset"]]

In [8]:
print(f"Train/test set shape: {df_train_test.shape}")
print(f"Training set shape: {df_train.shape}")
print(f"Test set shape: {df_test.shape}")
print(f"MoA in training set:\n {df_train['MoA'].unique().tolist()}")
print(f"MoA in test set:\n {df_test['MoA'].unique().tolist()}")
print(f"Compounds in training set:\n {df_train['Compound'].unique().tolist()}")
print(f"Compounds in test set:\n {df_test['Compound'].unique().tolist()}")

Train/test set shape: (39, 4)
Training set shape: (29, 4)
Test set shape: (10, 4)
MoA in training set:
 ['Actin disruptors', 'Aurora kinase inhibitors', 'Cholesterol-lowering', 'DMSO', 'DNA damage', 'DNA replication', 'Eg5 inhibitors', 'Epithelial', 'Kinase inhibitors', 'Microtubule destabilizers', 'Microtubule stabilizers', 'Protein degradation', 'Protein synthesis']
MoA in test set:
 ['Actin disruptors', 'Aurora kinase inhibitors', 'DNA damage', 'DNA replication', 'Epithelial', 'Kinase inhibitors', 'Microtubule destabilizers', 'Microtubule stabilizers', 'Protein degradation', 'Protein synthesis']
Compounds in training set:
 ['cytochalasin B', 'cytochalasin D', 'AZ-A', 'AZ258', 'mevinolin/lovastatin', 'simvastatin', 'DMSO', 'chlorambucil', 'cisplatin', 'etoposide', 'camptothecin', 'floxuridine', 'methotrexate', 'AZ-C', 'AZ138', 'AZ-J', 'AZ-U', 'PD-169316', 'alsterpaullone', 'colchicine', 'demecolcine', 'nocodazole', 'docetaxel', 'epothilone B', 'ALLN', 'MG-132', 'lactacystin', 'anisom

In [9]:
data_cols = [c for c in df_data.columns if c.startswith("PC")]

df_data_train = df_data.merge(df_train, left_on="Metadata_Compound", right_on="Compound", how="inner")
print(f"Training data shape: {df_data_train.shape}")
df_data_train = balance_train_data(df_data_train, random_state)
print(f"Training data balanced shape: {df_data_train.shape}")
# df_data_train = pd.concat([df_data_train, df_data_train], ignore_index=True)

# data_matrix = df_data_train[data_cols]
df_data_test = df_data.merge(df_test, left_on="Metadata_Compound", right_on="Compound", how="inner")
# data_matrix = df_data_test[data_cols]

Training data shape: (3944, 269)
Training data balanced shape: (2843, 269)


In [10]:
print(f"Test data shape: {df_data_test.shape}")
print(f"Total embedding vectors in training/test set: {df_data_train.shape[0]+df_data_test.shape[0]}")

Test data shape: (2216, 269)
Total embedding vectors in training/test set: 5059


In [11]:
moa_list = df_data[~df_data["Metadata_MoA"].isnull()].loc[:, "Metadata_MoA"].unique().tolist()
moa_dict = {moa: idx for moa, idx in zip(moa_list, range(len(moa_list)))}
n_classes = len(moa_dict.keys())

In [12]:
print(f"MoA label dictionary:\n{moa_dict}")
print(f"Number of classes: {n_classes}")

MoA label dictionary:
{'Protein degradation': 0, 'Kinase inhibitors': 1, 'Protein synthesis': 2, 'DNA replication': 3, 'DNA damage': 4, 'Microtubule destabilizers': 5, 'Actin disruptors': 6, 'Microtubule stabilizers': 7, 'Cholesterol-lowering': 8, 'Epithelial': 9, 'Eg5 inhibitors': 10, 'Aurora kinase inhibitors': 11, 'DMSO': 12}
Number of classes: 13


In [13]:
train_dataset = EmbedVec256Dataset(df_data_train, "Metadata_MoA", "PC", moa_dict)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = EmbedVec256Dataset(df_data_test, "Metadata_MoA", "PC", moa_dict)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [14]:
model = Net256(n_classes, 256, 52)

In [15]:
# using cross entropy loss for multiclass classification
loss_func = nn.CrossEntropyLoss()

# use Adam as optimizer for this NN
optimizer = torch.optim.Adam(model.parameters(), foreach=True, lr=learning_rate)



In [95]:
# print("Begin Training")

# for epoch in range(n_epochs):
#     running_loss = 0.0
#     loss, accuracy = ms2_model.train_model(model, optimizer, loss_func, train_dataloader)
#     print(f"Epoch {epoch+1}  Loss: {loss:>5f}  Accuracy: {accuracy:>5f}")
# print("Stop Training")

In [96]:
# y_label = df_data_test["Metadata_MoA"].map(moa_dict).tolist()

# avg_loss, yhat_label = ms2_model.test_model(model, loss_func, test_dataloader)

# print(avg_loss)
# print(accuracy_score(y_label, yhat_label))
# print(f1_score(y_label, yhat_label, average="weighted"))
# print(precision_score(y_label, yhat_label, average=None))
# print(recall_score(y_label, yhat_label, zero_division=np.nan, average=None))

In [97]:
def enumerate_params(param_dict):
    params = list(param_dict.keys())
    param_list = [param_dict[param] for param in params]
    param_combi =  itertools.product(*param_list) 
    
    return params, list(param_combi)

In [98]:
def do_kfold_crossval(cv_splits, dataset, random_state, parameters, verbose=True):
    kfold = KFold(n_splits=cv_splits, shuffle=True, random_state=random_state)
    cv_values = np.zeros((cv_splits, 4))
    
    learning_rate = parameters[0]
    batch_size = parameters[1]
    n_epochs = parameters[2]
    
    for fold ,(train_idx, valid_idx) in enumerate(kfold.split(np.arange(len(dataset)))):
        train_sampler = SubsetRandomSampler(train_idx)
        valid_sampler = SubsetRandomSampler(valid_idx)
        train_dataloader = DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
        valid_dataloader = DataLoader(dataset, batch_size=batch_size, sampler=valid_sampler)
    
        model = Net256(n_classes).to(device)
        optimizer = torch.optim.Adam(model.parameters(), foreach=True, lr=learning_rate)
        loss_func = nn.CrossEntropyLoss()
    
        cv_fold_values = np.zeros((n_epochs, 4))

        for epoch in range(n_epochs):
            train_loss, train_acc = ms2_model.train_model(model, optimizer, loss_func, train_dataloader, device)
            valid_loss, valid_acc = ms2_model.valid_model(model, loss_func, valid_dataloader, device)
            tl = train_loss / len(train_dataloader.sampler)
            ta = train_acc / len(train_dataloader.sampler)
            vl = valid_loss / len(valid_dataloader.sampler)
            va = valid_acc / len(valid_dataloader.sampler)
            cv_fold_values[epoch,0] = tl
            cv_fold_values[epoch,1] = ta
            cv_fold_values[epoch,2] = vl
            cv_fold_values[epoch,3] = va
        cv_values[fold,0] = cv_fold_values[-1,0]
        cv_values[fold,1] = cv_fold_values[-1,1]
        cv_values[fold,2] = cv_fold_values[-1,2]
        cv_values[fold,3] = cv_fold_values[-1,3]

        if verbose:
            print(f"Fold {fold+1}: Train loss: {tl:>5f}  Train accuracy: {ta:>5f}  Valid loss: {vl:>5f}  Valid accuracy: {va:>5f}")
    
    cv_mean_tl = cv_values[:,0].mean()
    cv_std_tl = cv_values[:,0].std()
    cv_mean_ta = cv_values[:,1].mean()
    cv_std_ta = cv_values[:,1].std()
    cv_mean_vl = cv_values[:,2].mean()
    cv_std_vl = cv_values[:,2].std()
    cv_mean_va = cv_values[:,3].mean()
    cv_std_va = cv_values[:,3].std()

    if verbose:
        print(f"Training - loss mean: {cv_mean_tl:>5f}  loss std: {cv_std_tl:>5f}")
        print(f"Training - accuracy mean: {cv_mean_ta:>5f}  accuracy std: {cv_std_ta:>5f}")
        print(f"Validation - loss mean: {cv_mean_vl:>5f}  loss std: {cv_std_vl:>5f}")
        print(f"Validation - accuracy mean: {cv_mean_va:>5f}  accuracy std: {cv_std_va:>5f}")

    return [cv_mean_tl, cv_std_tl, cv_mean_ta, cv_std_ta, cv_mean_vl, cv_std_vl, cv_mean_va, cv_std_va]

In [99]:
%%time
param_dict = {"learning_rate": [0.0001],
             "batch_size": [16],
             "epochs": [12]}

parameters, combinations = enumerate_params(param_dict)

gridsearch_records = list()
value_cols = ["trianing_loss mean", 
              "training_loss std", 
              "training_acc mean", 
              "training_acc std", 
              "validation_loss mean", 
              "validation_loss std", 
              "validaton_acc mean", 
              "validation_acc std"]
result_cols = parameters
result_cols.extend(value_cols)

for combination in combinations:
    print(f"learning rate: {combination[0]}  batch_size: {combination[1]}  epochs: {combination[2]}")
    grid_results = do_kfold_crossval(cv_splits, train_dataset, random_state, combination, verbose=True)
    combi_results = list(combination)
    combi_results.extend(grid_results)
    gridsearch_records.append(tuple(combi_results))

learning rate: 0.0001  batch_size: 8  epochs: 12
Fold 1: Train loss: 0.901042  Train accuracy: 0.798593  Valid loss: 0.910947  Valid accuracy: 0.778559
Fold 2: Train loss: 0.934162  Train accuracy: 0.799033  Valid loss: 0.791910  Valid accuracy: 0.824253
Fold 3: Train loss: 0.898099  Train accuracy: 0.802111  Valid loss: 0.872259  Valid accuracy: 0.785589
Fold 4: Train loss: 0.923677  Train accuracy: 0.795604  Valid loss: 0.933004  Valid accuracy: 0.783451
Fold 5: Train loss: 0.899785  Train accuracy: 0.806593  Valid loss: 0.852362  Valid accuracy: 0.804577
Training - loss mean: 0.911353  loss std: 0.014751
Training - accuracy mean: 0.800387  accuracy std: 0.003726
Validation - loss mean: 0.872097  loss std: 0.049068
Validation - accuracy mean: 0.795286  accuracy std: 0.016967
CPU times: user 1min 57s, sys: 427 ms, total: 1min 57s
Wall time: 1min 59s


In [100]:
df_gridsearch = pd.DataFrame.from_records(gridsearch_records, columns=result_cols)
# df_gridsearch.to_parquet(os.path.join(data_dir, "ann_Net256_cv_results.parquet"))
df_gridsearch

Unnamed: 0,learning_rate,batch_size,epochs,trianing_loss mean,training_loss std,training_acc mean,training_acc std,validation_loss mean,validation_loss std,validaton_acc mean,validation_acc std
0,0.0001,8,12,0.911353,0.014751,0.800387,0.003726,0.872097,0.049068,0.795286,0.016967


In [47]:
df_gridsearch = pd.DataFrame.from_records(gridsearch_records, columns=result_cols)
df_gridsearch

Unnamed: 0,learning_rate,batch_size,epochs,trianing_loss mean,training_loss std,training_acc mean,training_acc std,validation_loss mean,validation_loss std,validaton_acc mean,validation_acc std
0,0.05,32,12,0.232963,0.071239,0.928689,0.023166,0.671163,0.102071,0.832558,0.023664
1,0.05,32,14,0.217851,0.047445,0.935106,0.013504,0.680141,0.128442,0.847695,0.019913
2,0.01,32,12,0.076482,0.01462,0.97705,0.004677,0.620664,0.103051,0.856831,0.014711
3,0.01,32,14,0.060025,0.022297,0.981709,0.007958,0.599271,0.070952,0.857891,0.02138


In [139]:
class NNTest(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear_stack = nn.Sequential(
            nn.Linear(256, 65),
            nn.BatchNorm1d(65),
            nn.ReLU(),
            # nn.Linear(128, 39),
            # nn.BatchNorm1d(39),
            # nn.ReLU(),
            nn.Linear(65, 13),
        )

    def forward(self, x):
        logits = self.linear_stack(x)
        return logits

In [46]:
# trained_model = torch.load("../models/kc_nn_Net256.pt").to(device)
learning_rate = 0.0005
n_epochs = 4
batch_size = 8
chunk_print = 10
torch.manual_seed(764)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

test_model = Net256(n_classes, 256, 52)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(test_model.parameters(), foreach=True, lr=learning_rate)

# print("Begin training for Test")
for epoch in range(n_epochs):
    # running_loss = 0.0
    loss, accuracy = ms2_model.train_model(test_model, optimizer, loss_func, train_dataloader, device)
train_loss = loss / len(train_dataloader.dataset)
train_acc = accuracy / len(train_dataloader.dataset)
print(f"Loss: {train_loss:>5f}  Accuracy: {train_acc:>5f}")
# print("End training for test")
# torch.save(test_model, "../models/kc_nn_Net256_test.pt")

Loss: 0.618468  Accuracy: 0.845234


In [47]:
y_label = df_data_test["Metadata_MoA"].map(moa_dict).tolist()
# loss_func = nn.CrossEntropyLoss()
# print("Start testing")
test_loss, yhat_label = ms2_model.test_model(test_model, loss_func, test_dataloader, device)

print(test_loss / len(test_dataloader.dataset))
print(accuracy_score(y_label, yhat_label))

0.9815846410284108
0.7802346570397112


In [156]:
trained_model = torch.load("../models/kc_nn_Net256.pt").to(device)
test_loss, yhat_label = ms2_model.test_model(trained_model, loss_func, test_dataloader, device)

print(test_loss / len(test_dataloader.dataset))
print(accuracy_score(y_label, yhat_label))

7.32508356390447
0.032490974729241874
