In [None]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

Data prep

In [2]:
def fasta_to_dataframe(fasta_file):
    records = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        records.append({"id": record.id.split('|')[0]
                        , "db": record.id.split('|')[2]
                        , "type": record.id.split('|')[3]
                        , "sequence": str(record.seq)})
    return pd.DataFrame(records)

data = fasta_to_dataframe("../data/database/v1/features.fasta")

In [3]:
uniprot_data = data[data['db'] == 'UNIPROT']
card_ardb_data = data[data['db'].isin(['CARD', 'ARDB'])]

In [4]:
def generate_feature_matrix(scoring_path, query_df, reference_df):
    # Extract all unique query and target IDs
    query_ids = [f"{row.id}_{row.type}" for _, row in query_df.iterrows()]
    target_ids = [f"{row.id}_{row.type}" for _, row in reference_df.iterrows()]

    # Initialize a dictionary to hold scores with 0s for all query-target pairs
    scores_dict = {query_id: {target_id: 0.0 for target_id in target_ids} for query_id in query_ids}

    # Populate the scores_dict with bit scores from DIAMOND output
    with open(scoring_path) as f:
        for line in f:
            query_id, subject_id, identity, alignment_length, mismatches, gap_opens, q_start, q_end, s_start, s_end, evalue, bit_score = line.strip().split()
            if query_id in scores_dict and subject_id in scores_dict[query_id]:
                scores_dict[query_id][subject_id] = float(bit_score)
    
    # Convert scores_dict to a feature matrix
    all_scores = []
    for query_id in query_ids:
        # Each row is a list of bit scores for the current query against each target in the reference
        scores = [scores_dict[query_id][target_id] for target_id in target_ids]
        all_scores.append(scores)
    
    # Convert to numpy array for matrix operations
    feature_matrix = np.array(all_scores)

    # Normalize the matrix
    scaler = MinMaxScaler()
    normalized_features = scaler.fit_transform(feature_matrix)
    
    return normalized_features

In [5]:
feature_matrix = generate_feature_matrix('out.tsv', uniprot_data, card_ardb_data)


In [6]:
X_train, X_val, y_train, y_val = train_test_split(feature_matrix, uniprot_data['type'], test_size=0.3, random_state=42)

In [7]:
label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train) 
y_val_encoded = label_encoder.transform(y_val)

In [8]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score

Run single logistic regression on training data, no CV

In [9]:
logistic_model = LogisticRegression(max_iter=1000, random_state=42)

# Train the model on the training data
logistic_model.fit(X_train, y_train_encoded)

# Make predictions on the validation set
y_val_pred = logistic_model.predict(X_val)

# Evaluate the model
accuracy = accuracy_score(y_val_encoded, y_val_pred)
print(f"Validation Accuracy: {accuracy:.4f}")

present_classes = sorted(set(y_val_encoded))

# Generate the classification report
class_report = classification_report(
    y_val_encoded, 
    y_val_pred, 
    labels=present_classes, 
    target_names=label_encoder.inverse_transform(present_classes)
)
print("Classification Report:\n")
print(class_report)

Validation Accuracy: 0.9925
Classification Report:

                                     precision    recall  f1-score   support

                     aminoglycoside       1.00      0.99      0.99       156
                         bacitracin       1.00      1.00      1.00      1206
                        beta_lactam       1.00      1.00      1.00      1071
                    chloramphenicol       1.00      0.98      0.99       122
                         fosfomycin       1.00      1.00      1.00        68
                       glycopeptide       0.00      0.00      0.00         9
macrolide-lincosamide-streptogramin       0.91      1.00      0.96       258
                          multidrug       1.00      0.96      0.98        47
                          mupirocin       0.00      0.00      0.00         1
                          polymyxin       1.00      1.00      1.00       246
                          quinolone       0.00      0.00      0.00         1
                       

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Feature weights track with what you expect

In [34]:
bacitracin_index = list(label_encoder.classes_).index('bacitracin')
coefficients = logistic_model.coef_[bacitracin_index]

