DATASET CREATION (Bacterial Genomes to Labeled Dataset)

## Notebook dependency helper

If this notebook raises ModuleNotFoundError for modules like `pandas` or `Bio`, run the cell below to install them into the active kernel. This installs into the same Python that the kernel is using.


In [None]:
# This cell installs missing dependencies into the active kernel; run this cell once before the main script if you see ModuleNotFoundError
import sys
import subprocess

missing = []
try:
    import pandas as pd
except Exception:
    missing.append("pandas")
try:
    from Bio import SeqIO
except Exception:
    missing.append("biopython")

if missing:
    print('Installing packages:', missing)
    subprocess.check_call([sys.executable, '-m', 'pip', 'install', '--upgrade'] + missing)
    print('Installation finished â€” restart the kernel or re-run the notebook cells if necessary.')
else:
    print('All required packages are already available in the kernel')


In [2]:
# ===============================
#   DNA DATASET CREATION SCRIPT
#   Multiclass (4 classes)
#   15,000 samples per class
#   Sequence length = 300 bp
#   Output: CSV files
# ===============================

import os
import random
import pandas as pd
from Bio import SeqIO
from tqdm import tqdm

random.seed(42)

# ==========================================================
# CONFIG
# ==========================================================
GENOMES_DIR = "data/genomes/"          # Folder containing .gbff/.gbk files
SEQ_LEN = 300                          # Fixed window length
SAMPLES_PER_CLASS = 15000              # Per class
CLASSES = ["promoter", "cds", "terminator", "intergenic"]

# ==========================================================
# HELPER FUNCTIONS
# ==========================================================

def extract_promoters(record, flanking=150):
    """Extract promoter regions based on gene annotations."""
    promoters = []
    for feature in record.features:
        if feature.type == "gene" or feature.type == "CDS":
            try:
                start = int(feature.location.start)
                end   = int(feature.location.end)
                strand = feature.location.strand

                if strand == 1:
                    prom_start = max(0, start - flanking)
                    prom_end   = start + 50
                else:
                    prom_start = max(0, end - 50)
                    prom_end   = min(len(record.seq), end + flanking)

                promoters.append(str(record.seq[prom_start:prom_end]))
            except:
                pass
    return promoters


def extract_cds(record):
    cds_list = []
    for f in record.features:
        if f.type == "CDS":
            try:
                seq = f.extract(record.seq)
                cds_list.append(str(seq))
            except:
                pass
    return cds_list


def extract_terminators(record):
    """Simplified terminator extraction: near poly-T or hairpin-like regions."""
    seq = str(record.seq)
    terminators = []
    
    for i in range(0, len(seq)-40, 40):
        window = seq[i:i+40]
        if window.count("T") > 20:  # crude terminator signal (poly-T tail)
            terminators.append(seq[i:i+SEQ_LEN])
    return terminators


def extract_intergenic(record):
    inter = []
    seq = str(record.seq)
    occupied = []

    for f in record.features:
        if "location" in f.__dict__:
            s = int(f.location.start)
            e = int(f.location.end)
            occupied.append((s, e))

    occupied = sorted(occupied)
    intergenic_regions = []

    # regions between genes
    last_end = 0
    for (s, e) in occupied:
        if s - last_end > SEQ_LEN:
            intergenic_regions.append(seq[last_end:s])
        last_end = e

    # convert long regions into windows
    windows = []
    for region in intergenic_regions:
        for i in range(0, len(region) - SEQ_LEN, 100):
            windows.append(region[i:i+SEQ_LEN])

    return windows


def clean_and_fix_length(seq, length=300):
    seq = seq.upper()
    seq = seq.replace("N", "A")  # replace ambiguous bases
    if len(seq) < length:
        return None
    return seq[:length]


def kmerize(seq, k=3):
    return " ".join([seq[i:i+k] for i in range(len(seq)-k+1)])


def tokenize_dl(seq):
    mapping = {"A":0, "C":1, "G":2, "T":3}
    return [mapping.get(b, 0) for b in seq]
    

# ==========================================================
# MAIN EXTRACTION LOOP
# ==========================================================
all_promoters = []
all_cds = []
all_terminators = []
all_intergenics = []

print("Reading genomes from:", GENOMES_DIR)
files = [f for f in os.listdir(GENOMES_DIR) if f.endswith(".gbff") or f.endswith(".gbk")]

for file in files:
    print("Processing genome:", file)
    path = os.path.join(GENOMES_DIR, file)
    for record in SeqIO.parse(path, "genbank"):
        
        # Promoters
        proms = extract_promoters(record)
        for p in proms:
            p = clean_and_fix_length(p, SEQ_LEN)
            if p: all_promoters.append(p)

        # CDS
        cds_list = extract_cds(record)
        for c in cds_list:
            c = clean_and_fix_length(c, SEQ_LEN)
            if c: all_cds.append(c)

        # Terminators
        terms = extract_terminators(record)
        for t in terms:
            t = clean_and_fix_length(t, SEQ_LEN)
            if t: all_terminators.append(t)

        # Intergenic
        inter = extract_intergenic(record)
        for i in inter:
            i = clean_and_fix_length(i, SEQ_LEN)
            if i: all_intergenics.append(i)


# ==========================================================
# BALANCE & SAMPLE CLASSES
# ==========================================================
print("Sampling 15,000 per class...")

promoters = random.sample(all_promoters, SAMPLES_PER_CLASS)
cds       = random.sample(all_cds, SAMPLES_PER_CLASS)
terms     = random.sample(all_terminators, SAMPLES_PER_CLASS)
inter     = random.sample(all_intergenics, SAMPLES_PER_CLASS)

dataset = []

def add_samples(seqs, label):
    for s in seqs:
        dataset.append([s, label])

add_samples(promoters, "promoter")
add_samples(cds, "cds")
add_samples(terms, "terminator")
add_samples(inter, "intergenic")

random.shuffle(dataset)

df_raw = pd.DataFrame(dataset, columns=["sequence", "label"])

# ==========================================================
# SAVE RAW
# ==========================================================
df_raw.to_csv("dataset_raw.csv", index=False)
print("Saved: dataset_raw.csv")


# ==========================================================
# GENERATE ML VERSION (k-mer 3)
# ==========================================================
df_ml = df_raw.copy()
df_ml["kmers"] = df_raw["sequence"].apply(lambda x: kmerize(x, k=3))
df_ml.to_csv("dataset_ml.csv", index=False)
print("Saved: dataset_ml.csv (Naive Bayes + N-gram LM)")


# ==========================================================
# GENERATE DL VERSION (integer tokens)
# ==========================================================
df_dl = df_raw.copy()
df_dl["tokens"] = df_raw["sequence"].apply(tokenize_dl)
df_dl.to_csv("dataset_dl.csv", index=False)
print("Saved: dataset_dl.csv (RNN / BiLSTM / Transformer)")


# ==========================================================
# SUMMARY
# ==========================================================
print("\nDataset creation complete!")
print("Total samples:", len(df_raw))
print(df_raw["label"].value_counts())


ModuleNotFoundError: No module named 'pandas'