In [1]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from sklearn.metrics import f1_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

import subprocess
import pickle

# This notebook reproduces the DeepARG model using short read DNA sequences.

In [2]:
with open('sr_dna_feature_matrix.pkl', 'rb') as handle:
    feature_matrix = pickle.load(handle)

In [3]:
dna_data = pd.read_csv('all_df_v2.csv')
uniprot_data = dna_data[dna_data['db'] == 'UNIPROT']

In [4]:
def contains_invalid_dna_bases(sequence):
    valid_bases = {'A', 'T', 'C', 'G'}
    return any(base not in valid_bases for base in sequence.upper())

uniprot_data = uniprot_data[~uniprot_data['dna_seq'].apply(contains_invalid_dna_bases)]

In [5]:
def split_to_short_reads(fasta_file, output_file, read_length=100):
    short_reads = []
    read_ids = []
    types = []
    for record in SeqIO.parse(fasta_file, "fasta"):
        sequence = str(record.seq)
        arg_type = '_'.join(record.id.split('_')[1:])
        # Generate short reads
        for i in range(0, len(sequence), read_length):
            if i + read_length <= len(sequence):  # Ensure we don't exceed the sequence length
                short_read = Seq(sequence[i:i + read_length])
                read_id = f"{record.id}_pos_{i}"
                short_reads.append(SeqRecord(short_read, id=read_id, description=""))
                read_ids.append(read_id)
                types.append(arg_type)

    # Write the short reads to a new fasta file
    SeqIO.write(short_reads, output_file, "fasta")
    return short_reads,read_ids,types

input_fasta = "uniprot_dna_sequences.fasta"
output_fasta = "dna_short_reads.fasta"
short_reads,read_ids,types = split_to_short_reads(input_fasta, output_fasta)

In [6]:
X_train, X_test, y_train, y_test = train_test_split(feature_matrix, types, test_size=0.3, random_state=123)

label_encoder = LabelEncoder()
y_train_encoded = label_encoder.fit_transform(y_train) 
y_test_encoded = label_encoder.transform(y_test)

In [7]:
X_train_tensor = torch.tensor(X_train, dtype=torch.float32)
y_train_tensor = torch.tensor(y_train_encoded, dtype=torch.long)
X_test_tensor = torch.tensor(X_test, dtype=torch.float32)
y_test_tensor = torch.tensor(y_test_encoded, dtype=torch.long)

In [8]:
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

In [9]:
class DeepARGMLP(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(DeepARGMLP, self).__init__()
        self.fc1 = nn.Linear(input_dim, 2000)
        self.fc2 = nn.Linear(2000, 1000)
        self.fc3 = nn.Linear(1000, 500)
        self.fc4 = nn.Linear(500, 100)
        self.output = nn.Linear(100, output_dim)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout(x)
        x = torch.relu(self.fc3(x))
        x = self.dropout(x)
        x = torch.relu(self.fc4(x))
        x = self.dropout(x)
        x = self.output(x)
        return torch.softmax(x, dim=1)

In [10]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32)

In [11]:
input_dim = X_train.shape[1]
num_classes = uniprot_data['type'].nunique()

model = DeepARGMLP(input_dim=input_dim, output_dim=num_classes)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)

In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [13]:
kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
best_model = None
best_f1_score = 0  # Track the best F1 score

