<a href="https://colab.research.google.com/github/bennyy117/Salmonella_AMR_Project/blob/main/SalmonellaAMR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

In [1]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multioutput import MultiOutputClassifier
from sklearn.metrics import classification_report, accuracy_score
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import os

data_path = '/content/drive/MyDrive/Salmonella_AMR_Project/'
os.makedirs(data_path, exist_ok=True)
print("Setup complete!")

Setup complete!


# Load data

In [3]:
file_name = 'salmonella_amr_data.csv'
df = pd.read_csv(data_path + file_name, low_memory=False)

print(df.shape)
print(df.columns)
df.head()

(211700, 21)
Index(['#Organism group', 'Strain', 'Isolate identifiers', 'Serovar',
       'Isolate', 'Create date', 'Location', 'Isolation source',
       'Isolation type', 'Food origin', 'SNP cluster', 'Min-same', 'Min-diff',
       'BioSample', 'Assembly', 'AMR genotypes', 'Computed types',
       'Unnamed: 17', 'Unnamed: 18', 'Unnamed: 19', 'Unnamed: 20'],
      dtype='object')


Unnamed: 0,#Organism group,Strain,Isolate identifiers,Serovar,Isolate,Create date,Location,Isolation source,Isolation type,Food origin,...,Min-same,Min-diff,BioSample,Assembly,AMR genotypes,Computed types,Unnamed: 17,Unnamed: 18,Unnamed: 19,Unnamed: 20
0,Salmonella enterica,SQ0227,"""93-6175B"",""SQ0227"",""SRS426868""",enteritidis,PDT000000002.3,2014-01-04T17:03:07Z,USA: Western Region,Ocean,environmental/other,,...,0,,SAMN02147118,,"ant(2'')-Ia=COMPLETE,aph(3')-Ia=COMPLETE,blaTE...","antigen_formula=9:g,m:-,serotype=Enteritidis",,,,
1,Salmonella enterica,SQ0228,"""93-2836A"",""SQ0228"",""SRS426867""",enteritidis,PDT000000003.3,2014-01-04T17:03:07Z,USA: Western Region,Ocean,environmental/other,,...,0,,SAMN02147119,,"mdsA=COMPLETE,mdsB=COMPLETE","antigen_formula=9:g,m:-,serotype=Enteritidis",,,,
2,Salmonella enterica,SQ0229,"""93-7741"",""SQ0229"",""SRS426869""",enteritidis,PDT000000004.3,2014-01-04T17:03:07Z,USA: Western Region,Ocean,environmental/other,,...,0,25.0,SAMN02147120,,"mdsA=COMPLETE,mdsB=COMPLETE","antigen_formula=9:g,m:-,serotype=Enteritidis",,,,
3,Salmonella enterica,Gen_001782,"""Gen_001782"",""SRS426891""",Heidelberg,PDT000000005.4,2014-01-04T17:03:07Z,USA,food,environmental/other,,...,2,2.0,SAMN02147121,GCA_010121905.1,"fosA7=COMPLETE,mdsA=COMPLETE,mdsB=COMPLETE","antigen_formula=4:r:1,2,serotype=Heidelberg",,,,
4,Salmonella enterica,Gen_001783,"""Gen_001783"",""SRS426892""",Heidelberg,PDT000000006.4,2014-01-04T17:03:07Z,USA,food,environmental/other,,...,2,12.0,SAMN02147122,GCA_010121865.1,"blaCMY-2=COMPLETE,fosA7=COMPLETE,mdsA=COMPLETE...","antigen_formula=4:r:1,2,serotype=Heidelberg",,,,


# Data processing

In [4]:
# Extract unique AMR genes and create binary features
def parse_amr_genotypes(row):
    if pd.isna(row): return {}
    genes = {}
    for item in str(row).split(','):
        item = item.strip()
        if '=' in item:
            gene, status = item.split('=')
            genes[gene] = 1 if status == 'COMPLETE' else 0
    return genes

