In [1]:
import argparse
import time

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import (accuracy_score, average_precision_score, f1_score,
                             precision_score, recall_score, roc_auc_score)
from sklearn.model_selection import KFold

In [2]:
%load_ext autoreload
%autoreload 2

from data_load import data_load
from data_process import process
from load_data import load_data
from model import *
from my_utiils import *
from sampler import Sampler



OSError: /gpfs/gsfs12/users/inouey2/conda/envs/genex/lib/python3.10/site-packages/torch_scatter/_scatter_cuda.so: undefined symbol: _ZN2at23SavedTensorDefaultHooks11set_tracingEb

In [None]:
alph = 0.30  # アルファ値
beta = 0.30  # ベータ値
epoch = 350  # エポック数
hidden_channels = 256  # 隠れ層のチャンネル数
output_channels = 100  # 出力層のチャンネル数

In [None]:
data = "gdsc1"
(drug_feature, exp, mutation, methylation, nb_celllines, nb_drugs) = data_load(data)
res, exprs, null_mask, pos_num = load_data(data)
cells = {i: j for i, j in enumerate(res.index)}
drugs = {i: j for i, j in enumerate(res.columns)}

nb_celllines = len(cells)
nb_drugs = len(drugs)

In [None]:
def train():
    model.train()
    loss_temp = 0
    print("Training batch:", end=" ")
    for batch, (drug, cell) in enumerate(zip(drug_set, cellline_set)):
        print(f"{batch+1}", end=" ")
        optimizer.zero_grad()
        pos_z, neg_z, summary_pos, summary_neg, pos_adj = model(
            drug_feature=drug.x,
            drug_adj=drug.edge_index,
            ibatch=drug.batch,
            gexpr_data=cell[0],
            mutation_data=cell[1] if len(cell) > 1 else None,
            methylation_data=cell[2] if len(cell) > 2 else None,
            edge=train_edge,
        )
        dgi_pos = model.loss(pos_z, neg_z, summary_pos)
        dgi_neg = model.loss(neg_z, pos_z, summary_neg)

        pos_loss = myloss(pos_adj[train_mask], label_pos[train_mask])
        loss = (1 - alph - beta) * pos_loss + alph * dgi_pos + beta * dgi_neg
        loss.backward()
        optimizer.step()
        loss_temp += loss.item()
    print("\nTrain loss: ", str(round(loss_temp, 4)))

In [None]:
def test():
    model.eval()
    print("Testing...")
    with torch.no_grad():
        for batch, (drug, cell) in enumerate(zip(drug_set, cellline_set)):
            _, _, _, _, pre_adj = model(
                drug_feature=drug.x,
                drug_adj=drug.edge_index,
                ibatch=drug.batch,
                gexpr_data=cell[0],
                mutation_data=cell[1] if len(cell) > 1 else None,
                methylation_data=cell[2] if len(cell) > 2 else None,
                edge=train_edge,
            )

            yp = pre_adj[test_mask].detach().numpy()
            ytest = label_pos[test_mask]

            # AUCとAUPRを計算
            AUC = roc_auc_score(ytest, yp)
            AUPR = average_precision_score(ytest, yp)

            # 二値分類のための閾値処理
            yp_binary = (yp > 0.5).astype(int)

            # 分類指標を計算
            accuracy = accuracy_score(ytest, yp_binary)
            precision = precision_score(ytest, yp_binary)
            recall = recall_score(ytest, yp_binary)
            f1 = f1_score(ytest, yp_binary)

            print("Test metrics:")
            print(f"  Accuracy: {accuracy:.4f}")
            print(f"  Precision: {precision:.4f}")
            print(f"  Recall: {recall:.4f}")
            print(f"  F1: {f1:.4f}")
            print(f"  AUC: {AUC:.4f}")
            print(f"  AUPR: {AUPR:.4f}")

        return AUC, AUPR, f1, accuracy, ytest, yp

In [None]:
kfold = KFold(n_splits=5, shuffle=True, random_state=42)
true_datas = pd.DataFrame()
predict_datas = pd.DataFrame()

for i, (train_index, test_index) in enumerate(kfold.split(np.arange(pos_num))):
    # Initialize sampler and model
    sampler = Sampler(res, train_index, test_index, null_mask, i)
    train_data = pd.DataFrame(sampler.train_data, index=res.index, columns=res.columns)
    test_data = pd.DataFrame(sampler.test_data, index=res.index, columns=res.columns)

    train_mask = pd.DataFrame(sampler.train_mask, index=res.index, columns=res.columns)
    test_mask = pd.DataFrame(sampler.test_mask, index=res.index, columns=res.columns)

    train_df = pd.DataFrame(train_mask.values.nonzero()).T
    train_df[2] = train_data.values[train_mask.values.nonzero()].astype(int)

    test_df = pd.DataFrame(test_mask.values.nonzero()).T
    test_df[2] = test_data.values[test_mask.values.nonzero()].astype(int)

    train_df[0] = [cells[i] for i in train_df[0]]
    train_df[1] = [drugs[i] for i in train_df[1]]

    test_df[0] = [cells[i] for i in test_df[0]]
    test_df[1] = [drugs[i] for i in test_df[1]]

    cols = ["Cell", "Drug", "labels"]

    train_df.columns = cols
    test_df.columns = cols

    drug_set, cellline_set, train_edge, label_pos, train_mask, test_mask, atom_shape = (
        process(
            drug_feature,
            mutation,
            exprs,
            methylation,
            train_df,
            test_df,
            nb_celllines,
            nb_drugs,
        )
    )

    use_mutation = True
    use_methylation = True

    if data == "nci":
        dim_mut = 510
    elif data == "gdsc1":
        dim_mut = 1020
    elif data == "gdsc2":
        dim_mut = 1020
    elif data == "ctrp":
        dim_mut = 1020
        use_mutation = False
        use_methylation = False
    else:
        NotImplementedError

    model = GraphCDR(
        hidden_channels=hidden_channels,
        encoder=Encoder(output_channels, hidden_channels),
        summary=Summary(output_channels, hidden_channels),
        feat=NodeRepresentation(
            atom_shape,
            exp.shape[-1],
            methylation.shape[-1],
            dim_mut,
            output_channels,
            use_mutation=use_mutation,
            use_methylation=use_methylation,
        ),
        index=nb_celllines,
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0)
    myloss = nn.BCELoss()

    # ------main
    final_AUC = 0
    final_AUPR = 0
    final_F1 = 0
    final_ACC = 0
    for epoch in range(epoch):
        print("\nepoch: " + str(epoch))
        train()
        AUC, AUPR, F1, ACC, ytest, yp = test()
        if AUC > final_AUC:
            final_AUC = AUC
            final_AUPR = AUPR
            final_F1 = F1
            final_ACC = ACC
            final_ytest = ytest
            final_yp = yp

    print(
        "Final_AUC: "
        + str(round(final_AUC, 4))
        + "  Final_AUPR: "
        + str(round(final_AUPR, 4))
        + "  Final_F1: "
        + str(round(final_F1, 4))
        + "  Final_ACC: "
        + str(round(final_ACC, 4))
    )
    print("---------------------------------------")

    true_datas = pd.concat(
        [true_datas, pd.DataFrame(final_ytest)], ignore_index=True, axis=1
    )
    predict_datas = pd.concat(
        [predict_datas, pd.DataFrame(yp)], ignore_index=True, axis=1
    )

true_datas.to_csv(f"true_{data}.csv")
predict_datas.to_csv(f"pred_{data}.csv")