feature_weights = pd.DataFrame({
    "Feature": [f"{name}_{i}" for i,name in enumerate(card_ardb_data.type.values)],  # Replace with actual feature names if available
    'arg_type':[name for i,name in enumerate(card_ardb_data.type.values)],
    "Coefficient": coefficients
}).sort_values(by="Coefficient", ascending=False)

# Display the top positive and negative weights
print("Top positive weights for 'bacitracin':")
print(feature_weights.head(20))

print("\nTop negative weights for 'bacitracin':")
print(feature_weights.tail(10))

Top positive weights for 'bacitracin':
              Feature    arg_type  Coefficient
1312  bacitracin_1312  bacitracin     0.313136
1625  bacitracin_1625  bacitracin     0.300130
3057  bacitracin_3057  bacitracin     0.299529
1582  bacitracin_1582  bacitracin     0.298782
302    bacitracin_302  bacitracin     0.289606
1593  bacitracin_1593  bacitracin     0.284403
2256  bacitracin_2256  bacitracin     0.282740
1313  bacitracin_1313  bacitracin     0.278951
2357  bacitracin_2357  bacitracin     0.278115
1483  bacitracin_1483  bacitracin     0.276539
3630  bacitracin_3630  bacitracin     0.276062
1470  bacitracin_1470  bacitracin     0.275800
3546  bacitracin_3546  bacitracin     0.275439
128    bacitracin_128  bacitracin     0.274780
3495  bacitracin_3495  bacitracin     0.272342
1651  bacitracin_1651  bacitracin     0.272140
1014  bacitracin_1014  bacitracin     0.269386
2776  bacitracin_2776  bacitracin     0.268890
1324  bacitracin_1324  bacitracin     0.268694
527    bacitracin_527

In [35]:
feature_weights[feature_weights['arg_type'] == 'bacitracin'].Coefficient.mean()

0.19577667367490856

In [36]:
feature_weights[feature_weights['arg_type'] != 'bacitracin'].Coefficient.mean()

-0.0025781555416145934

In [25]:
beta_lactam_index = list(label_encoder.classes_).index('beta_lactam')
coefficients = logistic_model.coef_[beta_lactam_index]

feature_weights = pd.DataFrame({
    "Feature": [f"{name}_{i}" for i,name in enumerate(card_ardb_data.type.values)],  # Replace with actual feature names if available
    'arg_type':[name for i,name in enumerate(card_ardb_data.type.values)],
    "Coefficient": coefficients
}).sort_values(by="Coefficient", ascending=False)

# Display the top positive and negative weights
print("Top positive weights for 'beta_lactam':")
print(feature_weights.head(20))

print("\nTop negative weights for 'beta_lactam':")
print(feature_weights.tail(10))

Top positive weights for 'beta_lactam':
               Feature     arg_type  Coefficient
11      beta_lactam_11  beta_lactam     0.756679
2993  beta_lactam_2993  beta_lactam     0.674931
1591  beta_lactam_1591  beta_lactam     0.674307
1668  beta_lactam_1668  beta_lactam     0.672002
3353  beta_lactam_3353  beta_lactam     0.659623
2891  beta_lactam_2891  beta_lactam     0.476856
970    beta_lactam_970  beta_lactam     0.471357
869    beta_lactam_869  beta_lactam     0.471275
1109  beta_lactam_1109  beta_lactam     0.469572
3279  beta_lactam_3279  beta_lactam     0.469572
2621  beta_lactam_2621  beta_lactam     0.443920
3713  beta_lactam_3713  beta_lactam     0.437421
3014  beta_lactam_3014  beta_lactam     0.429227
1896  beta_lactam_1896  beta_lactam     0.426277
182    beta_lactam_182  beta_lactam     0.417296
2110  beta_lactam_2110  beta_lactam     0.416372
3766  beta_lactam_3766  beta_lactam     0.411353
1578  beta_lactam_1578  beta_lactam     0.358287
328    beta_lactam_328  beta_

