In [1]:
# %pip install numpy pandas xgboost lightgbm scikit-learn openpyxl

In [2]:
# %pip install cupy-cuda12x

In [3]:
# # MLP
# %pip install torch torchvision torchaudio

In [4]:
import numpy as np
import pandas as pd

from utils.utils import load_dataset
from utils.workflow import train_eval, case_study, benchmark_all_cases

from function.ITRPCA import ITRPCA_F
from function.VDA_GKSBMF import A_VDA_GMSBMF
from function.DRPADC import WeightImputeLogFactorization, GetProbability
from function.MVL import MVL_F
from function.BMC import BMC_F
from function.MSBMF import MSBMF_F
from function.HGIMC import fBMC, fGRB, fHGI
from function.AMVL import AMVL

Note: You have installed the 'manylinux2014' variant of XGBoost. Certain features such as GPU algorithms or federated learning are not available. To use these features, please upgrade to a recent Linux distro with glibc 2.28+, and install the 'manylinux_2_28' variant.


## HN-DREP Benchmark

In [5]:
benchmark = pd.read_excel('data/Benchmark/evaluation-results.xlsx')
benchmark.head()

Unnamed: 0,id,dataset,auc,aupr,f1,time(s),memory(kb),overall(auc+aupr+f1)
0,HINGRL,Fdataset,0.9432,0.9515,0.8789,18.0,655120.0,2.7736
1,DRPADC,iDrug,0.9746,0.9731,0.7,1177.0,1706124.0,2.6477
2,VDA-GKSBMF,Ydataset,0.9699,0.9751,0.6981,111.0,979672.0,2.6431
3,MLMC,Cdataset,0.9706,0.9747,0.6945,57.0,889936.0,2.6398
4,ITRPCA,Cdataset,0.9654,0.973,0.696,18.0,1328172.0,2.6344


### Note: Explanation of High Average F1 Score in the HINGRL Algorithm

The HINGRL algorithm achieved a relatively high average F1 score (0.8789) on the Fdataset dataset. However, it’s important to note that this high score may be influenced by specific strategies in generating negative samples and the model training methods used. Key points are outlined below:

1. **Negative Sample Generation Strategy:**
   - The algorithm generates an equal number of negative samples by randomly selecting drug-disease pairs that are not labeled as positive in the existing dataset. While this ensures a balanced number of positive and negative samples, the generated negatives may not represent true non-associated pairs—they could be potential positives that are simply undiscovered.
   - This simple random sampling method does not account for biological similarities, functional annotations, or other prior knowledge, potentially making the negatives less representative of real-world scenarios. As a result, the model may perform well on the training set but have limited generalization capability in real applications.

2. **Balanced Data Leading to Inflated F1 Scores:**
   - The training set is perfectly balanced with equal positive and negative samples, which makes it easier for the model to achieve high precision and recall, and thus a higher F1 score.
   - In real-world applications, negative samples usually vastly outnumber positive ones. This artificially balanced training data does not reflect real-world distributions, causing the F1 score to be overestimated.

3. **Use of Random Forest for Training:**
   - The HINGRL algorithm utilizes a Random Forest model, which performs exceptionally well on balanced datasets by leveraging multiple decision trees to reduce overfitting and improve precision and recall.
   - While it achieves high F1 scores on balanced data, its performance may degrade on real, imbalanced datasets where the proportion of negative samples is much higher.

**Conclusion:**

The high average F1 score achieved by the HINGRL algorithm is partly due to the simplistic negative sample generation and training on a balanced dataset. While the results appear promising on the current dataset, its generalization to more complex and imbalanced real-world drug-disease prediction scenarios may be limited.

## Data
- **Fdataset**
- **Cdataset**
- **Ydataset**

Each of the four algorithms achieved the highest score on the datasets mentioned above.

In [6]:
datasets = ['Fdataset', 'Cdataset', 'Ydataset']

