<a href="https://colab.research.google.com/github/olympus-terminal/LLM-training/blob/main/Mambalga_workflow_1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

We will fine-tune a pretrained MAMBA transformers-compatible LLM to recognize 'algal' or 'bacterial' signatures in protein sequence. First, we need to prepare the input data. If we check line sizes, we see that they are quite variable and could influence training through inconsistent padding lengths:

In [None]:
import sys

def print_line_char_counts(input_file):
  """
  Prints the number of characters in each line of a text file.

  Args:
    input_file: The path to the input text file.
  """
  with open(input_file, 'r') as infile:
    for line_number, line in enumerate(infile, 1):
      char_count = len(line.strip())
      print(f"Line {line_number}: {char_count} characters")

# Check if input file is provided as sys.argv[1]
if len(sys.argv) > 1:
  input_file = sys.argv[1]
  print_line_char_counts(input_file)
else:
  print("Error: Please provide the input text file as a command-line argument.")

In [None]:
remove_wraps_and_headers.py
import sys

def remove_line_wrapping_and_headers(input_file):
  """
  Removes line wrapping and headers from a FASTA file.

  Args:
    input_file: The path to the input FASTA file.
  """
  output_file = f"{input_file}.unwrapped_no_headers"
  with open(input_file, 'r') as infile, open(output_file, 'w') as outfile:
    for line in infile:
      # Skip header lines and write only sequence data
      if not line.startswith('>'):
        outfile.write(line.strip())

# Check if input file is provided as sys.argv[1]
if len(sys.argv) > 1:
  input_file = sys.argv[1]
  remove_line_wrapping_and_headers(input_file)
else:
  print("Error: Please provide the input FASTA file as a command-line argument.")

Now we will run the fine-tune and log the results at WandB. The main considerations for balancing performance and scalability are:
1. Mamba model size: 130m, 370m, 790m, 1.4B, or 2.8B.
2. The 'per_device_train_batch_size' parameter. This can speed up the job considerably and smooth out roughness but possibly blunt fine features and eventually cause OOM errors.
3. Finally, 'the max_seq_length' parameter directly decreases the batch size needed to cause an OOM. F1 scores of ~80 when using short seq. lengths (128) but ~90 when using 2048.

In [None]:
####python archive_train_wandb-shuffle-downsample.py

import wandb
from datasets import Dataset
from trl import SFTTrainer
from peft import LoraConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
import pandas as pd
import sys
import zipfile
import os
import random
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

# Initialize wandb
wandb.init(project="MambalgaP", entity="algaeai")

# Ensure there is a command-line argument provided
if len(sys.argv) < 2:
    print("Usage: python script.py <path_to_zip_file>")
    sys.exit(1)

zip_file_path = sys.argv[1]

#tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-370m-hf")
#model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-370m-hf")
tokenizer = AutoTokenizer.from_pretrained("state-spaces/mamba-1.4b-hf", use_special_tokens=False)
model = AutoModelForCausalLM.from_pretrained("state-spaces/mamba-1.4b-hf")

# Function to extract zip file and read all text files
def load_text_from_zip(zip_file_path, sample_size=None, fraction=1):
    text_data = []
    with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
        zip_ref.extractall("temp_data")
        for root, dirs, files in os.walk("temp_data"):
            for file in files:
                if file.endswith(".txt"):
                    file_path = os.path.join(root, file)
                    with open(file_path, 'r', encoding='utf-8') as f:
                        text_data.extend(f.readlines())

    # Shuffling the data
    random.shuffle(text_data)

    # Downsampling the data
    if sample_size is not None:
        text_data = text_data[:sample_size]
    elif fraction is not None:
        text_data = text_data[:int(len(text_data) * fraction)]

    return {"text": [line.strip() for line in text_data if line.strip()]}

# Load and preprocess the data
data = load_text_from_zip(zip_file_path)
dataframe = pd.DataFrame(data)
dataset = Dataset.from_pandas(dataframe)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=1,
    per_device_train_batch_size=28,
    logging_dir='./logs',
    logging_steps=100,
    learning_rate=5e-4,
    report_to="wandb"  # Integrate Weights & Biases for tracking
)

