## ResLens: Antibiotic Resistance Gene Prediction Example

This script demonstrates how to use the ResLens models to:
1. Identify sequences that are likely to be antibiotic resistance genes (ARGs)
2. Classify the resistance mechanism for identified ARGs

In [1]:
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoConfig
from torch.nn.functional import softmax
import numpy as np

Load in example data

In [27]:
df = pd.read_csv('example_data.csv')
display(df)

Unnamed: 0,seq,class
0,ATGAAAATAATTAACTTAGGCATTCTGGCTCACGTTGACGCAGGAA...,tetracycline
1,GTGACATTGAAATCCCCACTGCCACCGCAATCCGTCTCCGCACCCG...,MLS
2,TTGAAAAAATTAATAATTTTAGTCGTGTTAGCGTTGATATTAAGTG...,beta_lactam
3,ATGTTGAAAAGTTCGTGGCGTAAAACCGCCCTGATGGCCGCCGCCG...,beta_lactam
4,ATGCGCGGTAAACACACTGTCATTCTGGGCGCGGCACTGTCGGCGC...,beta_lactam
5,ATGGGCATCATTCGCACATGTAGGCTCGGCCCTGACCAAGTCAAAT...,aminoglycoside
6,ATGACAGAGCAGCAGTGGAATTTCGCGGGTATCGAGGCCGCGGCAA...,non_ARG
7,ATGGCTATCGACGAAAACAAACAGAAAGCGTTGGCGGCAGCACTGG...,non_ARG
8,ATGTTTGAACCAATGGAACTTACCAATGACGCGGTGATTAAAGTCA...,non_ARG
9,ATGAGAGATTTATTATCTAAAAAAAGTCATAGGCAATTAGAATTAT...,non_ARG


Load models

In [28]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def load_model_and_tokenizer(model_path, device):
    """Load a model and its tokenizer."""
    tokenizer = AutoTokenizer.from_pretrained(model_path,trust_remote_code=True)
    config = AutoConfig.from_pretrained(model_path
                                        , trust_remote_code=True
                                        )
    model = AutoModelForSequenceClassification.from_pretrained(
        model_path
        , config=config
        , trust_remote_code=True
    ).to(device)
    model.eval()
    return tokenizer, model, config

print("Loading binary model...")
binary_model_path = "omicseye/resLens_lr_binary"
bin_tok, bin_model, _ = load_model_and_tokenizer(binary_model_path, device)

print("Loading multiclass model...")
multiclass_model_path = "omicseye/resLens_lr_multiclass"
multi_tok, multi_model, multi_config = load_model_and_tokenizer(multiclass_model_path, device)

Using device: cpu
Loading binary model...
Loading multiclass model...


Tokenize and make ARG vs non-ARG predictions

In [29]:
def batch_predict(model, tokenizer, sequences, device, batch_size=16):
    """Make predictions on a batch of sequences."""
    all_preds = []
    all_probs = []
    maxlen = tokenizer.model_max_length
    
    for i in range(0, len(sequences), batch_size):
        batch = list(sequences[i : i + batch_size])
        enc = tokenizer(
            batch,
            padding=True,
            truncation=True,
            max_length=maxlen,
            return_tensors="pt",
        ).to(device)
        
        with torch.no_grad():
            logits = model(**enc).logits
            probs = softmax(logits, dim=-1)
            preds = torch.argmax(logits, dim=-1)
            
        all_preds.extend(preds.cpu().tolist())
        all_probs.extend(probs.cpu().tolist())
    
    return all_preds, all_probs

print("Running binary predictions...")
seqs = df.seq
bin_preds, bin_probs = batch_predict(bin_model, bin_tok, seqs, device)
df["binary_pred"] = bin_preds
df["binary_prob"] = [probs[1] for probs in bin_probs]
df["binary_pred"] = np.where(df.binary_pred == 0, "ARG", "non_ARG")
print("Binary predictions complete")

Running binary predictions...
Binary predictions complete


Tokenize and make ARG class predictions

In [30]:
arg_idx = df.index[df["binary_pred"] == 'ARG'].tolist()
df["arg_pred"] = None
df["arg_pred_prob"] = None

if arg_idx:
    arg_seqs = [seqs[i] for i in arg_idx]
    print(f"Running multiclass predictions on {len(arg_seqs)} predicted ARGs...")
    multi_preds, multi_probs = batch_predict(multi_model, multi_tok, arg_seqs, device)
    
    # Map numeric IDs to label names
    id2label = {int(k): v for k, v in multi_config.id2label.items()}
    mapped = [id2label.get(pred, str(pred)) for pred in multi_preds]
    
    # Add predictions to dataframe
    df.loc[arg_idx, "arg_pred"] = pd.Series(mapped, index=arg_idx)
    df.loc[arg_idx, "arg_pred_prob"] = pd.Series([max(probs) for probs in multi_probs], index=arg_idx)

df['arg_pred'] = df['arg_pred'].fillna('non_ARG')
print("Multiclass predictions complete")

Running multiclass predictions on 6 predicted ARGs...
Multiclass predictions complete


In [31]:
df

Unnamed: 0,seq,class,binary_pred,binary_prob,arg_pred,arg_pred_prob
0,ATGAAAATAATTAACTTAGGCATTCTGGCTCACGTTGACGCAGGAA...,tetracycline,ARG,0.000418,tetracycline,0.991753
1,GTGACATTGAAATCCCCACTGCCACCGCAATCCGTCTCCGCACCCG...,MLS,ARG,0.000545,MLS,0.992611
2,TTGAAAAAATTAATAATTTTAGTCGTGTTAGCGTTGATATTAAGTG...,beta_lactam,ARG,0.000426,beta_lactam,0.999743
3,ATGTTGAAAAGTTCGTGGCGTAAAACCGCCCTGATGGCCGCCGCCG...,beta_lactam,ARG,0.000411,beta_lactam,0.999751
4,ATGCGCGGTAAACACACTGTCATTCTGGGCGCGGCACTGTCGGCGC...,beta_lactam,ARG,0.000431,beta_lactam,0.999747
5,ATGGGCATCATTCGCACATGTAGGCTCGGCCCTGACCAAGTCAAAT...,aminoglycoside,ARG,0.000456,aminoglycoside,0.997469
6,ATGACAGAGCAGCAGTGGAATTTCGCGGGTATCGAGGCCGCGGCAA...,non_ARG,non_ARG,0.99861,non_ARG,
7,ATGGCTATCGACGAAAACAAACAGAAAGCGTTGGCGGCAGCACTGG...,non_ARG,non_ARG,0.998672,non_ARG,
8,ATGTTTGAACCAATGGAACTTACCAATGACGCGGTGATTAAAGTCA...,non_ARG,non_ARG,0.998648,non_ARG,
9,ATGAGAGATTTATTATCTAAAAAAAGTCATAGGCAATTAGAATTAT...,non_ARG,non_ARG,0.989429,non_ARG,