In [7]:
# Test
dataset_name = 'Fdataset'
(drug_name, disease_name, 
            Wrd, 
            Wrr_eight, 
            Wrr_seven_llms_kgs, Wrr_seven_llms_geps, Wrr_seven_kgs_geps, 
            Wrr_six_geps, Wrr_six_kgs, Wrr_six_llms, 
            Wrr_five, 
            Wdd_three, 
            Wdd_two, 
            Trr_eight, 
            Trr_seven_llms_kgs, Trr_seven_llms_geps, Trr_seven_kgs_geps, 
            Trr_six_geps, Trr_six_kgs, Trr_six_llms, Trr_five, 
            Tdd_three, 
            Tdd_two, 
            drug_embeddings, disease_embeddings) = load_dataset(dataset_name, embedding_type='llm')

In [8]:
drug_name[:3], disease_name[:3]

(array(['DB00007', 'DB00010', 'DB00014'], dtype='<U7'),
 array(['D102100', 'D102300', 'D102400'], dtype='<U7'))

In [9]:
len(Wrr_eight), len(Wdd_three), len(Wrr_five), len(Wdd_two)

(8, 3, 5, 2)

In [10]:
len(Wrr_seven_llms_kgs)

7

In [11]:
len(Wrr_six_llms), Trr_six_llms.shape

(6, (593, 593, 6))

In [12]:
drug_embeddings.shape, disease_embeddings.shape

((593, 1024), (313, 1024))

## AMVL

### Benchmark

In [14]:
amvl_benchmark = benchmark_all_cases(AMVL, datasets=['Fdataset', 'Cdataset', 'Ydataset'])
amvl_benchmark

################## Benchmarking AMVL ##################


################## Benchmarking AMVL on Fdataset ##################


################## Loading Fdataset ##################


################## 8 + 3 ##################

Precomputing drug and disease similarity features...
Retrieving positive and negative sample indices...
Assigning positive samples to 10-fold cross-validation...

Test set prepared, positive interactions masked in Wrd matrix...
>>> Step 1: Starting BMC matrix completion...
>>> BMC converged in 165 iterations.
>>> Calculating GIP similarity matrices for disease-disease and drug-drug interactions...
>>> Removing diagonal elements to eliminate self-interactions...
>>> Step 2: Performing Matrix Factorization with similarity regularization...
>>> Matrix factorization completed in 322 iterations.
>>> Thresholding the recovered matrix based on the threshold value...
>>> Step 3: Starting multi-view learning...
>>> Performing multi-view learning with 8 drug views and 3

