<a href="https://colab.research.google.com/github/nbrg-ppcu/PhaStyle/blob/main/bin/PhaStyleExample.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ProkBERT PhaStyle example

The inference consists of 3 simple steps:

Main steps:
 - model loading
 - prepraring the dataset (parsing fasta file and creating tokenized dataset for inference)
 - running the inference and generating the final report

## Setting Up the Environment

While ProkBERT can operate on CPUs, leveraging GPUs significantly accelerates the process. Google Colab offers free GPU usage making it an ideal platform for trying and experimenting with ProkBERT models.



### Enabling and testing the GPU (if you are using google colab)

First, you'll need to enable GPUs for the notebook:

- Navigate to Edit→Notebook Settings
- select GPU from the Hardware Accelerator drop-down

In [None]:
!pip install git+https://github.com/nbrg-ppcu/prokbert.git --quiet
!pip install transformers datasets --quiet

from transformers import TrainingArguments, Trainer, DataCollatorWithPadding
from prokbert.sequtils import *
from prokbert.training_utils import *
from prokbert.models import ProkBertForSequenceClassification
from prokbert.tokenizer import LCATokenizer
from datasets import Dataset

import multiprocessing
import pandas as pd
import torch
import numpy as np
from os.path import join
import os



Next, we'll confirm that we can connect to the GPU with pytorch:


In [None]:
# Check if CUDA (GPU support) is available
if not torch.cuda.is_available():
    raise SystemError('GPU device not found')
else:
    device_name = torch.cuda.get_device_name(0)
    print(f'Found GPU at: {device_name}')
num_cores = os.cpu_count()
print(f'Number of available CPU cores: {num_cores}')

# Preparing the models

Downloading the model and the tokenizer


In [None]:
model_path = 'neuralbioinfo/PhaStyle-mini'
model = ProkBertForSequenceClassification.from_pretrained(model_path, trust_remote_code=True)
tokenizer = LCATokenizer.from_pretrained(model_path, trust_remote_code=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


%%markdown
# Preparing the Dataset

In this section, you can either **upload your own FASTA file** or use our **default EXTREMOPHILE dataset** for a quick start.

1. **Upload Your FASTA**  
   Run the next cell, then click the **Browse** button to select and upload a FASTA file from your local computer. The button will be located at the bottom of the cell.

2. **Use the Default EXTREMOPHILE Dataset**  
   If you don’t have a FASTA file handy, you can download the small **extremophiles.fasta** file, which contains the archaeal phages described in our recent paper. Simply click the link below to download and save it to your local computer (or right-click and choose “Save Link As”):

   <a href="https://raw.githubusercontent.com/nbrg-ppcu/PhaStyle/main/data/EXTREMOPHILE/extremophiles.fasta" download="extremophiles.fasta">📥 Download <code>extremophiles.fasta</code></a>


In [None]:
from google.colab import files
uploaded = files.upload()
if uploaded:
    fasta_filename = list(uploaded.keys())[0]
    print("Uploaded FASTA file name:", fasta_filename)


## Dataset Preprocessing

Next section, we read, segment, and tokenize FASTA sequences.

### Existing Steps
1. Reading and parsing the FASTA file  
2. Cutting long sequences into smaller segments (~512 bp)  
3. Tokenizing and preparing for the model  

In [None]:


max_length=512 # Fit to the model size
print(f"[prepare_dataset] Loading sequences from: {fasta_filename}")

sequences = load_contigs(
    [fasta_filename],
    IsAddHeader=True,
    adding_reverse_complement=False,
    AsDataFrame=True,
    to_uppercase=True,
    is_add_sequence_id=True,
)
print(f"[prepare_dataset] Number of raw sequences: {len(sequences)}")

print("[prepare_dataset] Running segmentation")
segmentation_params = {
    "max_length": max_length,
    "min_length": int(max_length * 0.5),
    "type": "contiguous",
}
raw_segment_df = segment_sequences(
    sequences, segmentation_params, AsDataFrame=True
)
print(f"[prepare_dataset] Number of segments: {len(raw_segment_df)}")

# Wrap into HF Dataset (in memory)
hf_dataset = Dataset.from_pandas(raw_segment_df)

# Tokenization function (same as before, except no labels)
def _tokenize_fn(batch):
    tokenized = tokenizer(
        batch["segment"],
        padding="longest",
        truncation=True,
        max_length=max_length,
    )
    # Zero out first/last attention token
    masks = tokenized["attention_mask"]
    for m in masks:
        m[0] = 0
        m[-1] = 0
    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": masks
    }

print(f"[prepare_dataset] Tokenizing with {num_cores} CPU core(s)")
tokenized_ds = hf_dataset.map(
    _tokenize_fn,
    batched=True,
    num_proc=num_cores,
    remove_columns=hf_dataset.column_names,
    keep_in_memory=True,
)

# Predicting phage lifestyle phenotype
Now, we have the dataset which can be passed through the finetune model.


In [None]:
final_columns = ['sequence_id', 'fasta_id', 'predicted_label', 'score_temperate', 'score_virulent']
final_columns_rename = ['sequence_id', 'predicted_label', 'score_temperate', 'score_virulent', 'fasta_id']
tmp_output = "./prokbert_inference_output"
os.makedirs(tmp_output, exist_ok=True)


training_args = TrainingArguments(
    output_dir=tmp_output,
    do_train=False,
    do_eval=False,
    per_device_eval_batch_size = 32,
    fp16=True,
    remove_unused_columns=False,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

print("[main] Running prediction on segments...")
predictions = trainer.predict(tokenized_ds)
# Building the final table:


final_table = inference_binary_sequence_predictions(predictions, hf_dataset)
final_table['predicted_label'] = final_table.apply(lambda x:  'virulent' if x['predicted_label']=='class_1' else 'temperate', axis=1)
final_table = final_table.merge(sequences[['sequence_id', 'fasta_id']], how='left',
                                  left_on='sequence_id', right_on='sequence_id')
#print(final_table)
final_table.columns = final_columns_rename
final_table = final_table[final_columns]

final_table

# Enjoy! :)
