# Comparing M6A ML models (HIFI read models)

In [1]:
# Get all the requisite imports
import torch
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics as metrics
import pandas as pd
import requests
import io
import xgboost as xgb 
from m6a_calling import M6ANet
from torchsummary import summary

from sklearn.metrics import (confusion_matrix,
                             accuracy_score,
                             precision_score,
                             recall_score,
                             f1_score,
                             balanced_accuracy_score,
                             matthews_corrcoef,
                             roc_auc_score,
                             average_precision_score,
                             roc_curve,
                             precision_recall_curve
                             )

### Read in the HiFi test data

In [17]:
# HiFi data

data_path = "../data/deprecated-data-2022-10-19/m6A_test_other_half_hifi.npz"

train_val_data = np.load(data_path, allow_pickle=True)

# Get the dictionary from the containing relevant data
train_val_data = train_val_data['save_data_dict'][()]

# Load test features and labels
X_test = train_val_data['X_test']
y_test = train_val_data['y_test']

### Predict using the CNN model

In [18]:
m6a_model = torch.load("models/m6ANet_other_half_hifi.3.best.torch", map_location=torch.device('cpu'))

X_test = torch.Tensor(X_test)
test_pred_cnn = m6a_model.predict(X_test, device='cpu').detach().numpy()

### Predict using the XGBoost model

In [19]:
bst2 = xgb.Booster({'nthread': 4})  # init model
bst2.load_model('models/xgboost.otherhalf.hifi.100rds.json')  # load data

X_test_2d = X_test.reshape(X_test.shape[0], X_test.shape[1] * X_test.shape[2])

dtest = xgb.DMatrix(X_test_2d)
test_pred_xgb = bst2.predict(dtest)

### Plot ROC curve for XGBoost and CNN models

In [None]:
plt.figure(figsize=(10,5))
plt.axline([0,0], slope=1, c="black")

fpr, tpr, thresholds = metrics.roc_curve(y_test[:, 0], test_pred_cnn[:, 0], pos_label=1)
print(f"AUC-ROC CNN: {metrics.auc(fpr, tpr):.03f}")
plt.plot(fpr, tpr, c='gray', label=f"CNN-AUROC: {metrics.auc(fpr, tpr):.04f}")

fpr, tpr, thresholds = metrics.roc_curve(y_test[:, 0], test_pred_xgb, pos_label=1)
print(f"AUC-ROC XGBoost: {metrics.auc(fpr, tpr):.03f}")
plt.plot(fpr, tpr, c='blue', label=f"XGBoost-AUROC: {metrics.auc(fpr, tpr):.04f}")


plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.title("ROC curve")
plt.xlabel("FPR")
plt.ylabel("TPR")
#plt.xlim([-0.1, 1.1])
#plt.ylim([0, 1.1])
plt.tight_layout()
plt.savefig("figures/roc_compare_hifi.png")
plt.plot()


### Focus on 1% FPR with the ROC curve

In [None]:
plt.figure(figsize=(10,5))
#plt.axline([0,0], slope=1, c="black")

fpr, tpr, thresholds = metrics.roc_curve(y_test[:, 0], test_pred_cnn[:, 0], pos_label=1)

idx_fpr = np.where(fpr <= 0.010)[0]

thr_last = thresholds[idx_fpr][-1]

print(f"AUC-ROC CNN: {metrics.auc(fpr, tpr):.04f}")
plt.plot(fpr[idx_fpr], tpr[idx_fpr], c='gray', label=f"CNN-AUROC: {metrics.auc(fpr, tpr):.04f}")

fpr, tpr, thresholds = metrics.roc_curve(y_test[:, 0], test_pred_xgb, pos_label=1)
idx_fpr = np.where(fpr <= 0.010)[0]

thr_last = thresholds[idx_fpr][-1]

print(f"AUC-ROC XGBoost: {metrics.auc(fpr, tpr):.04f}")
plt.plot(fpr[idx_fpr], tpr[idx_fpr], c='blue', label=f"XGBoost-AUROC: {metrics.auc(fpr, tpr):.04f}")


plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.title("ROC curve")
plt.xlabel("FPR")
plt.ylabel("TPR")
#plt.xlim([-0.1, 1.1])
#plt.ylim([0, 1.1])
plt.tight_layout()
plt.savefig("figures/roc_compare_hifi_fdr_1.png")
plt.plot()


