# Task 1: Develop a machine learning method to identify RNA modifications from direct RNA-Seq data

Write a computational method that predicts m6A RNA modification from direct RNA-Seq data. The method should be able to train a new new model, and make predictions on unseen test data. Specifically, your method should fullfil the following requirements:

Your method should contain two scripts, one for model training, and one for making predictions. The prediction script will be evaluated by other students.

In [36]:
import json
import os
import sys
import gzip

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, balanced_accuracy_score

#### 1.1 Read M6A Labels

In [37]:
os.listdir("data/")

['data.info.labelled', 'dataset0.json.gz']

In [38]:
# I'm putting these files into my ~/data, if it's different for u, change the path accordingly :)
M6A_FILE_PATH = "data/data.info.labelled"
DIRECT_RNA_SEQ_DATA_FILE_PATH = "data/dataset0.json.gz"

In [39]:
# Read m6a labels
def read_m6A_labels(m6a_file_path):
    m6a_df = pd.read_csv(m6a_file_path, sep=",")
    m6a_df.columns = ["gene_id", "transcript_id", "transcript_position", "label"]
    return m6a_df

In [40]:
m6a_labels_df = read_m6A_labels(M6A_FILE_PATH)
m6a_labels_df.head()

Unnamed: 0,gene_id,transcript_id,transcript_position,label
0,ENSG00000004059,ENST00000000233,244,0
1,ENSG00000004059,ENST00000000233,261,0
2,ENSG00000004059,ENST00000000233,316,0
3,ENSG00000004059,ENST00000000233,332,0
4,ENSG00000004059,ENST00000000233,368,0


#### 1.2 Read direct rna seq data

In [65]:
def convert_nucleotide_to_index(nucleotide):
    nucleotide = nucleotide.lower()
    if nucleotide == "a":
        return 0
    elif nucleotide == "c":
        return 1
    elif nucleotide == "g":
        return 2
    elif nucleotide == "t":
        return 3

In [66]:
def read_direct_rna_seq_data(data_path):
    data = []
    with gzip.open(data_path, 'rt') as f:
        for line in f:
            line_data = json.loads(line)
            for transcript_id, position_data in line_data.items():
                for transcript_position, combined_nucleotides_data in position_data.items():
                    for combined_nucleotide, reads in combined_nucleotides_data.items():
                        nucleotide_1, nucleotide_2, nucleotide_3, nucleotide_4, nucleotide_5, nucleotide_6, nucleotide_7 = combined_nucleotide
                        nucleotide_1_idx, nucleotide_2_idx, nucleotide_3_idx, nucleotide_4_idx, nucleotide_5_idx, nucleotide_6_idx, nucleotide_7_idx = convert_nucleotide_to_index(nucleotide_1), convert_nucleotide_to_index(nucleotide_2), convert_nucleotide_to_index(nucleotide_3), convert_nucleotide_to_index(nucleotide_4), convert_nucleotide_to_index(nucleotide_5), convert_nucleotide_to_index(nucleotide_6), convert_nucleotide_to_index(nucleotide_7)
                        for read_idx, read in enumerate(reads):
                            data.append({
                                'transcript_id': transcript_id,
                                'transcript_position': int(transcript_position),
                                'combined_nucleotide': combined_nucleotide,
                                'nucleotide_1_index': nucleotide_1_idx,
                                'nucleotide_2_index': nucleotide_2_idx,
                                'nucleotide_3_index': nucleotide_3_idx,
                                'nucleotide_4_index': nucleotide_4_idx,
                                'nucleotide_5_index': nucleotide_5_idx,
                                'nucleotide_6_index': nucleotide_6_idx,
                                'nucleotide_7_index': nucleotide_7_idx,
                                'read_id': read_idx,
                                'x_1': read[0],
                                'x_2': read[1],
                                'x_3': read[2],
                                'x_4': read[3],
                                'x_5': read[4],
                                'x_6': read[5],
                                'x_7': read[6],
                                'x_8': read[7],
                                'x_9': read[8]
                            })

    df = pd.DataFrame(data)
    return df