In [27]:
feature_weights[feature_weights['arg_type'] == 'beta_lactam'].Coefficient.mean()

0.09147049442144657

In [28]:
feature_weights[feature_weights['arg_type'] != 'beta_lactam'].Coefficient.mean()

-0.009917753953551605

# Repeat for short reads model

In [24]:
def split_to_short_reads(fasta_file, output_file, read_length=33):
    short_reads = []
    read_ids = []
    types = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence = str(record.seq)
        arg_type = '_'.join(record.id.split('_')[1:])
        # Generate short reads
        for i in range(0, len(sequence), read_length):
            if i + read_length <= len(sequence):  # Ensure we don't exceed the sequence length
                short_read = Seq(sequence[i:i + read_length])
                read_id = f"{record.id}_pos_{i}"
                short_reads.append(SeqRecord(short_read, id=read_id, description=""))
                read_ids.append(read_id)
                types.append(arg_type)

    # Write the short reads to a new fasta file
    SeqIO.write(short_reads, output_file, "fasta")
    return short_reads,read_ids,types

In [27]:
input_fasta = "uniprot_sequences.fasta"
output_fasta = "short_reads.fasta"
short_reads,read_ids,types = split_to_short_reads(input_fasta, output_fasta)

In [28]:
import pickle
with open('feature_matrix.pkl', 'rb') as handle:
    feature_matrix = pickle.load(handle)

In [29]:
X_train, X_test, y_train, y_test = train_test_split(feature_matrix, types, test_size=0.3, random_state=42)

In [30]:
label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train) 
y_test_encoded = label_encoder.transform(y_test)

In [31]:
logistic_model = LogisticRegression(max_iter=1000, random_state=42)

# Train the model on the training data
logistic_model.fit(X_train, y_train_encoded)

# Make predictions on the validation set
y_test_pred = logistic_model.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test_encoded, y_test_pred)
print(f"Validation Accuracy: {accuracy:.4f}")

present_classes = sorted(set(y_test_encoded))

# Generate the classification report
class_report = classification_report(
    y_test_encoded, 
    y_test_pred, 
    labels=present_classes, 
    target_names=label_encoder.inverse_transform(present_classes)
)
print("Classification Report:\n")
print(class_report)

Validation Accuracy: 0.7982
Classification Report:

                                     precision    recall  f1-score   support

                     aminoglycoside       1.00      0.68      0.81      1063
                         bacitracin       0.61      1.00      0.76      9402
                        beta_lactam       1.00      0.77      0.87      9510
                    chloramphenicol       1.00      0.65      0.79       678
                         fosfomycin       1.00      0.79      0.89       292
                       fosmidomycin       1.00      1.00      1.00         3
                       glycopeptide       1.00      0.04      0.07        78
macrolide-lincosamide-streptogramin       1.00      0.47      0.64      4324
                          multidrug       1.00      0.79      0.88       585
                          mupirocin       1.00      0.13      0.23        39
                          polymyxin       1.00      0.82      0.90      3790
                       

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


# Repeat for DNA models

In [32]:
with open('lr_dna_feature_matrix.pkl', 'rb') as handle:
    feature_matrix = pickle.load(handle)

In [33]:
dna_data = pd.read_csv('all_df_v2.csv')
uniprot_data = dna_data[dna_data['db'] == 'UNIPROT']

In [34]:
def contains_invalid_dna_bases(sequence):
    valid_bases = {'A', 'T', 'C', 'G'}
    return any(base not in valid_bases for base in sequence.upper())

uniprot_data = uniprot_data[~uniprot_data['dna_seq'].apply(contains_invalid_dna_bases)]

In [35]:
X_train, X_test, y_train, y_test = train_test_split(feature_matrix, uniprot_data.type, test_size=0.3, random_state=123)

label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train) 
y_test_encoded = label_encoder.transform(y_test)

In [36]:
logistic_model = LogisticRegression(max_iter=1000, random_state=42)

# Train the model on the training data
logistic_model.fit(X_train, y_train_encoded)