combination,model,dataset,metric,8 + 3,8 + 2,7 + 3 (llms + kgs),7 + 3 (llms + geps),7 + 3 (kgs + geps),7 + 2 (llms + kgs),7 + 2 (llms + geps),7 + 2 (kgs + geps),6 + 3 (llms),6 + 3 (kgs),6 + 3 (geps),6 + 2 (llms),6 + 2 (kgs),6 + 2 (geps),5 + 3,5 + 2
0,AMVL,Fdataset,auc,0.9610 (0.9557 - 0.9662),0.9611 (0.9552 - 0.9671),0.9611 (0.9560 - 0.9661),0.9607 (0.9553 - 0.9661),0.9589 (0.9533 - 0.9644),0.9603 (0.9548 - 0.9657),0.9614 (0.9554 - 0.9675),0.9605 (0.9545 - 0.9664),0.9586 (0.9528 - 0.9644),0.9585 (0.9531 - 0.9639),0.9578 (0.9515 - 0.9641),0.9612 (0.9550 - 0.9673),0.9594 (0.9532 - 0.9656),0.9589 (0.9530 - 0.9648),0.9582 (0.9530 - 0.9634),0.9583 (0.9518 - 0.9647)
1,AMVL,Fdataset,aupr,0.9675 (0.9630 - 0.9720),0.9671 (0.9621 - 0.9722),0.9666 (0.9618 - 0.9713),0.9671 (0.9627 - 0.9715),0.9651 (0.9590 - 0.9712),0.9669 (0.9626 - 0.9713),0.9673 (0.9619 - 0.9726),0.9670 (0.9623 - 0.9717),0.9643 (0.9578 - 0.9707),0.9653 (0.9604 - 0.9703),0.9656 (0.9603 - 0.9709),0.9676 (0.9616 - 0.9736),0.9654 (0.9599 - 0.9709),0.9646 (0.9578 - 0.9715),0.9660 (0.9614 - 0.9705),0.9655 (0.9593 - 0.9717)
2,AMVL,Fdataset,f1,0.7215 (0.7152 - 0.7277),0.7186 (0.7136 - 0.7236),0.7194 (0.7144 - 0.7245),0.7234 (0.7190 - 0.7278),0.7202 (0.7138 - 0.7267),0.7207 (0.7163 - 0.7251),0.7169 (0.7107 - 0.7231),0.7183 (0.7135 - 0.7230),0.7185 (0.7125 - 0.7244),0.7192 (0.7131 - 0.7254),0.7181 (0.7126 - 0.7236),0.7171 (0.7107 - 0.7234),0.7142 (0.7080 - 0.7203),0.7184 (0.7134 - 0.7233),0.7163 (0.7108 - 0.7218),0.7118 (0.7056 - 0.7180)
3,AMVL,Cdataset,auc,0.9710 (0.9647 - 0.9773),0.9712 (0.9646 - 0.9779),0.9713 (0.9651 - 0.9776),0.9710 (0.9642 - 0.9777),0.9694 (0.9635 - 0.9753),0.9717 (0.9657 - 0.9778),0.9709 (0.9642 - 0.9776),0.9701 (0.9633 - 0.9769),0.9716 (0.9656 - 0.9776),0.9706 (0.9640 - 0.9772),0.9696 (0.9630 - 0.9761),0.9718 (0.9657 - 0.9779),0.9700 (0.9638 - 0.9762),0.9701 (0.9638 - 0.9764),0.9702 (0.9643 - 0.9761),0.9703 (0.9641 - 0.9764)
4,AMVL,Cdataset,aupr,0.9755 (0.9705 - 0.9805),0.9754 (0.9700 - 0.9809),0.9762 (0.9713 - 0.9811),0.9758 (0.9706 - 0.9810),0.9735 (0.9696 - 0.9774),0.9763 (0.9715 - 0.9811),0.9753 (0.9701 - 0.9805),0.9747 (0.9693 - 0.9801),0.9765 (0.9721 - 0.9810),0.9753 (0.9700 - 0.9806),0.9744 (0.9691 - 0.9796),0.9764 (0.9717 - 0.9811),0.9741 (0.9686 - 0.9795),0.9750 (0.9702 - 0.9798),0.9755 (0.9709 - 0.9801),0.9753 (0.9705 - 0.9800)
5,AMVL,Cdataset,f1,0.7296 (0.7243 - 0.7350),0.7256 (0.7195 - 0.7317),0.7306 (0.7264 - 0.7349),0.7305 (0.7256 - 0.7353),0.7271 (0.7219 - 0.7322),0.7253 (0.7196 - 0.7311),0.7244 (0.7182 - 0.7307),0.7211 (0.7149 - 0.7272),0.7278 (0.7227 - 0.7328),0.7260 (0.7211 - 0.7309),0.7252 (0.7200 - 0.7304),0.7232 (0.7169 - 0.7296),0.7212 (0.7160 - 0.7264),0.7216 (0.7155 - 0.7278),0.7247 (0.7188 - 0.7306),0.7198 (0.7143 - 0.7252)
6,AMVL,Ydataset,auc,0.9634 (0.9575 - 0.9693),0.9693 (0.9638 - 0.9747),0.9661 (0.9608 - 0.9714),0.9659 (0.9601 - 0.9718),0.9690 (0.9633 - 0.9748),0.9718 (0.9679 - 0.9757),0.9721 (0.9688 - 0.9754),0.9714 (0.9680 - 0.9747),0.9697 (0.9651 - 0.9743),0.9708 (0.9671 - 0.9744),0.9706 (0.9673 - 0.9739),0.9715 (0.9678 - 0.9753),0.9714 (0.9680 - 0.9749),0.9711 (0.9679 - 0.9742),0.9705 (0.9668 - 0.9742),0.9710 (0.9677 - 0.9744)
7,AMVL,Ydataset,aupr,0.9588 (0.9480 - 0.9695),0.9698 (0.9609 - 0.9787),0.9661 (0.9566 - 0.9757),0.9647 (0.9542 - 0.9752),0.9713 (0.9623 - 0.9802),0.9753 (0.9719 - 0.9788),0.9756 (0.9728 - 0.9785),0.9750 (0.9720 - 0.9779),0.9730 (0.9679 - 0.9782),0.9748 (0.9720 - 0.9777),0.9751 (0.9723 - 0.9778),0.9755 (0.9725 - 0.9785),0.9750 (0.9721 - 0.9779),0.9746 (0.9718 - 0.9775),0.9751 (0.9722 - 0.9780),0.9752 (0.9725 - 0.9779)
8,AMVL,Ydataset,f1,0.7683 (0.7471 - 0.7895),0.7414 (0.7246 - 0.7583),0.7511 (0.7325 - 0.7698),0.7526 (0.7347 - 0.7705),0.7378 (0.7270 - 0.7485),0.7316 (0.7286 - 0.7347),0.7286 (0.7261 - 0.7311),0.7269 (0.7225 - 0.7313),0.7380 (0.7255 - 0.7504),0.7326 (0.7287 - 0.7364),0.7286 (0.7256 - 0.7317),0.7267 (0.7238 - 0.7296),0.7265 (0.7237 - 0.7293),0.7248 (0.7213 - 0.7282),0.7277 (0.7243 - 0.7310),0.7237 (0.7208 - 0.7266)