### Plot the PR curve using XGBoost and CNN predictions on HiFi data

In [None]:
plt.figure(figsize=(8,5))
plt.axline([0,0], slope=1, c="white")
plt.axline([0,1], slope=-1, c="black")
precision, recall, thresholds = metrics.precision_recall_curve(y_test[:, 0], test_pred_cnn[:, 0], pos_label=1)
print(f"AU-PR CNN: {metrics.average_precision_score(y_test[:, 0], test_pred_cnn[:, 0])}")
plt.plot(recall, precision, c='gray', label=f"CNN-AUPR: {metrics.average_precision_score(y_test[:, 0], test_pred_cnn[:, 0]):.03f}")

precision, recall, thresholds = metrics.precision_recall_curve(y_test[:, 0], test_pred_xgb, pos_label=1)
print(f"AU-PR XGBoost: {metrics.average_precision_score(y_test[:, 0], test_pred_xgb)}")
plt.plot(recall, precision, c='blue', label=f"XGBoost-AUPR: {metrics.average_precision_score(y_test[:, 0], test_pred_xgb):.03f}")


plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.title("PR curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
#plt.xlim([0, 1])
#plt.ylim([0, 1])
plt.tight_layout()
plt.savefig("figures/pr_compare_hifi.png")
plt.plot()

### Save model for RUST

In [None]:
m6a_model = torch.load("models/m6ANet_PS00075.best.torch", map_location=torch.device('cpu'))
example = torch.rand(1, 6, 15)
traced_script_module = torch.jit.trace(m6a_model, example)
traced_script_module.save("models/m6ANet_PS00075.best.torch.pt")


In [None]:
m6a_model = torch.load("models/m6ANet_other_half_hifi.3.best.torch", map_location=torch.device('cpu'))
example = torch.rand(1, 6, 15)
traced_script_module = torch.jit.trace(m6a_model, example)
traced_script_module.save("models/m6ANet_other_half_hifi.3.best.torch_nn.pt")


## Test on different data

In [13]:
val_path = "/net/noble/vol4/noble/user/anupamaj/proj/m6A-calling/data/PS00075_2_2022-10-17.npz"
input_size = 6
# Load validation data
val_data = np.load(val_path, allow_pickle=True)
X_val = val_data['features'][:, 0:input_size, :]
y_val = val_data['labels']
    
y_val_ohe = np.zeros((len(y_val), 2))
y_val_ohe[np.where(y_val == 1)[0], 0] = 1
y_val_ohe[np.where(y_val == 0)[0], 1] = 1

In [15]:
m6a_model = torch.load("models/m6ANet_PS00075_no_init.3.best.torch")
#, map_location=torch.device('cpu')
# Print model architecture summary
summary_str = summary(m6a_model, input_size=(6, 15))

X_val = torch.Tensor(X_val)
val_pred_cnn = m6a_model.predict(X_val, device='cpu').detach().numpy()

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv1d-1               [-1, 30, 11]             930
            Conv1d-2                [-1, 10, 7]           1,510
            Conv1d-3                 [-1, 5, 5]             155
            Linear-4                    [-1, 5]             130
            Linear-5                    [-1, 2]              12
Total params: 2,737
Trainable params: 2,737
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.00
Params size (MB): 0.01
Estimated Total Size (MB): 0.01
----------------------------------------------------------------


RuntimeError: Input type (torch.FloatTensor) and weight type (torch.cuda.FloatTensor) should be the same or input should be a MKLDNN tensor and weight is a dense tensor

In [None]:
bst2 = xgb.Booster({'nthread': 4})  # init model
bst2.load_model('models/xgboost.otherhalf.hifi.100rds.json')  # load data

X_val_2d = X_val.reshape(X_val.shape[0], X_val.shape[1] * X_val.shape[2])

dval = xgb.DMatrix(X_val_2d)
val_pred_xgb = bst2.predict(dval)

In [None]:
plt.figure(figsize=(10,5))
plt.axline([0,0], slope=1, c="black")

fpr, tpr, thresholds = metrics.roc_curve(y_val_ohe[:, 0], val_pred_cnn[:, 0], pos_label=1)
print(f"AUC-ROC CNN: {metrics.auc(fpr, tpr):.03f}")
plt.plot(fpr, tpr, c='gray', label=f"CNN-AUROC: {metrics.auc(fpr, tpr):.04f}")


plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.title("ROC curve")
plt.xlabel("FPR")
plt.ylabel("TPR")
#plt.xlim([-0.1, 1.1])
#plt.ylim([0, 1.1])
plt.tight_layout()
plt.savefig("figures/roc_compare_hifi_PS00075.png")
plt.plot()


In [None]:
plt.figure(figsize=(10,5))
#plt.axline([0,0], slope=1, c="black")

fpr, tpr, thresholds = metrics.roc_curve(y_val_ohe[:, 0], val_pred_cnn[:, 0], pos_label=1)

idx_fpr = np.where(fpr <= 0.050)[0]

thr_last = thresholds[idx_fpr][-1]
print(thresholds[idx_fpr])

print(thr_last)

print(f"AUC-ROC CNN: {metrics.auc(fpr, tpr):.04f}")
plt.plot(fpr[idx_fpr], tpr[idx_fpr], c='gray', label=f"CNN-AUROC: {metrics.auc(fpr, tpr):.04f}")


plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.title("ROC curve")
plt.xlabel("FPR")
plt.ylabel("TPR")
#plt.xlim([-0.1, 1.1])
#plt.ylim([0, 1.1])
plt.tight_layout()
plt.savefig("figures/roc_compare_hifi_fdr_1_PS00075.png")
plt.plot()


In [None]:
plt.figure(figsize=(8,5))
plt.axline([0,0], slope=1, c="white")
plt.axline([0,1], slope=-1, c="black")
precision, recall, thresholds = metrics.precision_recall_curve(y_val_ohe[:, 0], val_pred_cnn[:, 0], pos_label=1)
print(f"AU-PR CNN: {metrics.average_precision_score(y_val_ohe[:, 0], val_pred_cnn[:, 0])}")
plt.plot(recall, precision, c='gray', label=f"CNN-AUPR: {metrics.average_precision_score(y_val_ohe[:, 0], val_pred_cnn[:, 0]):.03f}")


plt.legend(bbox_to_anchor=(1.1, 1.05))
plt.title("PR curve")
plt.xlabel("Recall")
plt.ylabel("Precision")
#plt.xlim([0, 1])
#plt.ylim([0, 1])
plt.tight_layout()
plt.savefig("figures/pr_compare_hifi_PS00075.png")
plt.plot()

In [None]:
plt.figure(figsize=(8,5))
#plt.axline([0,0], slope=1, c="white")
#plt.axline([0,1], slope=-1, c="black")
precision, recall, thresholds = metrics.precision_recall_curve(y_val_ohe[:, 0], val_pred_cnn[:, 0], pos_label=1)
print(len(precision), len(thresholds))
print(f"AU-PR CNN: {metrics.average_precision_score(y_val_ohe[:, 0], val_pred_cnn[:, 0])}")
plt.plot(thresholds, precision[1:], c='gray', label=f"threshold vs precision")
plt.plot(thresholds, recall[1:], c='red', label=f"threshold vs recall")


plt.legend(bbox_to_anchor=(1.1, 1.05))
#plt.title("threshold vs precision")
plt.xlabel("thresholds")
#plt.ylabel("Precision")
#plt.xlim([0, 1])
#plt.ylim([0, 1])
plt.tight_layout()
plt.savefig("figures/pr_compare_hifi_PS00075.png")
plt.plot()

In [None]:
m6a_model = torch.load("models/m6ANet_PS00075_no_init.3.best.torch", map_location=torch.device('cpu'))
example = torch.rand(1, 6, 15)
traced_script_module = torch.jit.trace(m6a_model, example)
traced_script_module.save("models/m6ANet_PS00075_no_init.3.best.torch.pt")


In [4]:
m6a_model = torch.load("models/m6ANet_PS00075_semi_supervised.3.best.torch", map_location=torch.device('cpu'))
example = torch.rand(1, 6, 15)
traced_script_module = torch.jit.trace(m6a_model, example)
traced_script_module.save("models/m6ANet_PS00075_semi_supervised.3.best.torch.pt")