In [None]:
import polars as pl

In [None]:
amr = pl.scan_parquet("../temp/data/processed/card_amr.parquet")

In [None]:
non_amr = pl.scan_parquet("../temp/data/processed/non_amr_genes_10000.parquet")

In [None]:
non_amr.head().collect()

In [None]:
# Find all the 	genomic_nucleotide_accession.version in non_amr
non_amr.select("genomic_nucleotide_accession.version").unique().collect()

In [None]:
# In AMR data, every column name is antibiotic column except for 'sequence' column
label_columns = [col for col in amr.collect_schema().names() if col != "sequence"] + ["non-AMR"]

In [None]:
print("Label columns:", label_columns)

In [None]:
# Output the schema as json
pl.Series(label_columns).to_frame("labels").write_csv("../temp/data/processed/labels.csv")

In [None]:
# Cache the labels for both AMR and non-AMR
# Make a new dataset with only the labels, if column does not exist, fill with 0
amr_labels = (
    amr
    .with_columns(pl.lit(0).alias("non-AMR"))
    .select(label_columns)
    .collect()
)

In [None]:
non_amr_labels = pl.DataFrame(
    {
        **{col: [0] * non_amr.collect().height for col in label_columns if col != "non-AMR"},
        "non-AMR": [1] * non_amr.collect().height,
    }
)

In [None]:
# Cast non_amr_labels columns to match amr_labels dtypes
non_amr_labels_casted = non_amr_labels.with_columns([
	pl.col(col).cast(amr_labels.schema[col]) for col in label_columns
])
label_cache = pl.concat([amr_labels, non_amr_labels_casted], how="vertical")

In [None]:
from plotly import express as px

# Visualize the distribution of each label
label_sums = label_cache.select(pl.all().sum())
fig = px.bar(x=label_sums.columns, y=label_sums.row(0))
fig.show()

In [None]:
# Save the label cache to parquet
label_cache.write_parquet("../temp/data/cache/labels_cache.parquet")