In [22]:
# Imports
from pathlib import Path

import plotly.express as px
import polars as pl


In [35]:
transformed_data_with_nodes_path = Path("/home/jpatterson87/for_classes/gnn_for_diabetes/data/transformed_data_from_vms/with_nodes")

sample_id_to_pop_path = Path("/home/jpatterson87/for_classes/gnn_for_diabetes/data/sample_info/sample_id_to_population.txt")

In [36]:
sample_id_to_pop = pl.read_csv(file=Path(sample_id_to_pop_path),
                               has_header=True,
                               sep="\t")

In [100]:
demographic_tabulation = (sample_id_to_pop
    .groupby(by="population_description")
    .agg([
        pl.count().alias("num_samples")
    ])
    .sort("num_samples")
)
demographics_plot = px.bar(x=demographic_tabulation.select("population_description").to_series(), 
                           y=demographic_tabulation.select("num_samples").to_series(),
                           title="Numbers of Samples for each Population",
                           labels={"x": "Population Description",
                                   "y": "Count"})
# https://stackoverflow.com/questions/68400851/how-to-rotate-xtick-label-bar-chart-plotly-express
demographics_plot.update_xaxes(tickangle=270)
demographics_plot.show()
demographics_plot.write_html(file="/home/jpatterson87/for_classes/gnn_for_diabetes/figures/demographics_plot.html")


In [67]:
# https://stackoverflow.com/questions/23782215/iterate-through-files
# Get a list of tuples containing files and their labels.
files_and_labels = []

for tsv in transformed_data_with_nodes_path.glob("*.tsv"):
    sample_id = tsv.name.split(".")[0]
    label = (sample_id_to_pop
        .filter(
            (pl.col("sample_id") == sample_id)
        )
        .select(
            pl.col("population_description")
        )
        # run unique just in case
        .unique()
        .to_series()
        .to_list()
    )
    if len(label) > 0:
        files_and_labels.append((tsv, label[0]))


In [99]:
populations_to_compare = (sample_id_to_pop
                          .select("population_description")
                          .unique()
                          .to_series()
                          .to_list()
)

node_data = dict.fromkeys(populations_to_compare)