lora_config = LoraConfig(
    r=8,
    target_modules=["x_proj", "embeddings", "in_proj", "out_proj"],
    task_type="CAUSAL_LM",
    bias="none"
)

trainer = SFTTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    peft_config=lora_config,
    train_dataset=dataset,
    dataset_text_field="text",  # Ensure this matches your data's text field
    max_seq_length=128,
)

trainer.train()

# Clean up the extracted files
os.system("rm -rf temp_data")

# Finish the wandb run
wandb.finish()

Now we deploy on an HPC cluster using 4 gpus per job (A100 or V100):

In [None]:
#!/bin/bash

#SBATCH --mem=800GB
#SBATCH --time=96:00:00
#SBATCH -p nvidia
#SBATCH --gres=gpu:4
#SBATCH --cpus-per-task=10

PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True && python archive_train_wandb-shuffle-downsample.py /scratch/drn2/transfer/fasta-processing/finalize/training_100s.zip

The train run will save results as multiple checkpoint folders with .safetensors and .pkl files


Now we will evaluate the fine-tuned model (Mambalga v.x) on a set of sequences (10% of total) not used for training:

In [None]:
###filename: infer-singleLoad.py or infer-SL2(or 2).py

import sys
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def generate_output(input_text, model, tokenizer):
    # Tokenize input text
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to('cuda')  # Move input_ids to GPU

    # Generate output
    with torch.no_grad():  # Disable gradient calculation for inference
        outputs = model.generate(input_ids, max_new_tokens=12)

    # Decode the output tokens back to text
    output_text = tokenizer.decode(outputs[0], skip_special_tokens=False)

    return output_text

if __name__ == "__main__":
    if len(sys.argv) != 3:
        print("Usage: python script.py <model_name_or_path> <input_file>")
        sys.exit(1)

    model_name_or_path = sys.argv[1]
    input_file = sys.argv[2]

    # Load the model and tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to('cuda')  # Move the model to GPU

    with open(input_file, 'r') as file:
        for line in file:
            input_text = line.strip()
            output_text = generate_output(input_text, model, tokenizer)
            print(output_text)

Lets again deploy on gpu nodes. This example is looking at the prediction results from a 1.4b training (0.04 epochs, or < 5% of the data [chkpt 355000]). The tuning is relatively slow (compared to the 130m model) because of the smaller batch size (16 compared to 512; necessary to not throw an OOM error) and because of the 10x increase in parameters. The WandB logs show this run to maintain a steadier grad_norm and better loss minimization, so we want to see if earlier checkpoints have any predictive power compared to smaller models.

The input files are the 10% holdout sets from the algal and bacterial sequence blocks, and the output files contain the predictions we will use for evaluation of F1 scores.

In [None]:
#!/bin/bash

#SBATCH -o slurm-logs/arrayJob_%A_%a.out
#SBATCH -e slurm-logs/arrayJob_%A_%a.err
#SBATCH -a 1-2
#SBATCH --mem=250GB
#SBATCH --time=4:00:00
#SBATCH -p nvidia
#SBATCH --gres=gpu:2
#SBATCH --cpus-per-task=10


LINE=$(sed -n "$SLURM_ARRAY_TASK_ID"p filelist.txt)
echo "$LINE"

python infer-singleLoad.py '/scratch/drn2/PROJECTS/AI/Mambalga/WDN_4/nest-afterEval/results/checkpoint-35500' $LINE > ./eval-results_"$LINE"

In [None]:
>>>The output looks like this:

