In [1]:
import argparse

from tqdm import tqdm

In [2]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

In [3]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "ctrp"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [4]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    # 0,  # Cell
    1  # Drug
]

In [5]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)

    val_labels = sampler.test_data[sampler.test_mask]

    if len(np.unique(val_labels)) < 2:
        print(f"Target {target_index} skipped: Validation set has only one class.")
        return None, None

    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [6]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 50  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )

Processing dim 1:   0%|          | 0/460 [00:00<?, ?it/s]

epoch:   0 loss:0.697310 auc:0.5137
epoch:   0 loss:0.700717 auc:0.6475
epoch:   0 loss:0.695436 auc:0.4598
epoch:   0 loss:0.705157 auc:0.5000
epoch:   0 loss:0.699093 auc:0.4590
epoch:   0 loss:0.701050 auc:0.5790
epoch:   0 loss:0.696832 auc:0.4725
epoch:   0 loss:0.699978 auc:0.4839
epoch:   0 loss:0.700585 auc:0.4544
epoch:   0 loss:0.702219 auc:0.4307
epoch:   0 loss:0.700091 auc:0.5405
epoch:   0 loss:0.702225 auc:0.5345
epoch:   0 loss:0.699388 auc:0.4475
epoch:   0 loss:0.702984 auc:0.2778
epoch:   0 loss:0.696174 auc:0.5763
epoch:   0 loss:0.698586 auc:0.6300
epoch:   0 loss:0.701285 auc:0.4486
epoch:   0 loss:0.704084 auc:0.5500
epoch:   0 loss:0.702291 auc:0.5686
epoch:   0 loss:0.698783 auc:0.6019
epoch:   0 loss:0.700363 auc:0.4234
epoch:   0 loss:0.703336 auc:0.4421
epoch:   0 loss:0.703573 auc:0.4583
epoch:   0 loss:0.700141 auc:0.4880
epoch:   0 loss:0.698600 auc:0.4150
epoch:   0 loss:0.699424 auc:0.5979
epoch:   0 loss:0.702226 auc:0.6094
epoch:   0 loss:0.700616 auc

epoch: 240 loss:0.225055 auc:0.7268
epoch: 280 loss:0.222799 auc:0.7294
epoch: 180 loss:0.233903 auc:0.7704
epoch: 240 loss:0.231603 auc:0.6650
epoch:  60 loss:0.309749 auc:0.7400
epoch: 120 loss:0.245704 auc:0.7793
epoch:  60 loss:0.307562 auc:0.6253
epoch: 320 loss:0.224103 auc:0.7438
epoch: 280 loss:0.223097 auc:0.6187
epoch: 300 loss:0.221925 auc:0.7460
epoch: 220 loss:0.227815 auc:0.9022
epoch: 140 loss:0.244512 auc:0.9514
epoch: 240 loss:0.227754 auc:0.6821
epoch: 100 loss:0.261579 auc:0.8166
epoch: 300 loss:0.221497 auc:0.7394
epoch: 260 loss:0.224231 auc:0.7223
epoch: 180 loss:0.232645 auc:0.7403
epoch:  80 loss:0.279769 auc:0.8643
epoch: 340 loss:0.220244 auc:0.7512
epoch: 160 loss:0.237772 auc:0.7576
epoch: 260 loss:0.224921 auc:0.6855
epoch: 200 loss:0.231538 auc:0.7738
epoch: 300 loss:0.221380 auc:0.6014
epoch: 320 loss:0.223794 auc:0.7188
epoch: 240 loss:0.224752 auc:0.8978
epoch:  80 loss:0.279208 auc:0.7295
epoch: 320 loss:0.221211 auc:0.7371
epoch: 360 loss:0.223971 auc

KeyboardInterrupt: 

epoch: 300 loss:0.221835 auc:0.8711
epoch: 340 loss:0.220391 auc:0.6621
epoch: 400 loss:0.217863 auc:0.7433
epoch: 320 loss:0.220234 auc:0.6833
epoch: 120 loss:0.248842 auc:0.8223
epoch: 420 loss:0.219251 auc:0.7279
epoch: 240 loss:0.226776 auc:0.7351
epoch: 160 loss:0.233928 auc:0.7786
epoch: 400 loss:0.219931 auc:0.5817
epoch: 460 loss:0.220270 auc:0.7256


In [None]:
true_data_s.to_csv(f"new_drug_true_{args.data}.csv")
predict_data_s.to_csv(f"new_drug_pred_{args.data}.csv")