### Manual

#### Baseline

In [None]:
# 5 + 2
train_eval(
    algorithm_func=AMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_five,
    Wdd_list=Wdd_two,
    Trr=Trr_five,
    Tdd=Tdd_two,
    folds=10
)

#### 8+3

In [None]:
# 8 + 3
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_eight,
    Wdd_list=Wdd_three,
    Trr=Trr_eight,
    Tdd=Tdd_three,
    folds=10
)

#### 8+2

In [None]:
# 8 + 2
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_eight,
    Wdd_list=Wdd_two,
    Trr=Trr_eight,
    Tdd=Tdd_two,
    folds=10
)

#### 7+3

In [None]:
# llms + kgs
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_seven_llms_kgs,
    Wdd_list=Wdd_three,
    Trr=Trr_seven_llms_kgs,
    Tdd=Tdd_three,
    folds=10
)

In [None]:
# llms + geps
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_seven_llms_geps,
    Wdd_list=Wdd_three,
    Trr=Trr_seven_llms_geps,
    Tdd=Tdd_three,
    folds=10
)

In [None]:
# kgs + geps
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_seven_kgs_geps,
    Wdd_list=Wdd_three,
    Trr=Trr_seven_kgs_geps,
    Tdd=Tdd_three,
    folds=10
)

#### 7+2

In [None]:
# llms + kgs
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_seven_llms_kgs,
    Wdd_list=Wdd_two,
    Trr=Trr_seven_llms_kgs,
    Tdd=Tdd_two,
    folds=10
)

In [None]:
# llms + geps
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_seven_llms_geps,
    Wdd_list=Wdd_two,
    Trr=Trr_seven_llms_geps,
    Tdd=Tdd_two,
    folds=10
)

In [None]:
# kgs + geps
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_seven_kgs_geps,
    Wdd_list=Wdd_two,
    Trr=Trr_seven_kgs_geps,
    Tdd=Tdd_two,
    folds=10
)

