In [1]:
import polars as pl

# Load data
final_data = pl.read_parquet("data/gdsc_single_cell_aligned.parquet")

# Count unique cell lines per drug
drug_counts = (
    final_data
    .group_by("DRUG_ID")
    .agg(
        pl.col("SANGER_MODEL_ID").n_unique().alias("num_cell_lines")
    )
    .sort("num_cell_lines", descending=True)
)

# Get top 20 drug IDs
top_20_drugs = drug_counts.head(20)["DRUG_ID"].to_list()

# Filter the full dataset to include only those top 20 drugs
filtered_data = final_data.filter(pl.col("DRUG_ID").is_in(top_20_drugs))

# Save the filtered dataset
filtered_data.write_parquet("data/top20_drugs_dataset.parquet")