In [68]:
rna_seq_data_df = read_direct_rna_seq_data(DIRECT_RNA_SEQ_DATA_FILE_PATH)

In [69]:
# rna_seq_data.head()
rna_seq_data_df.head(100)

Unnamed: 0,transcript_id,transcript_position,combined_nucleotide,nucleotide_1_index,nucleotide_2_index,nucleotide_3_index,nucleotide_4_index,nucleotide_5_index,nucleotide_6_index,nucleotide_7_index,read_id,x_1,x_2,x_3,x_4,x_5,x_6,x_7,x_8,x_9
0,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,0,0.00299,2.06,125.0,0.01770,10.40,122.0,0.00930,10.90,84.1
1,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,1,0.00631,2.53,125.0,0.00844,4.67,126.0,0.01030,6.30,80.9
2,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,2,0.00465,3.92,109.0,0.01360,12.00,124.0,0.00498,2.13,79.6
3,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,3,0.00398,2.06,125.0,0.00830,5.01,130.0,0.00498,3.78,80.4
4,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,4,0.00664,2.92,120.0,0.00266,3.94,129.0,0.01300,7.15,82.2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,95,0.00590,5.57,126.0,0.01200,11.20,127.0,0.00564,9.24,87.3
96,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,96,0.01200,3.73,124.0,0.02520,14.40,123.0,0.00510,4.16,81.2
97,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,97,0.01350,4.09,126.0,0.00540,5.71,127.0,0.00396,3.48,81.4
98,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,98,0.00830,4.17,121.0,0.00973,5.68,124.0,0.00316,1.60,81.8


#### 1.3 Merge direct rna seq data with labels

In [70]:
rna_seq_data_with_labels_df = rna_seq_data_df.merge(m6a_labels_df, on=["transcript_id", "transcript_position"], how="left")

In [71]:
rna_seq_data_with_labels_df.head()

Unnamed: 0,transcript_id,transcript_position,combined_nucleotide,nucleotide_1_index,nucleotide_2_index,nucleotide_3_index,nucleotide_4_index,nucleotide_5_index,nucleotide_6_index,nucleotide_7_index,...,x_2,x_3,x_4,x_5,x_6,x_7,x_8,x_9,gene_id,label
0,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,...,2.06,125.0,0.0177,10.4,122.0,0.0093,10.9,84.1,ENSG00000004059,0
1,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,...,2.53,125.0,0.00844,4.67,126.0,0.0103,6.3,80.9,ENSG00000004059,0
2,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,...,3.92,109.0,0.0136,12.0,124.0,0.00498,2.13,79.6,ENSG00000004059,0
3,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,...,2.06,125.0,0.0083,5.01,130.0,0.00498,3.78,80.4,ENSG00000004059,0
4,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,...,2.92,120.0,0.00266,3.94,129.0,0.013,7.15,82.2,ENSG00000004059,0


In [73]:
print(f"Number of rows with no labels: {len(rna_seq_data_with_labels_df[rna_seq_data_with_labels_df['label'].isnull()])}")

Number of rows with no labels: 0


##### Strategy 1: Get Aggregate of all the reads

In [75]:
aggregated_df = rna_seq_data_df.drop(columns=['read_id']).groupby(['transcript_id', 'transcript_position', 'combined_nucleotide','nucleotide_1_index', 'nucleotide_2_index', 'nucleotide_3_index', 'nucleotide_4_index', 'nucleotide_5_index', 'nucleotide_6_index', 'nucleotide_7_index']).mean().reset_index()
aggregated_with_labels_df = aggregated_df.merge(m6a_labels_df, on=["transcript_id", "transcript_position"], how="left")
aggregated_with_labels_df.head(100)