head eval-results_* ==> eval-results_bact1_accns.headers-fetched.aa.fa.unwrapped_no_headers.wrapped.10 <== EPTKREAQRIEKLQSKMQELAAAVDDALDAEDEDKADALQEEGEAVGEQLQALEDGLQDYSPTAKAAAGAIVTIDRNRQAVIHRGLMREAEAKALRTLER [<!!!>] [<!!!>] [<!!!>] [<!!!>] NNYVALHYEARHMNWLHSFWGLGASAGPAVMGLSLTLQWGYRNGYRVLGLMQTLQVIVLIASLALWTDPKGKRTAQDYPSQQKGTLRHKALPFALLSFFL [<!!!>] [<!!!>] [<!!!>] [<!!!>] VYKIIYQVLPEWWPKIAKRNMIVTGAAAGLAATFNTPLGGIVFAIEELTKTHFSYYKTAIFSSVIIAGLSAQALLGPYLYLGYPKLDGLTNTIFLGVAVV [<!!!>] [<!!!>] [<!!!>] [<!!!>] GIRQRQPLKIVVAVPVASPQAISNLQAEVDEVVCLYMPPAMGAVGYYYDEFAQVSDDDVVRLLAAFQKKNKANEVLSFLSENASESFSAKELSEKLNISE [<!!!>] [<!!!>] [<!!!>] [<!!!>] RVSAVADLAAFHACYTPALHRIAIIDLGLPDGEGMVLVRALRAQGHPIGIVVFTARGATQDKVNGLGGGADYYLPKSADLDELAATLGALARRLGAPPAD [<!!!>] [<!!!>] [<!!!>] [<!!!>] AFFGASTAFVESTLAQVYRFPHRSSFRGGPFCYIQEGLHKRWIGVLFAILTIVCYGVFLPTVQANGASQAFFNSYSVTPAATGIGLAILLAIVIIGGVKR [<!!!>] [<!!!>] [<!!!>] [<!!!>] YFGFMGEFGWGATWGEDSYDADEYDASSPDAWEGEERQIVLIANPGFRWRLGRSLFLNLGLYAGAAIDVKDEIVYINQNNTTEDYRGAMFFGMIEFALGW [<!!!>] [<!!!>] [<!!!>] [<!!!>] RITPAPSITKNELENWVRANASTDYHPCGTCRMGTDKHAVVDQELRVHGIDGLRIVDASVMPDILSGNLNAPTQMIAERAADYLMGRPQLPEEHAKFHFM [<!!!>] [<!!!>] [<!!!>] [<!!!>] LIDFMKKSCLTTFNDIDVTGKIKCSVDARLYYSNISEKTDFYNFKYFYDMKGNEIYVESADAGIKLNLLNMYNKPIHLWDLRGFHKKIHYDNLQRIKEVY [<!!!>] [<!!!>] [<!!!>] [<!!!>] RQLNPANKADRELAHSLGLATEEMLGSIARWSDDGLTSTHGKSEKLARISSGVASLVMRVSLLNALTAASKVGFTKLLMEKYGRLSRSKAWGDLDIQDRE [<!!!>] [<!!!>] [<!!!>] [<!!!>]