# Cross-validation loop
for fold, (train_index, val_index) in enumerate(kf.split(X_train_tensor, y_train_tensor)):
    print(f"Starting Fold {fold + 1}")
    
    # Split data into train and validation sets for this fold
    X_fold_train, X_fold_val = X_train_tensor[train_index], X_train_tensor[val_index]
    y_fold_train, y_fold_val = y_train_tensor[train_index], y_train_tensor[val_index]
    
    # Move data to the specified device
    X_fold_train = X_fold_train.to(device)
    X_fold_val = X_fold_val.to(device)
    y_fold_train = y_fold_train.to(device)
    y_fold_val = y_fold_val.to(device)
    
    # Create DataLoaders for this fold
    train_dataset = TensorDataset(X_fold_train, y_fold_train)
    val_dataset = TensorDataset(X_fold_val, y_fold_val)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64)
    
    # Initialize model, loss, and optimizer
    model = DeepARGMLP(input_dim=X_train.shape[1], output_dim=num_classes).to(device)  # Move model to device
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, nesterov=True)
    
    # Train the model for each fold with 100 epochs
    epochs = 100
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)  # Move batch to device
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        # Evaluate on the fold validation set
        model.eval()
        val_loss = 0.0
        all_predictions = []
        all_labels = []
        with torch.no_grad():
            for X_val_batch, y_val_batch in val_loader:
                X_val_batch, y_val_batch = X_val_batch.to(device), y_val_batch.to(device)  # Move batch to device
                outputs = model(X_val_batch)
                loss = criterion(outputs, y_val_batch)
                val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                all_predictions.extend(predicted.cpu().numpy())  # Move to CPU for metric calculation
                all_labels.extend(y_val_batch.cpu().numpy())
        
        # Calculate F1 score for this epoch
        fold_f1_score = f1_score(all_labels, all_predictions, average='weighted')
        print(f"Fold {fold + 1} Epoch {epoch + 1}/{epochs}, "
              f"Loss: {running_loss/len(train_loader):.4f}, "
              f"Val Loss: {val_loss/len(val_loader):.4f}, "
              f"F1 Score: {fold_f1_score:.4f}")
    
    # Save the best model based on F1 score
    if fold_f1_score > best_f1_score:
        best_f1_score = fold_f1_score
        best_model = model 



Starting Fold 1
Fold 1 Epoch 1/100, Loss: 2.5513, Val Loss: 2.3823, F1 Score: 0.3503
Fold 1 Epoch 2/100, Loss: 2.3064, Val Loss: 2.2708, F1 Score: 0.4318
Fold 1 Epoch 3/100, Loss: 2.2717, Val Loss: 2.2653, F1 Score: 0.4320
Fold 1 Epoch 4/100, Loss: 2.2688, Val Loss: 2.2656, F1 Score: 0.4314
Fold 1 Epoch 5/100, Loss: 2.2684, Val Loss: 2.2640, F1 Score: 0.4323
Fold 1 Epoch 6/100, Loss: 2.2668, Val Loss: 2.2639, F1 Score: 0.4341
Fold 1 Epoch 7/100, Loss: 2.2666, Val Loss: 2.2651, F1 Score: 0.4324
Fold 1 Epoch 8/100, Loss: 2.2670, Val Loss: 2.2644, F1 Score: 0.4324
Fold 1 Epoch 9/100, Loss: 2.2666, Val Loss: 2.2651, F1 Score: 0.4317
Fold 1 Epoch 10/100, Loss: 2.2667, Val Loss: 2.2636, F1 Score: 0.4329
Fold 1 Epoch 11/100, Loss: 2.2661, Val Loss: 2.2615, F1 Score: 0.4354
Fold 1 Epoch 12/100, Loss: 2.2659, Val Loss: 2.2619, F1 Score: 0.4339
Fold 1 Epoch 13/100, Loss: 2.2653, Val Loss: 2.2611, F1 Score: 0.4349
Fold 1 Epoch 14/100, Loss: 2.2651, Val Loss: 2.2611, F1 Score: 0.4337
Fold 1 Epoch 

Fold 2 Epoch 18/100, Loss: 2.2011, Val Loss: 2.1966, F1 Score: 0.5476
Fold 2 Epoch 19/100, Loss: 2.1929, Val Loss: 2.1840, F1 Score: 0.5483
Fold 2 Epoch 20/100, Loss: 2.1609, Val Loss: 2.1476, F1 Score: 0.6406
Fold 2 Epoch 21/100, Loss: 2.1468, Val Loss: 2.1440, F1 Score: 0.6402
Fold 2 Epoch 22/100, Loss: 2.1434, Val Loss: 2.1434, F1 Score: 0.6392
Fold 2 Epoch 23/100, Loss: 2.1418, Val Loss: 2.1427, F1 Score: 0.6404
Fold 2 Epoch 24/100, Loss: 2.1413, Val Loss: 2.1425, F1 Score: 0.6393
Fold 2 Epoch 25/100, Loss: 2.1408, Val Loss: 2.1426, F1 Score: 0.6400
Fold 2 Epoch 26/100, Loss: 2.1404, Val Loss: 2.1423, F1 Score: 0.6399
Fold 2 Epoch 27/100, Loss: 2.1401, Val Loss: 2.1421, F1 Score: 0.6405
Fold 2 Epoch 28/100, Loss: 2.1399, Val Loss: 2.1422, F1 Score: 0.6402
Fold 2 Epoch 29/100, Loss: 2.1399, Val Loss: 2.1419, F1 Score: 0.6412
Fold 2 Epoch 30/100, Loss: 2.1396, Val Loss: 2.1418, F1 Score: 0.6401
Fold 2 Epoch 31/100, Loss: 2.1394, Val Loss: 2.1418, F1 Score: 0.6401
Fold 2 Epoch 32/100,