Unnamed: 0,transcript_id,transcript_position,combined_nucleotide,nucleotide_1_index,nucleotide_2_index,nucleotide_3_index,nucleotide_4_index,nucleotide_5_index,nucleotide_6_index,nucleotide_7_index,...,x_2,x_3,x_4,x_5,x_6,x_7,x_8,x_9,gene_id,label
0,ENST00000000233,244,AAGACCA,0,0,2,0,1,1,0,...,4.223784,123.702703,0.009373,7.382162,125.913514,0.007345,4.386989,80.570270,ENSG00000004059,0
1,ENST00000000233,261,CAAACTG,1,0,0,0,1,3,2,...,3.216424,109.681395,0.006813,3.226535,107.889535,0.007710,3.016599,94.290698,ENSG00000004059,0
2,ENST00000000233,316,GAAACAG,2,0,0,0,1,0,2,...,2.940541,105.475676,0.007416,3.642703,98.947027,0.007555,2.087146,89.364324,ENSG00000004059,0
3,ENST00000000233,332,AGAACAT,0,2,0,0,1,0,3,...,6.476350,129.355000,0.008632,2.899200,97.836500,0.006101,2.236520,89.154000,ENSG00000004059,0
4,ENST00000000233,368,AGGACAA,0,2,2,0,1,0,0,...,6.415051,117.924242,0.011479,5.870303,121.954545,0.010019,4.260253,85.178788,ENSG00000004059,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
95,ENST00000001008,2242,GAAACAA,2,0,0,0,1,0,0,...,3.393768,104.312077,0.009123,3.659807,99.084541,0.007017,2.583527,88.815942,ENSG00000004478,0
96,ENST00000002165,35,AGGACAG,0,2,2,0,1,0,2,...,7.024259,116.333333,0.009167,7.397778,119.327778,0.006937,2.754444,83.496296,ENSG00000001036,0
97,ENST00000002165,54,GGGACAT,2,2,2,0,1,0,3,...,3.403200,119.960000,0.008390,6.730200,121.180000,0.006244,3.840600,83.824000,ENSG00000001036,0
98,ENST00000002165,207,TTGACCA,3,3,2,0,1,1,0,...,3.563962,101.907547,0.005667,7.683396,116.830189,0.006573,4.287358,77.590566,ENSG00000001036,0


### 2. Train Test Split

In [76]:
features = aggregated_with_labels_df[['nucleotide_1_index', 'nucleotide_2_index', 'nucleotide_3_index', 'nucleotide_4_index', 'nucleotide_5_index', 'nucleotide_6_index', 'nucleotide_7_index','x_1', 'x_2', 'x_3', 'x_4', 'x_5', 'x_6', 'x_7', 'x_8', 'x_9']]
labels = aggregated_with_labels_df['label']

In [77]:
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.3, random_state=42)

In [78]:
print(f"X_train shape: {X_train.shape}")
print(f"X_test shape: {X_test.shape}")
print(f"y_train shape: {y_train.shape}")
print(f"y_test shape: {y_test.shape}")

X_train shape: (85286, 16)
X_test shape: (36552, 16)
y_train shape: (85286,)
y_test shape: (36552,)


From the numbers below, we can see that the labels data is extremely inbalanced

In [79]:
y_train.value_counts()

label
0    81427
1     3859
Name: count, dtype: int64

### 3. Model

In [80]:
model = LogisticRegression(class_weight='balanced') # added 'balanced' to class weights to mitigate the effects of imbalanced dataset
model.fit(X_train, y_train)

y_pred = model.predict(X_test)

STOP: TOTAL NO. of ITERATIONS REACHED LIMIT.

Increase the number of iterations (max_iter) or scale the data as shown in:
    https://scikit-learn.org/stable/modules/preprocessing.html