# Apply parsing
amr_parsed = df['AMR genotypes'].apply(parse_amr_genotypes)

# Get all unique genes
all_genes = set()
for genes in amr_parsed:
    all_genes.update(genes.keys())
all_genes = sorted(all_genes)
print(f"Total unique AMR genes: {len(all_genes)}")

# Create binary dataframe
amr_df = pd.DataFrame(0, index=df.index, columns=all_genes)
for i, genes in enumerate(amr_parsed):
    for gene, val in genes.items():
        amr_df.at[i, gene] = val

# Combine with original (keep only features)
features_df = amr_df
print(features_df.sum().sort_values(ascending=False).head(20))  # Top common genes

Total unique AMR genes: 429
mdsA           204068
mdsB           203541
aph(6)-Id       25822
aph(3'')-Ib     25688
tet(A)          24231
sul2            21844
blaTEM-1        19646
tet(B)          16573
sul1            15455
floR            11705
aadA1            9409
aadA2            7331
aph(3')-Ia       6653
blaCMY-2         6492
fosA7            5350
dfrA14           3773
blaCARB-2        3597
aac(3)-IVa       3457
aph(4)-Ia        3392
fosA7.2          3296
dtype: int64


In [5]:
# Define key genes for antibiotics (same as before)
antibiotic_genes = {
    'Ampicillin': ['blaTEM-1', 'blaCMY-2', 'blaCARB-2'],
    'Tetracycline': ['tet(A)', 'tet(B)'],
    'Streptomycin': ['aph(3\'\')-Ib', 'aph(6)-Id', 'aadA1', 'aadA2', 'aph(3\')-Ia'],
    'Sulfonamides': ['sul1', 'sul2'],
    'Chloramphenicol': ['floR', 'catA1']
}

# Create labels (resistant if at least one key gene present)
labels_df = pd.DataFrame(0, index=features_df.index, columns=antibiotic_genes.keys())
for ab, key_genes_list in antibiotic_genes.items():
    available = [g for g in key_genes_list if g in features_df.columns]
    if available:
        labels_df[ab] = features_df[available].max(axis=1)
    resistant_count = labels_df[ab].sum()
    prevalence = resistant_count / len(features_df) * 100
    print(f"{ab}: {resistant_count} resistant ({prevalence:.2f}%)")

# Remove key genes from features to avoid direct leakage
all_key_genes = set()
for genes in antibiotic_genes.values():
    all_key_genes.update(genes)
key_genes_in_data = [g for g in all_key_genes if g in features_df.columns]

features_clean = features_df.drop(columns=key_genes_in_data)
print(f"\nRemoved {len(key_genes_in_data)} key genes from features.")
print(f"Remaining features: {features_clean.shape[1]} genes (including common ones like mdsA, mdsB)")

# Final data
X = features_clean.values.astype(np.float32)
y = labels_df.values.astype(np.float32)

Ampicillin: 28929 resistant (13.67%)
Tetracycline: 40003 resistant (18.90%)
Streptomycin: 40747 resistant (19.25%)
Sulfonamides: 33027 resistant (15.60%)
Chloramphenicol: 14621 resistant (6.91%)

Removed 14 key genes from features.
Remaining features: 415 genes (including common ones like mdsA, mdsB)


In [6]:
# Subsample data to speed up training
from sklearn.model_selection import train_test_split

sample_size = 20000

# Stratify on total resistant labels to keep balance
X_sub, _, y_sub, _ = train_test_split(X, y, train_size=sample_size, random_state=42, stratify=y.sum(axis=1))

print(f"Subsampled to {X_sub.shape[0]} samples.")
print("Resistant prevalence in subsample:")
print(pd.DataFrame(y_sub, columns=antibiotic_genes.keys()).mean() * 100)

X = X_sub
y = y_sub

Subsampled to 20000 samples.
Resistant prevalence in subsample:
Ampicillin         13.680000
Tetracycline       18.925001
Streptomycin       19.250000
Sulfonamides       15.705000
Chloramphenicol     6.755000
dtype: float32


# Random Forest with K-Fold



In [7]:
# K-Fold
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import numpy as np

# 5-fold CV for Random Forest (multi-label)
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

# Since multi-label, stratify on sum of labels
fold_accuracies = {ab: [] for ab in antibiotic_genes.keys()}

for fold, (train_idx, test_idx) in enumerate(skf.split(X, y.sum(axis=1))):
    X_train, X_test = X[train_idx], X[test_idx]
    y_train, y_test = y[train_idx], y[test_idx]

    rf = MultiOutputClassifier(RandomForestClassifier(n_estimators=200, random_state=42))
    rf.fit(X_train, y_train)

    y_pred = rf.predict(X_test)

    print(f"\nFold {fold+1}:")
    for i, ab in enumerate(antibiotic_genes.keys()):
        acc = accuracy_score(y_test[:, i], y_pred[:, i])
        fold_accuracies[ab].append(acc)
        print(f"{ab} Accuracy: {acc:.4f}")

# Average accuracies
print("\nAverage Accuracies across folds:")
for ab, accs in fold_accuracies.items():
    avg_acc = np.mean(accs)
    std_acc = np.std(accs)
    print(f"{ab}: {avg_acc:.4f} ± {std_acc:.4f}")


Fold 1:
Ampicillin Accuracy: 0.9147
Tetracycline Accuracy: 0.8730
Streptomycin Accuracy: 0.8942
Sulfonamides Accuracy: 0.9100
Chloramphenicol Accuracy: 0.9760

Fold 2:
Ampicillin Accuracy: 0.9183
Tetracycline Accuracy: 0.8690
Streptomycin Accuracy: 0.8892
Sulfonamides Accuracy: 0.9077
Chloramphenicol Accuracy: 0.9690

Fold 3:
Ampicillin Accuracy: 0.9097
Tetracycline Accuracy: 0.8648
Streptomycin Accuracy: 0.8938
Sulfonamides Accuracy: 0.9117
Chloramphenicol Accuracy: 0.9712

Fold 4:
Ampicillin Accuracy: 0.9105
Tetracycline Accuracy: 0.8748
Streptomycin Accuracy: 0.8888
Sulfonamides Accuracy: 0.9095
Chloramphenicol Accuracy: 0.9688

Fold 5:
Ampicillin Accuracy: 0.9160
Tetracycline Accuracy: 0.8672
Streptomycin Accuracy: 0.8922
Sulfonamides Accuracy: 0.9130
Chloramphenicol Accuracy: 0.9695

Average Accuracies across folds:
Ampicillin: 0.9139 ± 0.0033
Tetracycline: 0.8698 ± 0.0037
Streptomycin: 0.8916 ± 0.0023
Sulfonamides: 0.9104 ± 0.0018
Chloramphenicol: 0.9709 ± 0.0027


# PyTorch with K-Fold

In [8]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
import numpy as np

class AMRDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.float32)

    def __len__(self): return len(self.X)
    def __getitem__(self, idx): return self.X[idx], self.y[idx]