Fold 3 Epoch 35/100, Loss: 2.1327, Val Loss: 2.1260, F1 Score: 0.6700
Fold 3 Epoch 36/100, Loss: 2.1254, Val Loss: 2.1224, F1 Score: 0.6714
Fold 3 Epoch 37/100, Loss: 2.1227, Val Loss: 2.1207, F1 Score: 0.6732
Fold 3 Epoch 38/100, Loss: 2.1207, Val Loss: 2.1203, F1 Score: 0.6734
Fold 3 Epoch 39/100, Loss: 2.1199, Val Loss: 2.1197, F1 Score: 0.6740
Fold 3 Epoch 40/100, Loss: 2.1196, Val Loss: 2.1195, F1 Score: 0.6741
Fold 3 Epoch 41/100, Loss: 2.1193, Val Loss: 2.1194, F1 Score: 0.6741
Fold 3 Epoch 42/100, Loss: 2.1190, Val Loss: 2.1194, F1 Score: 0.6742
Fold 3 Epoch 43/100, Loss: 2.1188, Val Loss: 2.1194, F1 Score: 0.6738
Fold 3 Epoch 44/100, Loss: 2.1184, Val Loss: 2.1190, F1 Score: 0.6737
Fold 3 Epoch 45/100, Loss: 2.1183, Val Loss: 2.1189, F1 Score: 0.6739
Fold 3 Epoch 46/100, Loss: 2.1181, Val Loss: 2.1189, F1 Score: 0.6743
Fold 3 Epoch 47/100, Loss: 2.1181, Val Loss: 2.1190, F1 Score: 0.6745
Fold 3 Epoch 48/100, Loss: 2.1179, Val Loss: 2.1187, F1 Score: 0.6744
Fold 3 Epoch 49/100,

Fold 4 Epoch 52/100, Loss: 2.1199, Val Loss: 2.1131, F1 Score: 0.6790
Fold 4 Epoch 53/100, Loss: 2.1199, Val Loss: 2.1132, F1 Score: 0.6790
Fold 4 Epoch 54/100, Loss: 2.1197, Val Loss: 2.1127, F1 Score: 0.6795
Fold 4 Epoch 55/100, Loss: 2.1193, Val Loss: 2.1127, F1 Score: 0.6793
Fold 4 Epoch 56/100, Loss: 2.1192, Val Loss: 2.1124, F1 Score: 0.6796
Fold 4 Epoch 57/100, Loss: 2.1192, Val Loss: 2.1125, F1 Score: 0.6792
Fold 4 Epoch 58/100, Loss: 2.1191, Val Loss: 2.1122, F1 Score: 0.6798
Fold 4 Epoch 59/100, Loss: 2.1189, Val Loss: 2.1123, F1 Score: 0.6799
Fold 4 Epoch 60/100, Loss: 2.1190, Val Loss: 2.1120, F1 Score: 0.6804
Fold 4 Epoch 61/100, Loss: 2.1190, Val Loss: 2.1123, F1 Score: 0.6797
Fold 4 Epoch 62/100, Loss: 2.1190, Val Loss: 2.1124, F1 Score: 0.6801
Fold 4 Epoch 63/100, Loss: 2.1187, Val Loss: 2.1123, F1 Score: 0.6798
Fold 4 Epoch 64/100, Loss: 2.1188, Val Loss: 2.1124, F1 Score: 0.6795
Fold 4 Epoch 65/100, Loss: 2.1186, Val Loss: 2.1122, F1 Score: 0.6797
Fold 4 Epoch 66/100,

