In [15]:
import polars as pl

In [16]:
amr = pl.scan_parquet("../temp/data/processed/card_amr.parquet")

In [17]:
non_amr = pl.scan_parquet("../temp/data/processed/non_amr_genes_10000.parquet")

In [18]:
non_amr.head().collect()

GeneID,#tax_id,status,genomic_nucleotide_accession.version,start_position_on_the_genomic_accession,end_position_on_the_genomic_accession,orientation,Symbol,non-AMR
i64,i64,str,str,i64,i64,str,str,i32
57034424,446,,"""NZ_CP013742.1""",460139,460801,"""+""","""AVR58_RS02155""",1
57034428,446,,"""NZ_CP013742.1""",464223,464762,,"""AVR58_RS02175""",1
57034440,446,,"""NZ_CP013742.1""",473737,474546,,"""ankJ""",1
57034448,446,,"""NZ_CP013742.1""",482059,484410,"""+""","""icmO""",1
57034451,446,,"""NZ_CP013742.1""",485695,486333,"""+""","""AVR58_RS02300""",1


In [None]:
# Find all the 	genomic_nucleotide_accession.version in non_amr
non_amr.select("genomic_nucleotide_accession.version").unique().collect()

AttributeError: 'LazyFrame' object has no attribute 'distinct'

In [20]:
# In AMR data, every column name is antibiotic column except for 'sequence' column
label_columns = [col for col in amr.collect_schema().names() if col != "sequence"] + ["non-AMR"]

In [21]:
print("Label columns:", label_columns)

Label columns: ['disinfecting agents and antiseptics', 'glycylcycline', 'rifamycin antibiotic', 'macrolide antibiotic', 'streptogramin antibiotic', 'pyrazine antibiotic', 'tetracycline antibiotic', 'bicyclomycin-like antibiotic', 'isoniazid-like antibiotic', 'nitroimidazole antibiotic', 'orthosomycin antibiotic', 'nitrofuran antibiotic', 'carbapenem', 'pactamycin-like antibiotic', 'moenomycin antibiotic', 'cycloserine-like antibiotic', 'cephalosporin', 'diaminopyrimidine antibiotic', 'fluoroquinolone antibiotic', 'antibiotic without defined classification', 'phosphonic acid antibiotic', 'pleuromutilin antibiotic', 'elfamycin antibiotic', 'nucleoside antibiotic', 'peptide antibiotic', 'phenicol antibiotic', 'streptogramin A antibiotic', 'aminocoumarin antibiotic', 'streptogramin B antibiotic', 'sulfonamide antibiotic', 'fusidane antibiotic', 'zoliflodacin-like antibiotic', 'sulfone antibiotic', 'thiosemicarbazone antibiotic', 'glycopeptide antibiotic', 'oxazolidinone antibiotic', 'amino

In [22]:
# Output the schema as json
pl.Series(label_columns).to_frame("labels").write_csv("../temp/data/processed/labels.csv")

In [23]:
# Cache the labels for both AMR and non-AMR
# Make a new dataset with only the labels, if column does not exist, fill with 0
amr_labels = (
    amr
    .with_columns(pl.lit(0).alias("non-AMR"))
    .select(label_columns)
    .collect()
)

In [24]:
non_amr_labels = pl.DataFrame(
    {
        **{col: [0] * non_amr.collect().height for col in label_columns if col != "non-AMR"},
        "non-AMR": [1] * non_amr.collect().height,
    }
)

In [25]:
# Cast non_amr_labels columns to match amr_labels dtypes
non_amr_labels_casted = non_amr_labels.with_columns([
	pl.col(col).cast(amr_labels.schema[col]) for col in label_columns
])
label_cache = pl.concat([amr_labels, non_amr_labels_casted], how="vertical")

In [26]:
from plotly import express as px

# Visualize the distribution of each label
label_sums = label_cache.select(pl.all().sum())
fig = px.bar(x=label_sums.columns, y=label_sums.row(0))
fig.show()

In [27]:
# Save the label cache to parquet
label_cache.write_parquet("../temp/data/interim/labels_cache.parquet")