class AMRNet(nn.Module):
    def __init__(self, input_size, num_labels):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_labels),
            nn.Sigmoid()  # For multi-label
        )

    def forward(self, x):
        return self.net(x)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 5-fold CV for PyTorch
n_folds = 5
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=42)

fold_accuracies_pt = {ab: [] for ab in antibiotic_genes.keys()}

for fold, (train_idx, test_idx) in enumerate(skf.split(X, y.sum(axis=1))):
    print(f"\nFold {fold+1}:")

    X_train_fold = X[train_idx]
    X_test_fold = X[test_idx]
    y_train_fold = y[train_idx]
    y_test_fold = y[test_idx]

    train_dataset = AMRDataset(X_train_fold, y_train_fold)
    test_dataset = AMRDataset(X_test_fold, y_test_fold)
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

    model = AMRNet(X.shape[1], y.shape[1]).to(device)
    criterion = nn.BCELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    epochs = 20
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        for batch_x, batch_y in train_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            optimizer.zero_grad()
            outputs = model(batch_x)
            loss = criterion(outputs, batch_y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        avg_loss = total_loss / len(train_loader)
        if (epoch + 1) % 5 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

    model.eval()
    y_pred_pt = []
    y_true_pt = []
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            batch_x = batch_x.to(device)
            outputs = model(batch_x)
            y_pred_pt.append((outputs.cpu().numpy() > 0.5).astype(int))
            y_true_pt.append(batch_y.numpy())
    y_pred_pt = np.vstack(y_pred_pt)
    y_true_pt = np.vstack(y_true_pt)

    for i, ab in enumerate(antibiotic_genes.keys()):
        acc = accuracy_score(y_true_pt[:, i], y_pred_pt[:, i])
        fold_accuracies_pt[ab].append(acc)
        print(f"{ab} Accuracy: {acc:.4f}")

# Average accuracies
print("\nAverage Accuracies across folds:")
for ab, accs in fold_accuracies_pt.items():
    avg_acc = np.mean(accs)
    std_acc = np.std(accs)
    print(f"{ab}: {avg_acc:.4f} ± {std_acc:.4f}")

Using device: cuda

Fold 1:
Epoch 5/20, Loss: 0.2778
Epoch 10/20, Loss: 0.2713
Epoch 15/20, Loss: 0.2682
Epoch 20/20, Loss: 0.2668
Ampicillin Accuracy: 0.9143
Tetracycline Accuracy: 0.8712
Streptomycin Accuracy: 0.8958
Sulfonamides Accuracy: 0.9075
Chloramphenicol Accuracy: 0.9750

Fold 2:
Epoch 5/20, Loss: 0.2785
Epoch 10/20, Loss: 0.2711
Epoch 15/20, Loss: 0.2679
Epoch 20/20, Loss: 0.2662
Ampicillin Accuracy: 0.9187
Tetracycline Accuracy: 0.8695
Streptomycin Accuracy: 0.8902
Sulfonamides Accuracy: 0.9050
Chloramphenicol Accuracy: 0.9667

Fold 3:
Epoch 5/20, Loss: 0.2762
Epoch 10/20, Loss: 0.2697
Epoch 15/20, Loss: 0.2668
Epoch 20/20, Loss: 0.2644
Ampicillin Accuracy: 0.9085
Tetracycline Accuracy: 0.8640
Streptomycin Accuracy: 0.8915
Sulfonamides Accuracy: 0.9125
Chloramphenicol Accuracy: 0.9712

Fold 4:
Epoch 5/20, Loss: 0.2763
Epoch 10/20, Loss: 0.2702
Epoch 15/20, Loss: 0.2663
Epoch 20/20, Loss: 0.2645
Ampicillin Accuracy: 0.9075
Tetracycline Accuracy: 0.8752
Streptomycin Accuracy:

# Model training

In [9]:
final_dataset = AMRDataset(X, y)
final_loader = DataLoader(final_dataset, batch_size=64, shuffle=True)

final_model = AMRNet(X.shape[1], y.shape[1]).to(device)
optimizer = optim.Adam(final_model.parameters(), lr=0.001)
criterion = nn.BCELoss()

for epoch in range(30):
    final_model.train()
    total_loss = 0
    for batch_x, batch_y in final_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        optimizer.zero_grad()
        outputs = final_model(batch_x)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    if (epoch + 1) % 10 == 0:
        print(f"Final Epoch {epoch+1}/30, Loss: {total_loss/len(final_loader):.4f}")

torch.save(final_model.state_dict(), data_path + 'final_salmonella_amr_pytorch_model.pth')
print("Model saved.")

Final Epoch 10/30, Loss: 0.2700
Final Epoch 20/30, Loss: 0.2655
Final Epoch 30/30, Loss: 0.2641
Model saved.