Fold 5 Epoch 69/100, Loss: 2.1180, Val Loss: 2.1164, F1 Score: 0.6788
Fold 5 Epoch 70/100, Loss: 2.1181, Val Loss: 2.1163, F1 Score: 0.6785
Fold 5 Epoch 71/100, Loss: 2.1181, Val Loss: 2.1166, F1 Score: 0.6782
Fold 5 Epoch 72/100, Loss: 2.1180, Val Loss: 2.1166, F1 Score: 0.6782
Fold 5 Epoch 73/100, Loss: 2.1180, Val Loss: 2.1168, F1 Score: 0.6782
Fold 5 Epoch 74/100, Loss: 2.1180, Val Loss: 2.1168, F1 Score: 0.6783
Fold 5 Epoch 75/100, Loss: 2.1178, Val Loss: 2.1165, F1 Score: 0.6784
Fold 5 Epoch 76/100, Loss: 2.1179, Val Loss: 2.1162, F1 Score: 0.6788
Fold 5 Epoch 77/100, Loss: 2.1179, Val Loss: 2.1164, F1 Score: 0.6785
Fold 5 Epoch 78/100, Loss: 2.1180, Val Loss: 2.1164, F1 Score: 0.6788
Fold 5 Epoch 79/100, Loss: 2.1179, Val Loss: 2.1163, F1 Score: 0.6788
Fold 5 Epoch 80/100, Loss: 2.1178, Val Loss: 2.1165, F1 Score: 0.6786
Fold 5 Epoch 81/100, Loss: 2.1178, Val Loss: 2.1162, F1 Score: 0.6789
Fold 5 Epoch 82/100, Loss: 2.1178, Val Loss: 2.1166, F1 Score: 0.6782
Fold 5 Epoch 83/100,

Fold 6 Epoch 86/100, Loss: 2.1172, Val Loss: 2.1229, F1 Score: 0.6717
Fold 6 Epoch 87/100, Loss: 2.1173, Val Loss: 2.1226, F1 Score: 0.6718
Fold 6 Epoch 88/100, Loss: 2.1173, Val Loss: 2.1227, F1 Score: 0.6719
Fold 6 Epoch 89/100, Loss: 2.1170, Val Loss: 2.1229, F1 Score: 0.6715
Fold 6 Epoch 90/100, Loss: 2.1172, Val Loss: 2.1227, F1 Score: 0.6716
Fold 6 Epoch 91/100, Loss: 2.1172, Val Loss: 2.1228, F1 Score: 0.6719
Fold 6 Epoch 92/100, Loss: 2.1173, Val Loss: 2.1231, F1 Score: 0.6714
Fold 6 Epoch 93/100, Loss: 2.1173, Val Loss: 2.1228, F1 Score: 0.6717
Fold 6 Epoch 94/100, Loss: 2.1173, Val Loss: 2.1225, F1 Score: 0.6718
Fold 6 Epoch 95/100, Loss: 2.1173, Val Loss: 2.1226, F1 Score: 0.6719
Fold 6 Epoch 96/100, Loss: 2.1171, Val Loss: 2.1224, F1 Score: 0.6720
Fold 6 Epoch 97/100, Loss: 2.1172, Val Loss: 2.1226, F1 Score: 0.6720
Fold 6 Epoch 98/100, Loss: 2.1172, Val Loss: 2.1226, F1 Score: 0.6718
Fold 6 Epoch 99/100, Loss: 2.1172, Val Loss: 2.1222, F1 Score: 0.6724
Fold 6 Epoch 100/100

Fold 8 Epoch 3/100, Loss: 2.2707, Val Loss: 2.2688, F1 Score: 0.4300
Fold 8 Epoch 4/100, Loss: 2.2681, Val Loss: 2.2684, F1 Score: 0.4341
Fold 8 Epoch 5/100, Loss: 2.2673, Val Loss: 2.2679, F1 Score: 0.4297
Fold 8 Epoch 6/100, Loss: 2.2664, Val Loss: 2.2675, F1 Score: 0.4317
Fold 8 Epoch 7/100, Loss: 2.2662, Val Loss: 2.2678, F1 Score: 0.4305
Fold 8 Epoch 8/100, Loss: 2.2657, Val Loss: 2.2673, F1 Score: 0.4330
Fold 8 Epoch 9/100, Loss: 2.2656, Val Loss: 2.2682, F1 Score: 0.4312
Fold 8 Epoch 10/100, Loss: 2.2653, Val Loss: 2.2677, F1 Score: 0.4300
Fold 8 Epoch 11/100, Loss: 2.2651, Val Loss: 2.2671, F1 Score: 0.4302
Fold 8 Epoch 12/100, Loss: 2.2649, Val Loss: 2.2663, F1 Score: 0.4320
Fold 8 Epoch 13/100, Loss: 2.2647, Val Loss: 2.2660, F1 Score: 0.4322
Fold 8 Epoch 14/100, Loss: 2.2643, Val Loss: 2.2656, F1 Score: 0.4333
Fold 8 Epoch 15/100, Loss: 2.2638, Val Loss: 2.2651, F1 Score: 0.4310
Fold 8 Epoch 16/100, Loss: 2.2624, Val Loss: 2.2615, F1 Score: 0.4325
Fold 8 Epoch 17/100, Loss: 

