In [None]:
# Get non-AMR
import polars as pl
from pathlib import Path

In [None]:
bacteria_gene_info = pl.scan_parquet(Path("../temp/data/raw/bacteria.gene_info.20251001.parquet"))

In [None]:
gene2accession = pl.scan_parquet(Path("../temp/data/raw/gene2accession.20251006.parquet"))

In [None]:
gene2accession.head().collect()

In [None]:
protein_coding_genes = (
    bacteria_gene_info.filter(pl.col("type_of_gene") == "protein-coding")
    .select("GeneID")
    .join(gene2accession, on="GeneID", how="inner")
    .filter(pl.col("end_position_on_the_genomic_accession").is_not_null() & pl.col("start_position_on_the_genomic_accession").is_not_null())
)

In [None]:
protein_coding_genes.select(
    (
        (
            pl.col("end_position_on_the_genomic_accession")
            - pl.col("start_position_on_the_genomic_accession")
        )
    ).mean()
).collect()

In [None]:
# Get number of unique accessions
protein_coding_genes.select(pl.col("genomic_nucleotide_accession.version").n_unique()).collect()

In [None]:
# Get top 20000 protein coding genes
limited = protein_coding_genes.limit(20000)

In [None]:
# Get unique 	genomic_nucleotide_accession.version in the limited set
limited.select(pl.col("genomic_nucleotide_accession.version").n_unique()).collect()

In [None]:
limited.select(
    (
        (
            pl.col("end_position_on_the_genomic_accession")
            - pl.col("start_position_on_the_genomic_accession")
        )
    ).mean()
).collect()

In [None]:
# Extract GeneID	#tax_id	status	genomic_nucleotide_accession.version	start_position_on_the_genomic_accession	end_position_on_the_genomic_accession	orientation	Symbol	non-AMR
# Non-AMR is always 1
limited.select(
    [
        "GeneID",
        "#tax_id",
        "status",
        "genomic_nucleotide_accession.version",
        "start_position_on_the_genomic_accession",
        "end_position_on_the_genomic_accession",
        "orientation",
        "Symbol",
    ]
).with_columns(
    pl.lit(1).alias("non_amr")
).sink_parquet(
    Path("../temp/data/processed/non_amr_genes_10000.parquet"), compression="zstd"
)

In [None]:
non_amr = pl.scan_parquet("../temp/data/processed/non_amr_genes_10000.parquet")

In [None]:
import ssl, certifi
import urllib.request

# Force urllib to always use certifi certs
ssl._create_default_https_context(cafile=certifi.where())

from Bio import Entrez

urllib.request.install_opener(
    urllib.request.build_opener(
        urllib.request.HTTPSHandler(
            context=ssl.create_default_context(cafile=certifi.where())
        )
    )
)

Entrez.email = "j7.jacob@hdr.qut.edu.au"

# For each accession in accessions, fetch the fasta sequence from NCBI
accession_list = non_amr.select("genomic_nucleotide_accession.version").unique().collect().to_series().to_list()

from pathlib import Path
from tqdm import tqdm
loop = tqdm(accession_list)
try:
    for acc in loop:
        fasta_path = Path(f"../temp/data/external/sequences/{acc}.fasta")
        # Check if sequences already exists
        if fasta_path.exists():
            loop.set_description(f"Skipping {acc}")
            continue

        handle = Entrez.efetch(db="nucleotide", id=acc, rettype="fasta", retmode="text")
        with open(fasta_path, "w") as out_handle:
            loop.set_description(f"Fetching {acc}")
            out_handle.write(handle.read())
        handle.close()
except Exception as e:
    # Remove last file if it exists (to avoid half-downloaded file)
    if 'fasta_path' in locals() and fasta_path.exists():
        fasta_path.unlink()
    print("\nDownload interrupted. Last file removed. You can rerun the cell to resume.")
    print(e)