for population in populations_to_compare:
    node_summaries = []
    for file, label in files_and_labels:
        if label == population:
            scanned_file = pl.scan_csv(file=file,
                        has_header=True,
                        sep="\t")

            scanned_file_with_counts = (scanned_file
                .with_columns([
                    pl.col("REF")
                        .str.to_uppercase()
                        .str.count_match("A")
                        .alias("REF_A_count"),

                    pl.col("REF")
                        .str.to_uppercase()
                        .str.count_match("C")
                        .alias("REF_C_count"),

                    pl.col("REF")
                        .str.to_uppercase()
                        .str.count_match("G")
                        .alias("REF_G_count"),

                    pl.col("REF")
                        .str.to_uppercase()
                        .str.count_match("T")
                        .alias("REF_T_count"),
            
                    pl.col("ALT")
                        .str.to_uppercase()
                        .str.count_match("A")
                        .alias("ALT_A_count"),

                    pl.col("ALT")
                        .str.to_uppercase()
                        .str.count_match("C")
                        .alias("ALT_C_count"),

                    pl.col("ALT")
                        .str.to_uppercase()
                        .str.count_match("G")
                        .alias("ALT_G_count"),

                    pl.col("ALT")
                        .str.to_uppercase()
                        .str.count_match("T")
                        .alias("ALT_T_count")
                ])

            )

            scanned_file_with_ratios = (scanned_file_with_counts
                .with_columns([
                    (
                        (pl.col("REF_G_count") + pl.col("REF_C_count")) /
                        (pl.col("REF_G_count") + pl.col("REF_C_count") + pl.col("REF_A_count") + pl.col("REF_T_count"))
                    ).alias("REF_GC_ratio"),

                    (
                        (pl.col("ALT_G_count") + pl.col("ALT_C_count")) /
                        (pl.col("ALT_G_count") + pl.col("ALT_C_count") + pl.col("ALT_A_count") + pl.col("ALT_T_count"))
                    ).alias("ALT_GC_ratio")
                ])                          
            )
            
            node_summary = (scanned_file_with_ratios
                .groupby(by="node_id")
                .agg([
                    pl.col("POS").quantile(quantile=0).alias("min_start_pos"),
                    pl.col("POS").quantile(quantile=0.25).alias("first_quartile_start_pos"),
                    pl.col("POS").quantile(quantile=0.5).alias("second_quartile_start_pos"),
                    pl.col("POS").quantile(quantile=0.75).alias("third_quartile_start_pos"),
                    pl.col("POS").quantile(quantile=1).alias("max_start_pos"),

                    pl.col("REF_GC_ratio").quantile(quantile=0).alias("min_REF_GC_ratio"),
                    pl.col("REF_GC_ratio").quantile(quantile=0.25).alias("first_quartile_REF_GC_ratio"),
                    pl.col("REF_GC_ratio").quantile(quantile=0.5).alias("second_quartile_REF_GC_ratio"),
                    pl.col("REF_GC_ratio").quantile(quantile=0.75).alias("third_quartile_REF_GC_ratio"),
                    pl.col("REF_GC_ratio").quantile(quantile=1).alias("max_REF_GC_ratio"),

                    pl.col("ALT_GC_ratio").quantile(quantile=0).alias("min_ALT_GC_ratio"),
                    pl.col("ALT_GC_ratio").quantile(quantile=0.25).alias("first_quartile_ALT_GC_ratio"),
                    pl.col("ALT_GC_ratio").quantile(quantile=0.5).alias("second_quartile_ALT_GC_ratio"),
                    pl.col("ALT_GC_ratio").quantile(quantile=0.75).alias("third_quartile_ALT_GC_ratio"),
                    pl.col("ALT_GC_ratio").quantile(quantile=1).alias("max_ALT_GC_ratio"),
                ])
                .select([
                    "node_id", 
                    "min_start_pos",
                    "first_quartile_start_pos",
                    "second_quartile_start_pos",
                    "third_quartile_start_pos",
                    "max_start_pos",
                    "min_REF_GC_ratio",
                    "first_quartile_REF_GC_ratio",
                    "second_quartile_REF_GC_ratio", 
                    "third_quartile_REF_GC_ratio",
                    "max_REF_GC_ratio",
                    "min_ALT_GC_ratio",
                    "first_quartile_ALT_GC_ratio",
                    "second_quartile_ALT_GC_ratio",
                    "third_quartile_ALT_GC_ratio",
                    "max_ALT_GC_ratio"
                ])
            )

        node_summaries.append(node_summary)
        
    
        
        

shape: (1134, 16)
┌─────────┬────────────┬────────────┬────────────┬─────┬────────────┬────────────┬────────────┬────────────┐
│ node_id ┆ min_start_ ┆ first_quar ┆ second_qua ┆ ... ┆ first_quar ┆ second_qua ┆ third_quar ┆ max_ALT_GC │
│ ---     ┆ pos        ┆ tile_start ┆ rtile_star ┆     ┆ tile_ALT_G ┆ rtile_ALT_ ┆ tile_ALT_G ┆ _ratio     │
│ i64     ┆ ---        ┆ _pos       ┆ t_pos      ┆     ┆ C_ratio    ┆ GC_ratio   ┆ C_ratio    ┆ ---        │
│         ┆ f64        ┆ ---        ┆ ---        ┆     ┆ ---        ┆ ---        ┆ ---        ┆ f64        │
│         ┆            ┆ f64        ┆ f64        ┆     ┆ f64        ┆ f64        ┆ f64        ┆            │
╞═════════╪════════════╪════════════╪════════════╪═════╪════════════╪════════════╪════════════╪════════════╡
│ 576     ┆ 1.8556522e ┆ 1.8557225e ┆ 1.8557279e ┆ ... ┆ 0.0        ┆ 0.5        ┆ 1.0        ┆ 1.0        │
│         ┆ 7          ┆ 7          ┆ 7          ┆     ┆            ┆            ┆            ┆            │
├