Please also refer to the documentation for alternative solver options:
    https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression
  n_iter_i = _check_optimize_result(


In [81]:
y_pred = model.predict(X_test)

In [82]:
import pickle

with open('y_pred.pkl', 'wb') as file:
    pickle.dump(y_pred, file)

### 4. Evaluation

Just going to use a basic balanced accuracy here for now. We can explore the usage of ROC next.

In [83]:
balanced_accuracy = balanced_accuracy_score(y_test, y_pred)
report = classification_report(y_test, y_pred)

In [84]:
print(f"Balanced Accuracy: {balanced_accuracy}")
print("Classification Report:")
print(report)

Balanced Accuracy: 0.7579866692588655
Classification Report:
              precision    recall  f1-score   support

           0       0.99      0.76      0.86     34936
           1       0.13      0.76      0.22      1616

    accuracy                           0.76     36552
   macro avg       0.56      0.76      0.54     36552
weighted avg       0.95      0.76      0.83     36552



# XGBoost Model

we first try an XGBoost model to see its performance

In [85]:
import xgboost as xgb

scale_pos_weight = len(y_train) / (sum(y_train == 1))

# Create an XGBoost Classifier
model = xgb.XGBClassifier(scale_pos_weight=scale_pos_weight, random_state=42)

# Train the model
model.fit(X_train, y_train)

# Make predictions
y_pred = model.predict(X_test)

In [86]:
# Evaluate the model
print(classification_report(y_test, y_pred))

balanced_accuracy = balanced_accuracy_score(y_test, y_pred)

print(balanced_accuracy)



              precision    recall  f1-score   support

           0       0.99      0.91      0.95     34936
           1       0.27      0.70      0.39      1616

    accuracy                           0.90     36552
   macro avg       0.63      0.81      0.67     36552
weighted avg       0.95      0.90      0.92     36552

0.8078679798080564


Now, we try gridsearch to optimize the parameters used during training

In [87]:
import xgboost as xgb
from sklearn.model_selection import RandomizedSearchCV
from sklearn.metrics import classification_report, balanced_accuracy_score
import numpy as np

# Define the parameter distribution
param_dist = {
    'n_estimators': [100, 200, 300],
    'max_depth': [3, 5, 7, 9],
    'learning_rate': np.arange(0.01, 0.3, 0.05),
    'subsample': np.arange(0.5, 1.0, 0.1),
    'colsample_bytree': np.arange(0.5, 1.0, 0.1),
    'gamma': [0, 0.1, 0.2, 0.3],
    'reg_alpha': [0, 0.01, 0.1, 1],
    'reg_lambda': [0, 0.01, 0.1, 1]
}

# Create the XGBoost classifier
xgb_model = xgb.XGBClassifier(scale_pos_weight=scale_pos_weight, random_state=42)

# Create the RandomizedSearchCV object
random_search = RandomizedSearchCV(estimator=xgb_model,
                                   param_distributions=param_dist,
                                   scoring='balanced_accuracy',
                                   n_iter=100,  # Number of parameter settings to sample
                                   cv=3,
                                   n_jobs=-1,
                                   random_state=42)

# Fit the random search to the data
random_search.fit(X_train, y_train)

# Get the best parameters
best_params = random_search.best_params_
print("Best Parameters:", best_params)

# Use the best parameters to create a new model
best_model = xgb.XGBClassifier(**best_params, scale_pos_weight=scale_pos_weight, random_state=42)

# Train the best model
best_model.fit(X_train, y_train)

# Make predictions
y_pred_best = best_model.predict(X_test)

# Evaluate the best model
print(classification_report(y_test, y_pred_best))
balanced_accuracy_best = balanced_accuracy_score(y_test, y_pred_best)
print("Balanced Accuracy:", balanced_accuracy_best)


Best Parameters: {'subsample': 0.7999999999999999, 'reg_lambda': 0.1, 'reg_alpha': 1, 'n_estimators': 300, 'max_depth': 3, 'learning_rate': 0.21000000000000002, 'gamma': 0.1, 'colsample_bytree': 0.7999999999999999}
              precision    recall  f1-score   support

           0       0.99      0.86      0.92     34936
           1       0.20      0.78      0.32      1616

    accuracy                           0.85     36552
   macro avg       0.59      0.82      0.62     36552
weighted avg       0.95      0.85      0.89     36552

Balanced Accuracy: 0.8164444829243629


The performance is significantly higher than the initial logistics regression model, so i think this can be our benchmark model