==> eval-results_Filtered_algal_doubled.aa.fa.unwrapped_no_headers.wrapped.10 <== LARAADLSCDAWGHPCGPLHRMAHGHPLDPLLPRVCPIPAADSGSGTGAAPPRPPPSPSGGKGKGKKRPPPPAPAKDNRRPPSPAKKKKNRPPKNPKLGP [<@@@>] [<@@@>] [<@@@>] YMIDRNLPSPREMSRSRGSKFRAETPIFSLTGKNNLKVGESVELTAQEEMAKAESEMLRQTAKIERHRRQSRTIMVDSLTVTLEVPRPFCLVLLRQVCP [<@@@>] [<@@@>] [<@@@>] LDHMIGHIQYCRPAASVNLNPGSFLSQLVRMGAPSQLPPTHPPALSPAYSAAYSAACPHPHDPVHARALSSSCAPPHLMERGSRANWAGVAVTGAEAAF [<@@@>] [<@@@>] [< [< [< [< YPEGCMGLCRVHRQAQQDRKKGDAGSSVMLDISHLTPASLALHVVGCCRQLLLELVLTTALGPKAGELTSRRPDSRAAIGSARTRNLINYAKTTAAAAL [<@@@>] [<@@@>] [< [< [< [< GDDPTGGPTFKRADPTSQHDPSHQEVMGQDDVAASQAPAVPRGASESDLLGLGDTLLASRMALWTAGFVVMTLLINAPLLTPLMTLLGLNKATPGQLHV [<@@@>] [<@@@>] [< [< [< [< PTPHKLRMPLKLSRETVAEVKPASTASCSRLASQHGPGLGQLQLELNHRSQVRLTVNRQAGLKVPGDPRLRSTCCVLLTCSSSFCWASRATSARASTYYA [<@@@>] [<@@@>] [< [< [< [< QRKLVTLYHCPTTEMPADMLTKPLGRVLFEKFVGMVGLSECKAAQSMMDDQSSGSVRIEKLTASNFYIWKQKIQLLLALRDVDQYGFDDIPENATSDDRL [<@@@>] [<@@@>] [< [< [< GCAASNVGNGANAGYPCPPCYFGVTATRATHSSPVLSFRQPKESLVRLKELGGKVAGGVGQVVDKVRQKWDARQAHKKQNPPGGRGKVGRKQNKAQGNEG [<@@@>] [<@@@>] [<@@@>] KVGCHDHFHQRPCFEMHHLTCLACYSAPWGCFTKWGRWKCWRALLLAIRHKPMLHCSLAHLAQAEAINQLQEERAQLSTAVEQGLADQADLHTLRQQAG [<@@@>] [<@@@>] [< [< [< [< HQHQWLAFLDVDEFIVIDEEKVSIPLMLSNYKQYGGVGLNWVMFGSSGYDQRPEGGVRNYTMCNRSLDFTSLTAQSAHDPWRQCDHIEHNRQSNFPDNRH [<@@@>] [<@@@>] [<@@@>]


From a glance, the model appears to be predicting sequence sources correctly. The generation of [<@@@>] indicates the model correctly assigned an algal tag to an algal input sequence, while lines with [<!!!>] signify bacterial IDs. There may be an emergent property of this method of NTG; bacterial sequences always get 4 tags, while algal seqs get either 2 or 3 tags. This may indicate that bacterial sequences are thoroughly such, while algal sequences vary in their degree of intrinsic 'algal-ness'. Seeing that algal lineages are relatively evolutionarily new compared to bacterial, the variance observed here might be giving some information on the sequences' evolutionary age. Defined sequence motifs, possibly quantified through SHAP, could indicate an 'old/core' algal seq, while others could indicate specialization or recent acquisitions by some lineages.

Full inference of ~60,000 sequences (comprising the eval 10% holdout sets; can use 10,000 eval seqs for quickchecks) will take 3-3.5 hours on two gpus, so lets run the summary script to get stats on a previous run for an example:

In [None]:
###filename: MoreMetrics.sh

#!/bin/bash

# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
NC='\033[0m' # No Color

# Counting algal hits
echo -e "${GREEN}Counting algal hits...${NC}"
fgrep -c '@' eval-resu*

echo -e "\n${GREEN}Counting bacterial hits...${NC}"
fgrep -c '!' eval-resu*

echo -e "\n${YELLOW}In other words, from the algal holdout set there are:${NC}"
count_algae=$(fgrep -c '@' eval-results_Filtered_algal_doubled.aa.fa.unwrapped_no_headers.wrapped.10)
echo -e "${GREEN}${count_algae} algal signatures.${NC}"

count_bacteria=$(fgrep -c '!' eval-results_Filtered_algal_doubled.aa.fa.unwrapped_no_headers.wrapped.10)
echo -e "${RED}${count_bacteria} bacterial signatures.${NC}"

echo -e "\n${YELLOW}And from the bacterial holdout set, there are:${NC}"
count_algae_bact=$(fgrep -c '@' eval-results_bact1_accns.headers-fetched.aa.fa.unwrapped_no_headers.wrapped.10)
echo -e "${GREEN}${count_algae_bact} algal signatures.${NC}"

count_bacteria_bact=$(fgrep -c '!' eval-results_bact1_accns.headers-fetched.aa.fa.unwrapped_no_headers.wrapped.10)
echo -e "${RED}${count_bacteria_bact} bacterial signatures.${NC}"

# Calculate performance metrics
total_algae=$((count_algae + count_algae_bact))
total_bacteria=$((count_bacteria + count_bacteria_bact))

true_positives_algae=$count_algae
false_positives_algae=$count_bacteria
true_negatives_bacteria=$count_bacteria_bact
false_negatives_bacteria=$count_algae_bact

precision_algae=$(echo "scale=4; $true_positives_algae / ($true_positives_algae + $false_positives_algae)" | bc)
recall_algae=$(echo "scale=4; $true_positives_algae / ($true_positives_algae + $false_negatives_bacteria)" | bc)
f1_score_algae=$(echo "scale=4; 2 * ($precision_algae * $recall_algae) / ($precision_algae + $recall_algae)" | bc)

precision_bacteria=$(echo "scale=4; $true_negatives_bacteria / ($true_negatives_bacteria + $false_negatives_bacteria)" | bc)
recall_bacteria=$(echo "scale=4; $true_negatives_bacteria / ($true_negatives_bacteria + $false_positives_algae)" | bc)
f1_score_bacteria=$(echo "scale=4; 2 * ($precision_bacteria * $recall_bacteria) / ($precision_bacteria + $recall_bacteria)" | bc)

echo -e "\n${YELLOW}Performance Metrics:${NC}"
echo -e "Algal Precision: ${GREEN}$precision_algae${NC}"
echo -e "Algal Recall: ${GREEN}$recall_algae${NC}"
echo -e "Algal F1 Score: ${GREEN}$f1_score_algae${NC}"

echo -e "Bacterial Precision: ${RED}$precision_bacteria${NC}"
echo -e "Bacterial Recall: ${RED}$recall_bacteria${NC}"
echo -e "Bacterial F1 Score: ${RED}$f1_score_bacteria${NC}"

In [None]:
<<OUTPUT>>>

Counting algal hits...
eval-results_bact1_accns.headers-fetched.aa.fa.unwrapped_no_headers.wrapped.10:2910
eval-results_Filtered_algal_doubled.aa.fa.unwrapped_no_headers.wrapped.10:27094

Counting bacterial hits...
eval-results_bact1_accns.headers-fetched.aa.fa.unwrapped_no_headers.wrapped.10:28731
eval-results_Filtered_algal_doubled.aa.fa.unwrapped_no_headers.wrapped.10:5493

In other words, from the algal holdout set there are:
27094 algal signatures.
5493 bacterial signatures.

And from the bacterial holdout set, there are:
2910 algal signatures.
28731 bacterial signatures.

Performance Metrics:
Algal Precision: .8314
Algal Recall: .9030
Algal F1 Score: .8656
Bacterial Precision: .9080
Bacterial Recall: .8394
Bacterial F1 Score: .8722

SyntaxError: invalid syntax (<ipython-input-1-52a649b84446>, line 1)

This performance is OK, but we want to boost F1 scores. For one, we don't want any false positives showing as 'algae' from the bacterial set. Let's use a lower learning rate, larger max_seq_length, and we will use the 370m model. Since we will be benchmarking often, lets also sample the eval set to only 10,000 seqs for now.

In [None]:
C<<OUTPUT>>>

ounting algal hits...
eval-results_AlgalTop10000-10holdout:8299
eval-results_BactTop10000-10holdout:723

Counting bacterial hits...
eval-results_AlgalTop10000-10holdout:1703
eval-results_BactTop10000-10holdout:9277

In other words, from the algal holdout set there are:
8299 algal signatures.
1703 bacterial signatures.

And from the bacterial holdout set, there are:
723 algal signatures.
9277 bacterial signatures.

Performance Metrics:
Algal Precision: .8297
Algal Recall: .9198
Algal F1 Score: .8723
Bacterial Precision: .9277
Bacterial Recall: .8448
Bacterial F1 Score: .8842

These results are promising: there are less than half the 'false positives' from the bacterial set, as expected. Still, only marginal improvements on F1 scores were made. Lets increase the lora_alpha (default = 16) to 64; this will amplify the influence from the tagged amino acid sequences and dampen the original pretrained weights. We also know that adding math instructions, such as from the OpenOrca math training datasets, can improve reasoning. Lets add that to the training set to see if it helps as well.