# Make predictions on the validation set
y_test_pred = logistic_model.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test_encoded, y_test_pred)
print(f"Validation Accuracy: {accuracy:.4f}")

present_classes = sorted(set(y_test_encoded))

# Generate the classification report
class_report = classification_report(
    y_test_encoded, 
    y_test_pred, 
    labels=present_classes, 
    target_names=label_encoder.inverse_transform(present_classes)
)
print("Classification Report:\n")
print(class_report)

Validation Accuracy: 0.8995
Classification Report:

                                     precision    recall  f1-score   support

                     aminoglycoside       0.86      0.87      0.87       187
                         bacitracin       0.93      0.93      0.93      1145
                        beta_lactam       0.92      0.92      0.92      1044
                    chloramphenicol       0.84      0.83      0.84       124
                         fosfomycin       0.83      0.84      0.84        58
                       glycopeptide       0.00      0.00      0.00         6
macrolide-lincosamide-streptogramin       0.85      0.88      0.87       258
                          multidrug       0.82      0.77      0.80        48
                          mupirocin       0.00      0.00      0.00         1
                          polymyxin       0.81      0.78      0.80       242
                       tetracycline       0.00      0.00      0.00         1
                       

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [37]:
with open('sr_dna_feature_matrix.pkl', 'rb') as handle:
    feature_matrix = pickle.load(handle)

In [38]:
def split_to_short_reads(fasta_file, output_file, read_length=100):
    short_reads = []
    read_ids = []
    types = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence = str(record.seq)
        arg_type = '_'.join(record.id.split('_')[1:])
        # Generate short reads
        for i in range(0, len(sequence), read_length):
            if i + read_length <= len(sequence):  # Ensure we don't exceed the sequence length
                short_read = Seq(sequence[i:i + read_length])
                read_id = f"{record.id}_pos_{i}"
                short_reads.append(SeqRecord(short_read, id=read_id, description=""))
                read_ids.append(read_id)
                types.append(arg_type)

    # Write the short reads to a new fasta file
    SeqIO.write(short_reads, output_file, "fasta")
    return short_reads,read_ids,types

input_fasta = "uniprot_dna_sequences.fasta"
output_fasta = "dna_short_reads.fasta"
short_reads,read_ids,types = split_to_short_reads(input_fasta, output_fasta)

In [39]:
X_train, X_test, y_train, y_test = train_test_split(feature_matrix, types, test_size=0.3, random_state=123)

label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train) 
y_test_encoded = label_encoder.transform(y_test)

In [40]:
logistic_model = LogisticRegression(max_iter=1000, random_state=42)

# Train the model on the training data
logistic_model.fit(X_train, y_train_encoded)

# Make predictions on the validation set
y_test_pred = logistic_model.predict(X_test)

# Evaluate the model
accuracy = accuracy_score(y_test_encoded, y_test_pred)
print(f"Validation Accuracy: {accuracy:.4f}")

present_classes = sorted(set(y_test_encoded))

# Generate the classification report
class_report = classification_report(
    y_test_encoded, 
    y_test_pred, 
    labels=present_classes, 
    target_names=label_encoder.inverse_transform(present_classes)
)
print("Classification Report:\n")
print(class_report)

Validation Accuracy: 0.7083
Classification Report:

                                     precision    recall  f1-score   support

                     aminoglycoside       0.92      0.56      0.69      1161
                         bacitracin       0.56      0.93      0.70      9415
                        beta_lactam       0.92      0.69      0.79      9351
                    chloramphenicol       0.82      0.40      0.53       761
                         fosfomycin       0.87      0.52      0.65       307
                       glycopeptide       1.00      0.06      0.12        32
macrolide-lincosamide-streptogramin       0.82      0.40      0.53      3897
                          multidrug       0.80      0.64      0.71       501
                          mupirocin       1.00      0.03      0.06        31
                          polymyxin       0.81      0.66      0.73      3370
                          quinolone       0.00      0.00      0.00         5
                       

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