#### 6+3

In [None]:
# 6(llms) + 3
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_six_llms,
    Wdd_list=Wdd_three,
    Trr=Trr_six_llms,
    Tdd=Tdd_three,
    folds=10
)

In [None]:
# 6(kgs) + 3
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_six_kgs,
    Wdd_list=Wdd_three,
    Trr=Trr_six_kgs,
    Tdd=Tdd_three,
    folds=10
)

In [None]:
# 6(geps) + 3
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_six_geps,
    Wdd_list=Wdd_three,
    Trr=Trr_six_geps,
    Tdd=Tdd_three,
    folds=10
)

#### 6+2

In [None]:
# 6(llms) + 2
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_six_llms,
    Wdd_list=Wdd_two,
    Trr=Trr_six_llms,
    Tdd=Tdd_two,
    folds=10
)

In [None]:
# 6(kgs) + 2
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_six_kgs,
    Wdd_list=Wdd_two,
    Trr=Trr_six_kgs,
    Tdd=Tdd_two,
    folds=10
)

In [None]:
# 6(geps) + 2
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_six_geps,
    Wdd_list=Wdd_two,
    Trr=Trr_six_geps,
    Tdd=Tdd_two,
    folds=10
)

#### 5+3

In [None]:
# 5 + 3
train_eval(
    algorithm_func=AdaMVL,
    Wrd=Wrd,
    Wrr_list=Wrr_five,
    Wdd_list=Wdd_three,
    Trr=Trr_five,
    Tdd=Tdd_three,
    folds=10
)

### Case Study

In [13]:
amvl_benchmark = benchmark_all_cases(AMVL, datasets=['iDrug'], filename='amvl_idrug')
amvl_benchmark

################## Benchmarking AMVL ##################


################## Benchmarking AMVL on iDrug ##################


################## Loading iDrug ##################


################## 8 + 3 ##################

Precomputing drug and disease similarity features...
Retrieving positive and negative sample indices...
Assigning positive samples to 10-fold cross-validation...

Test set prepared, positive interactions masked in Wrd matrix...
>>> Step 1: Starting BMC matrix completion...
>>> BMC converged in 129 iterations.
>>> Calculating GIP similarity matrices for disease-disease and drug-drug interactions...
>>> Removing diagonal elements to eliminate self-interactions...
>>> Step 2: Performing Matrix Factorization with similarity regularization...
>>> Matrix factorization completed in 303 iterations.
>>> Thresholding the recovered matrix based on the threshold value...
>>> Step 3: Starting multi-view learning...
>>> Performing multi-view learning with 8 drug views and 3 disea