Fold 9 Epoch 21/100, Loss: 2.1449, Val Loss: 2.1363, F1 Score: 0.6446
Fold 9 Epoch 22/100, Loss: 2.1431, Val Loss: 2.1354, F1 Score: 0.6434
Fold 9 Epoch 23/100, Loss: 2.1418, Val Loss: 2.1351, F1 Score: 0.6446
Fold 9 Epoch 24/100, Loss: 2.1414, Val Loss: 2.1345, F1 Score: 0.6449
Fold 9 Epoch 25/100, Loss: 2.1409, Val Loss: 2.1344, F1 Score: 0.6451
Fold 9 Epoch 26/100, Loss: 2.1406, Val Loss: 2.1340, F1 Score: 0.6462
Fold 9 Epoch 27/100, Loss: 2.1405, Val Loss: 2.1341, F1 Score: 0.6449
Fold 9 Epoch 28/100, Loss: 2.1403, Val Loss: 2.1338, F1 Score: 0.6457
Fold 9 Epoch 29/100, Loss: 2.1401, Val Loss: 2.1336, F1 Score: 0.6459
Fold 9 Epoch 30/100, Loss: 2.1402, Val Loss: 2.1338, F1 Score: 0.6451
Fold 9 Epoch 31/100, Loss: 2.1400, Val Loss: 2.1340, F1 Score: 0.6451
Fold 9 Epoch 32/100, Loss: 2.1400, Val Loss: 2.1341, F1 Score: 0.6450
Fold 9 Epoch 33/100, Loss: 2.1400, Val Loss: 2.1344, F1 Score: 0.6449
Fold 9 Epoch 34/100, Loss: 2.1400, Val Loss: 2.1340, F1 Score: 0.6446
Fold 9 Epoch 35/100,

Fold 10 Epoch 38/100, Loss: 2.1332, Val Loss: 2.1422, F1 Score: 0.6559
Fold 10 Epoch 39/100, Loss: 2.1253, Val Loss: 2.1354, F1 Score: 0.6577
Fold 10 Epoch 40/100, Loss: 2.1201, Val Loss: 2.1335, F1 Score: 0.6591
Fold 10 Epoch 41/100, Loss: 2.1188, Val Loss: 2.1334, F1 Score: 0.6592
Fold 10 Epoch 42/100, Loss: 2.1182, Val Loss: 2.1332, F1 Score: 0.6595
Fold 10 Epoch 43/100, Loss: 2.1180, Val Loss: 2.1331, F1 Score: 0.6592
Fold 10 Epoch 44/100, Loss: 2.1174, Val Loss: 2.1333, F1 Score: 0.6589
Fold 10 Epoch 45/100, Loss: 2.1173, Val Loss: 2.1331, F1 Score: 0.6598
Fold 10 Epoch 46/100, Loss: 2.1171, Val Loss: 2.1332, F1 Score: 0.6588
Fold 10 Epoch 47/100, Loss: 2.1167, Val Loss: 2.1331, F1 Score: 0.6589
Fold 10 Epoch 48/100, Loss: 2.1169, Val Loss: 2.1330, F1 Score: 0.6591
Fold 10 Epoch 49/100, Loss: 2.1166, Val Loss: 2.1328, F1 Score: 0.6595
Fold 10 Epoch 50/100, Loss: 2.1166, Val Loss: 2.1325, F1 Score: 0.6602
Fold 10 Epoch 51/100, Loss: 2.1166, Val Loss: 2.1324, F1 Score: 0.6602
Fold 1

In [14]:
from sklearn.metrics import precision_recall_fscore_support

all_predictions = []
all_labels = []
with torch.no_grad():
    for X_val_batch, y_val_batch in val_loader:
        outputs = best_model(X_val_batch)
        _, predicted = torch.max(outputs, 1)
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(y_val_batch.cpu().numpy())

# Calculate macro precision, recall, and F1 score, as well as per-class metrics
precision, recall, f1, support = precision_recall_fscore_support(
    all_labels, all_predictions, average=None, labels=range(num_classes)
)

