In [1]:
import random

import numpy as np
import pandas as pd
import torch
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

In [2]:
%load_ext autoreload
%autoreload 2

from DeepDSC.DeepDSC import *

In [5]:
def main(PATH):
    normalized_gene_exp_tensor, gene_exp = prepare_data(
        data1=PATH + "/gene_exp_part1.csv.gz", data2=PATH + "gene_exp_part2.csv.gz"
    )
    normalized_gene_exp_dataset = GeneExpressionDataset(normalized_gene_exp_tensor)
    normalized_gene_exp_dataloader = DataLoader(
        normalized_gene_exp_dataset, batch_size=10000, shuffle=True
    )

    # オートエンコーダーのトレーニング
    autoencoder = AE(normalized_gene_exp_tensor.shape[1]).to(device)
    train_autoencoder(autoencoder, normalized_gene_exp_dataloader)

    # 圧縮特徴の抽出
    compressed_features_tensor = autoencoder.encoder(normalized_gene_exp_tensor)
    compressed_features = pd.DataFrame(
        compressed_features_tensor.cpu().detach().numpy(), index=gene_exp.columns
    )

    # 薬物応答データの準備
    drug_response, nsc_sm = prepare_drug_data(is_nsc=False, is_gdsc=True, is_1=False)
    mfp = calculate_morgan_fingerprints(drug_response, nsc_sm)

    train_data = pd.read_csv(PATH + "train.csv")
    val_data = pd.read_csv(PATH + "val.csv")
    test_data = pd.read_csv(PATH + "test.csv")

    train_labels = torch.tensor(np.load(PATH + "train_labels.npy")).to(device)
    val_labels = torch.tensor(np.load(PATH + "val_labels.npy")).to(device)
    test_labels = torch.tensor(np.load(PATH + "test_labels.npy")).to(device)

    # トレーニング、検証、テストデータの準備
    train_data, val_data, test_data = prepare_train_val_test_data(
        train_data, val_data, test_data, compressed_features, mfp
    )

    # DFモデルのトレーニング
    df_model = DF().to(device)
    train_df_model(df_model, train_data, val_data, train_labels, val_labels)
    test_results = evaluate_model(df_model, test_data, test_labels)
    return print_binary_classification_metrics(test_labels, test_results)

In [6]:
res = pd.DataFrame()
PATH = "../gdsc2_data/"

for i in tqdm(range(5)):
    tmp = main(PATH)
    res = pd.concat([res, tmp])

100%|██████████| 5/5 [06:05<00:00, 73.19s/it]


In [7]:
res

Unnamed: 0,Accuracy,Precision,Recall,F1 Score,AUROC,AUPR
0,0.641304,0.716287,0.467963,0.56609,0.58861,0.622377
0,0.641304,0.716287,0.467963,0.56609,0.58861,0.622377
0,0.641304,0.716287,0.467963,0.56609,0.58861,0.622377
0,0.641304,0.716287,0.467963,0.56609,0.58861,0.622377
0,0.641304,0.716287,0.467963,0.56609,0.58861,0.622377