combination,model,dataset,metric,8 + 3,8 + 2,7 + 3 (llms + kgs),7 + 3 (llms + geps),7 + 3 (kgs + geps),7 + 2 (llms + kgs),7 + 2 (llms + geps),7 + 2 (kgs + geps),6 + 3 (llms),6 + 3 (kgs),6 + 3 (geps),6 + 2 (llms),6 + 2 (kgs),6 + 2 (geps),5 + 3,5 + 2
0,AMVL,iDrug,auc,0.9687 (0.9682 - 0.9693),0.9672 (0.9666 - 0.9677),0.9681 (0.9675 - 0.9687),0.9680 (0.9674 - 0.9686),0.9682 (0.9676 - 0.9688),0.9667 (0.9660 - 0.9674),0.9666 (0.9661 - 0.9671),0.9666 (0.9660 - 0.9672),0.9673 (0.9665 - 0.9681),0.9671 (0.9664 - 0.9677),0.9668 (0.9664 - 0.9672),0.9659 (0.9650 - 0.9667),0.9657 (0.9652 - 0.9662),0.9656 (0.9651 - 0.9661),0.9662 (0.9656 - 0.9668),0.9653 (0.9646 - 0.9660)
1,AMVL,iDrug,aupr,0.9707 (0.9703 - 0.9711),0.9699 (0.9695 - 0.9704),0.9704 (0.9698 - 0.9710),0.9704 (0.9698 - 0.9709),0.9703 (0.9698 - 0.9708),0.9696 (0.9691 - 0.9701),0.9696 (0.9693 - 0.9700),0.9697 (0.9692 - 0.9701),0.9701 (0.9694 - 0.9708),0.9699 (0.9694 - 0.9704),0.9698 (0.9695 - 0.9702),0.9691 (0.9685 - 0.9698),0.9690 (0.9685 - 0.9695),0.9689 (0.9685 - 0.9694),0.9694 (0.9689 - 0.9700),0.9688 (0.9682 - 0.9693)
2,AMVL,iDrug,f1,0.7374 (0.7365 - 0.7383),0.7202 (0.7194 - 0.7210),0.7336 (0.7327 - 0.7346),0.7282 (0.7268 - 0.7297),0.7308 (0.7301 - 0.7314),0.7177 (0.7167 - 0.7188),0.7142 (0.7138 - 0.7145),0.7166 (0.7157 - 0.7176),0.7253 (0.7238 - 0.7267),0.7265 (0.7249 - 0.7280),0.7231 (0.7227 - 0.7235),0.7115 (0.7110 - 0.7120),0.7123 (0.7115 - 0.7131),0.7107 (0.7099 - 0.7115),0.7201 (0.7193 - 0.7210),0.7093 (0.7087 - 0.7100)


In [None]:
case_study(Wrd, Wrr, Wdd, Trr, Tdd, algorithm_func=AdaMVL, drug_names=drug_name, disease_names=disease_name, top=100)

## MLMC

In [13]:
def MLMC(Wrr_list, Wdd_list, Wrd):
    alphaBMC = 10
    betaBMC = 10
    thresholdBMC = 0.8
    maxiterBMC = 300
    tol1BMC = 2 * 1e-3
    tol2BMC = 1 * 1e-5

    Wrr_ML = [w.copy() for w in Wrr_list]
    Wdd_ML = [w.copy() for w in Wdd_list]
    
    for i in range(len(Wrr_ML)):
        np.fill_diagonal(Wrr_ML[i], 0)
    
    for i in range(len(Wdd_ML)):
        np.fill_diagonal(Wdd_ML[i], 0)

    _, _, F = MVL_F(Wrr_ML, Wdd_ML, Wrd, 0.1, 0.1)

    trIndexBMC = (Wrd.T != 0).astype(float)
    A_bmc, iter = BMC_F(alphaBMC, betaBMC, Wrd.T, trIndexBMC, tol1BMC, tol2BMC, maxiterBMC, 0, 1)
    Wdr0 = A_bmc * (A_bmc > thresholdBMC)
    SR_MC, SD_MC, F_MC = MVL_F(Wrr_ML, Wdd_ML, Wdr0.T, 0.1, 0.1)

    return np.maximum(F, F_MC)

In [None]:
mlmc_benchmark = benchmark_all_cases(MLMC, datasets)
mlmc_benchmark

## MSBMF

In [15]:
def MSBMF(Wrr_list, Wdd_list, Wrd):
    lambda1 = 0.1
    lambda2 = 0.01
    lambda3 = lambda2
    k = int(min(Wrd.shape) * 0.7)
    maxiter = 300
    tol1 = 2 * 1e-3
    tol2 = 1 * 1e-4

    # Wrr = [Wrr1, Wrr2, Wrr3, Wrr4, Wrr5];
    Wrr = np.hstack(Wrr_list)
    # Wdd = [Wdd1, Wdd2];
    Wdd = np.hstack(Wdd_list)

    U, V, iter = MSBMF_F(Wrd.T, Wdd, Wrr, lambda1, lambda2, lambda3, k, tol1, tol2, maxiter);
    M_recovery = U @ V.T

    return M_recovery.T

In [None]:
msbmf_benchmark = benchmark_all_cases(MSBMF, datasets)
msbmf_benchmark

