In [1]:
import sys
print(sys.executable)

/scratch.global/inoue019/conda-envs/myenv/bin/python


In [11]:
!{sys.executable} -m pip install tqdm pubchempy pandas numpy scipy scikit-learn seaborn rdkit

Collecting rdkit
  Downloading rdkit-2024.9.6-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (4.0 kB)
Downloading rdkit-2024.9.6-cp310-cp310-manylinux_2_28_x86_64.whl (34.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m28.2 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hInstalling collected packages: rdkit
Successfully installed rdkit-2024.9.6


In [9]:
import argparse

from tqdm import tqdm

In [12]:
%load_ext autoreload
%autoreload 2

from load_data import load_data
from model import Optimizer, nihgcn
from myutils import *
from sampler import NewSampler
from sklearn.model_selection import KFold

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [13]:
class Args:
    def __init__(self):
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu"
        )  # cuda:number or cpu
        self.data = "ctrp"  # Dataset{gdsc or ccle}
        self.lr = 0.001  # the learning rate
        self.wd = 1e-5  # the weight decay for l2 normalizaton
        self.layer_size = [1024, 1024]  # Output sizes of every layer
        self.alpha = 0.25  # the scale for balance gcn and ni
        self.gamma = 8  # the scale for sigmod
        self.epochs = 1000  # the epochs for model


args = Args()

In [14]:
res, drug_finger, exprs, null_mask, pos_num = load_data(args)
cell_sum = np.sum(res, axis=1)
drug_sum = np.sum(res, axis=0)

target_dim = [
    0,  # Cell
    # 1  # Drug
]

In [15]:
def nihgcn_new(
    cell_exprs,
    drug_finger,
    res_mat,
    null_mask,
    target_dim,
    target_index,
    evaluate_fun,
    args,
    seed,
):

    sampler = NewSampler(res_mat, null_mask, target_dim, target_index, seed)

    val_labels = sampler.test_data[sampler.test_mask]

    if len(np.unique(val_labels)) < 2:
        print(f"Target {target_index} skipped: Validation set has only one class.")
        return None, None

    model = nihgcn(
        sampler.train_data,
        cell_exprs=cell_exprs,
        drug_finger=drug_finger,
        layer_size=args.layer_size,
        alpha=args.alpha,
        gamma=args.gamma,
        device=args.device,
    )
    opt = Optimizer(
        model,
        sampler.train_data,
        sampler.test_data,
        sampler.test_mask,
        sampler.train_mask,
        evaluate_fun,
        lr=args.lr,
        wd=args.wd,
        epochs=args.epochs,
        device=args.device,
    )
    true_data, predict_data = opt()
    return true_data, predict_data

In [17]:
from joblib import Parallel, delayed
from tqdm import tqdm

n_kfold = 1
n_jobs = 50  # 並列数


def process_iteration(dim, target_index, seed, args):
    """各反復処理をカプセル化した関数"""
    if dim:
        if drug_sum[target_index] < 10:
            return None, None
    else:
        if cell_sum[target_index] < 10:
            return None, None

    fold_results = []
    for fold in range(n_kfold):
        true_data, predict_data = nihgcn_new(
            cell_exprs=exprs,
            drug_finger=drug_finger,
            res_mat=res,
            null_mask=null_mask,
            target_dim=dim,
            target_index=target_index,
            evaluate_fun=roc_auc,
            args=args,
            seed=seed,
        )
        fold_results.append((true_data, predict_data))

    return fold_results


# 並列処理の実行
true_data_s = pd.DataFrame()
predict_data_s = pd.DataFrame()

for dim in target_dim:
    # 全タスクを事前に生成
    tasks = [
        (dim, target_index, seed, args)
        for seed, target_index in enumerate(np.arange(res.shape[dim]))
    ]

    # 並列実行（プログレスバー付き）
    results = Parallel(n_jobs=n_jobs, verbose=0, prefer="threads")(
        delayed(process_iteration)(*task)
        for task in tqdm(tasks, desc=f"Processing dim {dim}")
    )

    # 結果の結合
    for fold_results in results:
        if fold_results is None:
            continue
        for true_data, predict_data in fold_results:
            true_data_s = pd.concat(
                [true_data_s, translate_result(true_data)],
                ignore_index=True,
                copy=False,  # メモリ節約のため
            )
            predict_data_s = pd.concat(
                [predict_data_s, translate_result(predict_data)],
                ignore_index=True,
                copy=False,
            )


Processing dim 0:   0%|          | 0/823 [00:00<?, ?it/s][A

epoch:   0 loss:0.699953 auc:0.5388
epoch:   0 loss:0.698458 auc:0.4644
epoch:   0 loss:0.703540 auc:0.4805
epoch:   0 loss:0.698881 auc:0.5223
epoch:   0 loss:0.697864 auc:0.4389
epoch:   0 loss:0.700997 auc:0.5237
epoch:   0 loss:0.705132 auc:0.5138
epoch:   0 loss:0.703500 auc:0.4763
epoch:   0 loss:0.700175 auc:0.4216
epoch:   0 loss:0.700983 auc:0.4707
epoch:   0 loss:0.698952 auc:0.5197
epoch:   0 loss:0.702697 auc:0.5402
epoch:   0 loss:0.700577 auc:0.5648
epoch:   0 loss:0.701639 auc:0.5107
epoch:   0 loss:0.701718 auc:0.4957
epoch:   0 loss:0.700503 auc:0.4840
epoch:   0 loss:0.699038 auc:0.5136
epoch:   0 loss:0.704676 auc:0.4763
epoch:   0 loss:0.698996 auc:0.4397
epoch:   0 loss:0.702801 auc:0.4713
epoch:   0 loss:0.697613 auc:0.4685
epoch:   0 loss:0.701576 auc:0.5237
epoch:   0 loss:0.698256 auc:0.5406
epoch:   0 loss:0.701663 auc:0.5272
epoch:   0 loss:0.698616 auc:0.5434
epoch:   0 loss:0.700015 auc:0.4764
epoch:   0 loss:0.700123 auc:0.4070
epoch:   0 loss:0.703257 auc

KeyboardInterrupt: 

epoch: 380 loss:0.222023 auc:0.8977
epoch: 360 loss:0.219016 auc:0.8916
epoch: 260 loss:0.226092 auc:0.9317
epoch: 280 loss:0.223057 auc:0.9581
epoch: 380 loss:0.218728 auc:0.8449
epoch: 160 loss:0.237393 auc:0.9408
epoch: 200 loss:0.231040 auc:0.9545
epoch: 240 loss:0.225650 auc:0.9627
epoch: 100 loss:0.264644 auc:0.9531
epoch: 400 loss:0.217976 auc:0.8980
epoch:  40 loss:0.326110 auc:0.9213
epoch: 380 loss:0.218726 auc:0.8975


In [None]:
true_data_s.to_csv(f"new_cell_true_{args.data}.csv")
predict_data_s.to_csv(f"new_cell_pred_{args.data}.csv")