avg_precision, avg_recall, avg_f1, avg_support = precision_recall_fscore_support(
    all_labels, all_predictions, average='micro'
)

# Calculate macro-averaged metrics (ignoring class imbalance)
macro_precision, macro_recall, macro_f1, _ = precision_recall_fscore_support(
    all_labels, all_predictions, average='macro'
)

weighted_precision, weighted_recall, weighted_f1, weighted_support = precision_recall_fscore_support(
    all_labels, all_predictions, average='weighted'
)

# Print macro-averaged metrics
print(f"Precision: {avg_precision}")
print(f"Recall: {avg_recall}")
print(f"F1 Score: {avg_f1}")

print(f"Macro Precision: {macro_precision}")
print(f"Macro Recall: {macro_recall}")
print(f"Macro F1 Score: {macro_f1}")

print(f"Weighted Precision: {weighted_precision}")
print(f"Weighted Recall: {weighted_recall}")
print(f"Weighted F1 Score: {weighted_f1}")

Precision: 0.6852897473997028
Recall: 0.6852897473997028
F1 Score: 0.6852897473997028
Macro Precision: 0.2859253353467196
Macro Recall: 0.25402642337706566
Macro F1 Score: 0.25928584016933554
Weighted Precision: 0.7018668175167255
Weighted Recall: 0.6852897473997028
Weighted F1 Score: 0.6623364565238284


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [15]:
class_names = label_encoder.classes_

for i, class_name in enumerate(class_names):
    print(f"Class '{class_name}': Precision: {precision[i]}, Recall: {recall[i]}, F1 Score: {f1[i]}, Support: {support[i]}")

Class 'aminoglycoside': Precision: 0.676595744680851, Recall: 0.6115384615384616, F1 Score: 0.6424242424242425, Support: 260
Class 'bacitracin': Precision: 0.5572498662386303, Recall: 0.9433876811594203, F1 Score: 0.7006390850992263, Support: 2208
Class 'beta_lactam': Precision: 0.9100985221674877, Recall: 0.6789159393661002, F1 Score: 0.7776900815574849, Support: 2177
Class 'chloramphenicol': Precision: 0.0, Recall: 0.0, F1 Score: 0.0, Support: 178
Class 'fosfomycin': Precision: 0.0, Recall: 0.0, F1 Score: 0.0, Support: 74
Class 'fosmidomycin': Precision: 0.0, Recall: 0.0, F1 Score: 0.0, Support: 2
Class 'glycopeptide': Precision: 0.0, Recall: 0.0, F1 Score: 0.0, Support: 7
Class 'macrolide-lincosamide-streptogramin': Precision: 0.7829977628635347, Recall: 0.3875968992248062, F1 Score: 0.5185185185185185, Support: 903
Class 'multidrug': Precision: 0.0, Recall: 0.0, F1 Score: 0.0, Support: 116
Class 'mupirocin': Precision: 0.0, Recall: 0.0, F1 Score: 0.0, Support: 6
Class 'polymyxin': 

In [16]:
model_path = "models/best_dna_sr_model.pth"

# Save the trained model
torch.save(best_model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to models/best_dna_sr_model.pth


In [17]:
best_model = best_model.to(device)

output_path = "results/sr_dna_model_predictions.csv"

# After evaluating the best model on the holdout set and collecting predictions
all_predictions = []
all_labels = []
with torch.no_grad():
    for X_test_batch, y_test_batch in test_loader:
        # Move input batch and labels to the same device as the model
        X_test_batch = X_test_batch.to(device)
        y_test_batch = y_test_batch.to(device)
        
        outputs = best_model(X_test_batch)
        _, predicted = torch.max(outputs, 1)
        
        # Move predictions and labels to CPU for saving
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(y_test_batch.cpu().numpy())

# Decode labels and predictions to their original class names
true_labels = label_encoder.inverse_transform(all_labels)
predicted_labels = label_encoder.inverse_transform(all_predictions)

# Include the ID (index) from y_test
ids = range(len(y_test))

# Create a DataFrame to store the outputs
outputs_df = pd.DataFrame({
    "ID": ids,
    "True Label": true_labels,
    "Predicted Label": predicted_labels
})

# Save the outputs to a CSV file
outputs_df.to_csv(output_path, index=False)
print(f"Predictions saved to {output_path}")

Predictions saved to results/sr_dna_model_predictions.csv