## ITRPCA

In [17]:
def ITRPCA(Wrd, Trr, Tdd):
    p = 0.9
    K = 30
    rat1 = 0.1
    rat2 = 0.2

    F_ITRPCA = ITRPCA_F(Trr, Tdd, Wrd.T, p, K, rat1, rat2).T

    return F_ITRPCA

In [None]:
itrpca_benchmark = benchmark_all_cases(ITRPCA, datasets, folds=10)
itrpca_benchmark

## HGIMC

In [19]:
def HGIMC(Wrr_list, Wdd_list, Wrd):
    A_DR = Wrd.T

    # # Base
    # R = Wrr_list[0]
    # D = Wdd_list[1]

    # Average
    R = np.mean(Wrr_list, axis=0)
    D = np.mean(Wdd_list, axis=0)

    alpha = 10
    beta = 10
    gamma = 0.1
    threshold = 0.1
    maxiter = 300
    tol1 = 2 * 1e-3
    tol2 = 1 * 1e-5

    trIndex = (A_DR != 0).astype(float)
    A_bmc, iter = fBMC(alpha, beta, A_DR, trIndex, tol1, tol2, maxiter, 0, 1)
    A_DR0 = A_bmc * (A_bmc > threshold)

    A_RR = fGRB(R, 0.5)
    A_DD = fGRB(D, 0.5)

    A_recovery = fHGI(gamma, A_DD, A_RR, A_DR0)

    return A_recovery.T


In [None]:
hgimc_benchmark = benchmark_all_cases(HGIMC, datasets)
hgimc_benchmark

## DRPADC

In [21]:
def DRPADC(Wrd, Wrr_list, Wdd_list):
    rank = 150
    rnseed = 0
    lR = 0.1
    lM = 1.0
    lN = 0.1
    num_iter = 550
    learn_rate = 0.09

    W = np.maximum(1, Wrd)

    # # Paper Setting
    # Wdd = Wdd_list[1]  # semantic sim
    # Wrr = Wrr_list[0]  # chemical sim
    
    # Benchmark
    Wdd = np.mean(Wdd_list, axis=0)
    Wrr = np.mean(Wrr_list, axis=0)

    F, G = WeightImputeLogFactorization(Wrd, Wrr, Wdd, W, rank, lR, lM, lN, num_iter, learn_rate, rnseed)
    
    PROB = GetProbability(np.dot(F, G.T))
    
    return PROB

In [None]:
drpadc_benchmark = benchmark_all_cases(DRPADC, datasets)
drpadc_benchmark

## VDA-GKSBMF

In [23]:
def VDA_GKSBMF(Wrd, Wrr_list, Wdd_list):
    gm = 0.5
    w = 0.3
    lambda1 = 1
    lambda2 = lambda1
    lambda3 = lambda2
    tol1 = 2 * 1e-30
    tol2 = 2 * 1e-40
    maxiter = 400

    k = int(min(Wrd.shape) * 0.7)

    # # Paper Setting
    # Wdd = Wdd_list[1]
    # Wrr = Wrr_list[0]

    # Benchmark
    Wdd = np.mean(Wdd_list, axis=0)
    Wrr = np.mean(Wrr_list, axis=0)

    M_recovery = A_VDA_GMSBMF(Wrd.T, Wdd, Wrr, gm, w, lambda1, lambda2, lambda3, k, tol1, tol2, maxiter)
    
    return M_recovery.T

In [None]:
vda_gksbmf_benchmark = benchmark_all_cases(VDA_GKSBMF, datasets)
vda_gksbmf_benchmark

## Machine Learning

In [None]:
train_eval(
    drug_embeddings=drug_embeddings,
    disease_embeddings=disease_embeddings,
    Wrd=Wrd,
    Wrr_list=Wrr,
    Wdd_list=Wdd,
    Trr=Trr,
    Tdd=Tdd,
    ml_benchmark=True
)