In [None]:
# 1. Data loading

import pandas as pd
drugs = pd.read_csv('data/drugs.csv')
diseases = pd.read_csv('data/diseases.csv')
assoc = pd.read_csv('data/associations.csv')

In [None]:
# 2. Feature Engineering

from rdkit import Chem
from rdkit.Chem import AllChem
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np

# Drug fingerprint generation (a 1024-bit string)
def mol_to_fp(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol:
        return AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits = 1024)
    else :
        return np.zeros(1024)

drug_fps = np.array([np.array(mol_to_fps(s)) for s in drugs['smiles]']])

# Disease text encoding
vectoriser = TfidfVectorizer(max_features = 300)
disease_vecs = vectoriser.fit_transform(diseases['description']).toarray()

In [None]:
# 3. Merge and Prepare for training

merged = assoc.merge(drugs, on = 'drug_name').merge(diseases, on = 'disease_name')
X = []
Y = merged['label']

for _, row in merged.iterrows() :
    drug_vec = mol_to_fp(row['smails'])
    disease_vec = vectoriser.transform([row['description']]).toarray().flatten()
    x.append(np.concatenate([drug_vec, disease_vec]))

x = np.array(X)

In [None]:
# 4. Train Model

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 42)

model = RandomForestClassifier(n_estimators = 200, random_state = 42)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
print("Accuracy : ", accuracy_score(y_test, y_pred))
print("ROC-AUC : ", roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]))

In [None]:
# 5. Evaluation Plots

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve

# ROC Curve
fpr, tpr, _ = roc_curve(y_test, model.predict_proba(X_test)[:, 1])
plt.plot(fpr, tpr)
plt.title('ROC Curve')
plt.show()

# Confusion Matrix
sns.heatmap(confusion_matrix(y_test, y_prod), annot = True, fmt = 'd', cmap = 'Blues')
plt.title('Confusion Matrix')
plt.show()