In [None]:
# # first, import packages
# import polars as pl
# import pandas as pd
# from IPython.display import HTML, display
# from itables import init_notebook_mode, to_html_datatable, show
# init_notebook_mode(all_interactive=True)
# import itables.options as opt
# opt.maxBytes = "100KB"

# # load pretty jupyter's magics
# %load_ext pretty_jupyter

## **Introduction**

This notebook details the processes (semi-automated) done to further process the raw output files from Arriba and FusionCatcher fusion transcript callers. 

1. Run the `pypolars-process-ft-tsv.py` script to generate fusion transcript list from Arriba and FusionCatcher output files. The script takes a mandatory input of path to the directory where sample-specific fusion call output files from Arriba or FusionCatcher are stored as the first argument, and the specific string that is used to identify tool name (`arr` for Arriba fusion transcript call output file prefix, for instance). 

	For example:
	> ``` pypolars-process-ft-tsv.py data/FTmyBRCAs_raw/Arriba arr ```

	Do the same for the FusionCatcher raw output files, as well as the same Arriba and FusionCatcher output files generated from the processing 113 TCGA-Normals (to use as a panel of normals for FT filtering).

2. Then, load up the two datasets on Jupyter Notebook and concatenate the dataframes together so that Arriba+FusionCatcher unfiltered FT data are combined into one data table and saved in one `.parquet` and `.tsv` file. Do the same for the `TCGANormals` panel of normals.

In [None]:
# load up MyBrCa datasets
arr_mdf = pl.scan_parquet('../output/MyBrCa/Arriba-FT-all-unfilt-list-v2.parquet')
fc_mdf = pl.scan_parquet('../output/MyBrCa/FusionCatcher-FT-all-unfilt-list-v2.parquet')

# now load TCGANormals
arr_tdf = pl.scan_parquet('../output/TCGANormals/Arriba-Normal-FT-all-unfilt-list-v2.parquet')
fc_tdf = pl.scan_parquet('../output/TCGANormals/FusionCatcher-Normal-FT-all-unfilt-list-v2.parquet')

html_arr_t = to_html_datatable(pl.DataFrame.to_pandas(arr_tdf.collect(), use_pyarrow_extension_array=True).head(5), display_logo_when_loading=False)

html_fc_t = to_html_datatable(pl.DataFrame.to_pandas(fc_tdf.collect(), use_pyarrow_extension_array=True).head(), display_logo_when_loading=False)


### Loaded Polars dataFrames
[//]: # (-.- .tabset .tabset-pills)

Here all datasets from the two different fusion transcript calling tools run on both MyBrCa and TCGA-Normals cohorts are shown in tabs.

In [None]:
%%jmd 

#### **Dataset 1A** (MyBrCa): Arriba unfiltered
Arriba MyBrCa datatable dimension: <b>{{arr_mdf.collect().shape}}</b>

Showing truncated table:

In [None]:
# -.-|m { input: false, output: true}
show(pl.DataFrame.to_pandas(arr_mdf.collect(), use_pyarrow_extension_array=True).head(5), maxBytes=0)

In [None]:
%%jmd

#### **Dataset 1B** (MyBrCa): FusionCatcher unfiltered
Arriba MyBrCa datatable dimension: <b>{{fc_mdf.collect().shape}}</b>

Showing truncated table:

In [None]:
# -.-|m { input: false, output: true}
show(pl.DataFrame.to_pandas(fc_mdf.collect(), use_pyarrow_extension_array=True).head(5), maxBytes=0)

In [None]:
%%jmd 

#### **Dataset 2A** (TCGA Normals): FusionCatcher unfiltered
Arriba TCGA datatable dimension: <b>{{arr_tdf.collect().shape}}</b>

Showing truncated table:

In [None]:
# -.-|m { input: false, output: true}
show(pl.DataFrame.to_pandas(arr_tdf.collect(), use_pyarrow_extension_array=True).head(5), maxBytes=0)

In [None]:
%%jmd

#### **Dataset 2B** (TCGA Normals): FusionCatcher unfiltered
FusionCatcher TCGA Normals datatable dimension: <b>{{fc_tdf.collect().shape}}</b>

Showing truncated table:

In [None]:
# -.-|m { input: false, output: true}
show(pl.DataFrame.to_pandas(fc_tdf.collect(), use_pyarrow_extension_array=True).head(5), maxBytes=0)

## **Concatenate Arriba and FusionCatcher Datasets**

Now, we can merge the two dataframes into one masterFrame for each cohort data (MyBrCa & TCGA panel of normals) using Polars' `concat`.

**NOTE:** Vertical concatenation is the default, where two dataframes sharing the exact same columns would be joined together, adding all rows of dataframe 1 and 2 vertically.

[//]: # (-.- .alert .alert-warning)

In [None]:
# -.-|m { input: false, output: false}
joined_df = pl.concat(
    [
        arr_mdf.collect(),
        fc_mdf.collect()
    ]
)

In [None]:
# -.-|m { input: false, output: true}
display(HTML(f"Concatenated MyBrCa Arriba+FusionCatcher datatable dimension: " + f"<b>{joined_df.shape}</b>"))

show(joined_df.head(5), maxBytes=0, classes="display compact")

Do the same with the TCGA panel of normal FTs.

In [None]:
# -.-|m { input: false, output: false}
joined_norms_df = pl.concat(
    [
        arr_tdf.collect(),
        fc_tdf.collect()
    ]
)

In [None]:
# -.-|m { input: false, output: true}

display(HTML(f"Concatenated TCGA-Normals Arriba+FusionCatcher datatable dimension: " + f"<b>{joined_norms_df.shape}</b>"))

show(joined_norms_df.head(5), maxBytes=0, classes="display compact")

## **Filter MyBrCa Merged DataFrame using Panel of Normals**
Now we can filter the unfiltered, concatenated FT dataframes by discarding those that are present in TCGA Normal data.

First, load the parquet file from the MyBrCa datasets. Then load the parquet file from the TCGANormals datasets.

In [None]:
# -.-|m { input: true, output: true}

mybrca_ccdf = pl.scan_parquet('../output/MyBrCa/Arr_FC-concat-FT-all-unfilt-list-v2.parquet')

tcganorms_ccdf = pl.scan_parquet('../output/TCGANormals/Arr_FC-Normals-concat-FT-all-unfilt-list-v2.parquet')


Once they are loaded, we can convert to Pandas from Polars for ease of processing.

In [None]:
# -.-|m { input: true, output: true, input_fold: show}

my_concat_df = pl.DataFrame.to_pandas(mybrca_ccdf.collect(), use_pyarrow_extension_array=True)
tn_concat_df = pl.DataFrame.to_pandas(tcganorms_ccdf.collect(), use_pyarrow_extension_array=True)

In [None]:
# -.-|m { input: false, output: true}
display(HTML(f"Concatenated MyBrCa Arriba+FusionCatcher datatable dimension: " + f"<b>{my_concat_df.shape}</b>"))

show(my_concat_df.head(5), maxBytes=0, classes="display compact")

In [None]:
# -.-|m { input: false, output: true}
display(HTML(f"Concatenated TCGA Normals Arriba+FusionCatcher datatable dimension: " + f"<b>{tn_concat_df.shape}</b>"))

show(tn_concat_df.head(5), maxBytes=0, classes="display compact")

Next, use Polars' `filter` expression with `is_in` and the negation `~` to keep only unique rows for column `breakpointID` in MyBrCa dataframe that are **NOT** in the `breakpointID` in TCGANormals dataframe. 

In [None]:
# -.-|m { input: true, output: true, input_fold: show}

normfilt_mybrca_ccdf = mybrca_ccdf.collect().filter(~pl.col('breakpointID').is_in(tcganorms_ccdf.collect()['breakpointID']))

The TCGA Normal-filtered dataframe is as below:

In [None]:
# -.-|m { input: false, output: true}
display(HTML(f"<b>Normal-filtered</b> Concatenated MyBrCa FT DataFrame dimension: <b>{normfilt_mybrca_ccdf.shape}</b>"))

show(normfilt_mybrca_ccdf, maxBytes=0)

### Interrogate Shared `breakpointID`

The most identifying column for our FT data is the gene fusion breakpoint information coded in the column `breakpointID`. We can interrogate our dataframe based on this column to explore the *sharedness* of each unique breakpoint (*how many patients share the same breakpoints*).

To do so, subset the dataframe into just `breakpointID` and `sampleID` and then use `group_by` on the `breakpointID` column, then counting number of unique occurences of each unique `breakpointID` in the `sampleID` column using `n_unique()`.

In [None]:
# -.-|m { input: true, input_fold: show}

normfilt_my_sharedness = normfilt_mybrca_ccdf.select(pl.col(["breakpointID", "sampleID"])).group_by("breakpointID").n_unique().rename({"sampleID": "sharednessDegree"})

This would return a count of unique samples (patients) one particular unique breakpoint appears in. This is the `sharednessDegree` henceforth.

In [None]:
# -.-|m { input: false, input_fold: hide}

sorted_normfilt_my_sharedness = normfilt_my_sharedness.sort("sharednessDegree", descending=True)

display(HTML("Nonredundant <b>normal-filtered</b> sharednessDegree dimension: " + f"<b>{sorted_normfilt_my_sharedness.shape}</b>"))

show(sorted_normfilt_my_sharedness, maxBytes=0)


**NOTE:** This is the best way to address miscounting breakpoints that appear in multiple rows due to differences in gene naming but they are only seen in one sample. Using other counting strategies such as window function (`.over` method) will count these duplicate rows as separate entities when in reality they are the same breakpoint seen in just one patient.

[//]: # (-.- .alert .alert-warning)

## **Plot the Sharedness Degree**

We have used Polars to easily group and count the number of patients sharing a particular breakpoint ID for each unique breakpoint ID as above, let's formalize that again by using Pandas instead.

First, subset the filtered dataframe to just the two columns we are interested in using Polars, but this time prepend the "P" string to all values of the `sampleID` column, then convert to Pandas for visualization.

In [None]:
# -.-|m { input: true, input_fold: show}
bp_sample_array_pdf = normfilt_mybrca_ccdf.select(
    pl.col("breakpointID"),
    pl.concat_str(pl.lit("P"), pl.col("sampleID")).alias("sampleID")
).to_pandas()

In [None]:
# -.-|m { input: false, output: true}
display(HTML("Nonredundant <b>normal-filtered</b> BreakPoint–Patient Connection dimension: " + f"<b>{bp_sample_array_pdf.shape}</b>"))

show(bp_sample_array_pdf, maxBytes=0, classes="display compact")


Due to the annotation redundancy in `fusionGeneID` column in the original df, we now have rows in `breakpointID` and `sampleID` that are repeated (i.e. `6:36132629-17:44965446	P1` as seen above). 

Let's filter these out, as they represent the same putative FT.

In [None]:
# Drop duplicates based on both columns
bpsample_pdf_unique = bp_sample_array_pdf.drop_duplicates()

# see how many duplicates were removed
print("Original number of rows:", len(bp_sample_array_pdf))
print("Number of rows after removing duplicates:", len(bpsample_pdf_unique))
print("Number of duplicates removed:", len(bp_sample_array_pdf) - len(bpsample_pdf_unique))

Now we group by each unique `breakpointID` and count how many `sampleID` is associated with this breakpoint (*sharedness degree*).

In [None]:
# Group by breakpointID and count unique sampleIDs
breakpoint_counts = bpsample_pdf_unique.groupby('breakpointID')['sampleID'].nunique().reset_index()

# Rename the column for clarity
breakpoint_counts = breakpoint_counts.rename(columns={'sampleID': 'sharednessDegree'})

show(breakpoint_counts, maxBytes=0, classes="display compact")

Then we can count the number of unique `breakpointID`s for each `sharednessDegree` value.

In [None]:
# Count the number of unique breakpointIDs for each sharednessDegree value
sharedness_counts = (
    breakpoint_counts
    .groupby('sharednessDegree')
    .agg(
        uniqueBPCounts=('breakpointID', 'nunique')
    )
    .reset_index()
    .sort_values('sharednessDegree')
)

show(sharedness_counts.reset_index(drop=True), maxBytes=0, classes="display compact")

In [None]:
sharedness_counts['uniqueBPCounts']

Plot the `sharedness_counts` dataFrame in a bar plot.

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Create the bar plot
plt.figure(figsize=(10, 8), dpi=300)
sns.barplot(x=sharedness_counts['sharednessDegree'],y=sharedness_counts['uniqueBPCounts'], color='crimson')

# Add value labels on top of the bars
for i, v in enumerate(sharedness_counts['uniqueBPCounts']):
    plt.text(i, v, str(v), color='black', ha='center', fontweight='bold', fontsize=8)
    
# Set labels and title
plt.xlabel('Sharedness Degree (Number of Patients A Unique FT Is Observed)')
plt.ylabel('Count of Unique FTs')
plt.title('Frequency of Unique Tumor-Specific FTs by Sharedness Degree')

# Rotate x-axis labels for better readability
plt.xticks(rotation=90)

plt.show()

> **NOTE**: This means that **94% (41080/43927)** of the unique breakpoint FTs are patient-specific.

## **Implement Bipartite Network Analysis**

We can use bipartite network analysis from graph theory to explore the underlying relationships between unique `breakpointID` and `sampleID`.

### Design an Analysis Class
Create a complex class called `NetworkAnalyzer` to do graph network analysis between `breakpointID` and `sampleID`. 

(Expand the code below to see the full class methods.)

In [None]:
# -.-|m { input: true, output: true, input_fold: hide}
import numpy as np
import networkx as nx
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy.spatial.distance import pdist, squareform
from networkx.algorithms.bipartite import density as bipartite_density
from scipy.sparse import csr_matrix

class NetworkAnalyzer:
	def __init__(self, df=None, patient_col=None, breakpoint_col=None, 
					precomputed_matrix=None, patients=None, breakpoints=None):
		"""
		Initialize NetworkAnalyzer with either DataFrame or precomputed matrix.
		
		Args:
			df (pl.DataFrame.Polars, optional): Input DataFrame HAS TO BE IN POLARS
			patient_col (str, optional): Column name for patients
			breakpoint_col (str, optional): Column name for breakpoints
			precomputed_matrix (csr_matrix, optional): Pre-computed sparse adjacency matrix
			patients (list, optional): List of patient IDs (required if using precomputed_matrix)
			breakpoints (list, optional): List of breakpoint IDs (required if using precomputed_matrix)
		"""
		if precomputed_matrix is not None:
			if patients is None or breakpoints is None:
				raise ValueError("Must provide patients and breakpoints lists with precomputed matrix")
			# Keep matrix in sparse format
			self.adj_matrix_sparse = precomputed_matrix
			self.patients = patients
			self.breakpoints = breakpoints
		elif df is not None and patient_col and breakpoint_col:
			self.df = df
			self.patient_col = patient_col
			self.breakpoint_col = breakpoint_col
			
			# Get unique sets
			self.patients = sorted(df[patient_col].unique().to_list())
			self.breakpoints = sorted(df[breakpoint_col].unique().to_list())
			
			# Create sparse adjacency matrix
			self.adj_matrix_sparse, self.patient_idx_dict, self.breakpoint_idx_dict = self._create_adjacency_matrix()
		else:
			raise ValueError("Must provide either DataFrame with column names or precomputed matrix with labels")

		# Don't calculate metrics immediately - do it lazily
		self._metrics_calculated = False
		
	def _ensure_metrics_calculated(self):
		"""Calculate metrics if they haven't been calculated yet."""
		if not self._metrics_calculated:
			self._calculate_metrics()
			self._metrics_calculated = True

	def _create_adjacency_matrix(self) -> csr_matrix:
		"""Create the sparse adjacency matrix from the input DataFrame."""
		matrix = np.zeros((len(self.patients), len(self.breakpoints)))
		connections = self.df.group_by(self.patient_col).agg(
			pl.col(self.breakpoint_col).alias('breakpoints')
		).to_dict(as_series=False)

		patient_idx = {p: i for i, p in enumerate(self.patients)}
		breakpoint_idx = {b: i for i, b in enumerate(self.breakpoints)}

		for i, patient in enumerate(connections[self.patient_col]):
			for bp in connections['breakpoints'][i]:
				matrix[patient_idx[patient]][breakpoint_idx[bp]] = 1

		return csr_matrix(matrix), patient_idx, breakpoint_idx

	def _calculate_metrics(self):
		"""Calculate various network metrics."""
		# Convert to dense only when needed for specific calculations
		dense_matrix = self.adj_matrix_sparse.toarray()
		
		self.patient_degrees = np.asarray(self.adj_matrix_sparse.sum(axis=1)).flatten()
		self.breakpoint_degrees = np.asarray(self.adj_matrix_sparse.sum(axis=0)).flatten()
		
		# Only calculate similarity matrices if needed for visualization
		self.patient_similarity = squareform(pdist(dense_matrix, metric='jaccard'))
		self.breakpoint_similarity = squareform(pdist(dense_matrix.T, metric='jaccard'))

		# Create bipartite graph more efficiently
		G = nx.Graph()
		G.add_nodes_from(range(len(self.patients)), bipartite=0)
		G.add_nodes_from(range(len(self.patients), len(self.patients) + len(self.breakpoints)), bipartite=1)
		
		# Add edges using sparse matrix coordinates
		rows, cols = self.adj_matrix_sparse.nonzero()
		edges = zip(rows, cols + len(self.patients))
		G.add_edges_from(edges)

		self.density = bipartite_density(G, range(len(self.patients), len(self.patients) + len(self.breakpoints)))
		
		# Calculate centrality
		centrality = nx.degree_centrality(G)
		self.breakpoint_centrality = [centrality[i + len(self.patients)] for i in range(len(self.breakpoints))]

	def save_matrix(self, filename):
		"""
		Save the adjacency matrix in CSR format along with patient and breakpoint labels.
		
		Args:
			filename (str): Base filename to save the data (without extension)
		"""
		# Save the sparse matrix
		sparse_matrix = self.adj_matrix_sparse
		np.savez(f"{filename}_adjac_matrix.npz",
					data=sparse_matrix.data,
					indices=sparse_matrix.indices,
					indptr=sparse_matrix.indptr,
					shape=sparse_matrix.shape)
		
		# Save the labels
		np.save(f"{filename}_matrix_label_patients.npy", np.array(self.patients))
		np.save(f"{filename}_matrix_label_breakpoints.npy", np.array(self.breakpoints))

	@classmethod
	def load_from_files(cls, filename):
		"""
		Load a NetworkAnalyzer instance from saved files.
		
		Args:
			filename (str): Base filename (without extension) used when saving
			
		Returns:
			NetworkAnalyzer: New instance with loaded data
		"""
		# Load the sparse matrix
		loader = np.load(f"{filename}_adjac_matrix.npz")
		matrix = csr_matrix((loader['data'], loader['indices'], loader['indptr']), shape=loader['shape'])
		
		# Load the labels
		patients = np.load(f"{filename}_matrix_label_patients.npy").tolist()
		breakpoints = np.load(f"{filename}_matrix_label_breakpoints.npy").tolist()

		return cls(precomputed_matrix=matrix, patients=patients, breakpoints=breakpoints)
	
	def create_adjacency_matrix_plot(self, top_bins: list = None, bottom_bins: list = None) -> go.Figure:
		"""Create a standalone plot of the patient-breakpoint adjacency matrix."""
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()

		# Convert to dense only when needed for specific calculations
		dense_matrix = self.adj_matrix_sparse.toarray()

		# If no bins are provided, plot the full adjacency matrix
		if not top_bins and not bottom_bins:
			# Sort breakpoints by their degree (number of connected patients)
			sorted_indices = np.argsort(self.breakpoint_degrees)[::-1]  # Descending order
			top_bins = sorted_indices.tolist()
			top_breakpoints = [self.breakpoints[i] for i in top_bins]
			top_matrix = dense_matrix[:, top_bins]
		elif top_bins:
			# Sort the provided top bins by degree
			sorted_top_bins = sorted(top_bins, key=lambda x: self.breakpoint_degrees[x], reverse=True)
			top_breakpoints = [self.breakpoints[i] for i in sorted_top_bins]
			top_matrix = dense_matrix[:, sorted_top_bins]

		# Create a binary colorscale - only two colors
		binary_colorscale = [
			[0, 'white'],  # White for 0 (no connection)
			[1, 'crimson']  # Crimson for 1 (connection exists)
		]

		# Create the top breakpoints plot
		fig = go.Figure(
			data=go.Heatmap(
				z=top_matrix,
				x=top_breakpoints,
				y=self.patients,
				colorscale=binary_colorscale,
				showscale=True,
				hoverongaps=False,
				hoverinfo='text',
				text=[[f"Patient: {p}<br>Breakpoint: {b}<br>Connected: {'Yes' if top_matrix[i][j] else 'No'}"
						for j, b in enumerate(top_breakpoints)]
						for i, p in enumerate(self.patients)],
				colorbar=dict(title="Connection"),
				name="Top Breakpoints"
			)
		)

		# If bottom bins are provided, add the bottom breakpoints plot
		if bottom_bins:
			# Sort the bottom bins by degree
			sorted_bottom_bins = sorted(bottom_bins, key=lambda x: self.breakpoint_degrees[x], reverse=True)
			bottom_breakpoints = [self.breakpoints[i] for i in sorted_bottom_bins]
			bottom_matrix = dense_matrix[:, sorted_bottom_bins]

			fig.add_trace(
				go.Heatmap(
					z=bottom_matrix,
					x=bottom_breakpoints,
					y=self.patients,
					colorscale=binary_colorscale,
					showscale=True,
					hoverongaps=False,
					hoverinfo='text',
					text=[[f"Patient: {p}<br>Breakpoint: {b}<br>Connected: {'Yes' if bottom_matrix[i][j] else 'No'}"
							for j, b in enumerate(bottom_breakpoints)]
							for i, p in enumerate(self.patients)],
					colorbar=dict(title="Connection"),
					name="Bottom Breakpoints"
				)
			)

		fig.update_layout(
			height=1500,
			width=850,
			title=dict(
				text="Adjacency Matrix",
				x=0.5,
				y=0.95,
				font=dict(size=18)
			),
			xaxis_title="Breakpoints (Sorted by Number of Connected Patients)",
			yaxis_title="Patients",
			template="simple_white"
		)

		return fig

	def create_degree_distribution_plot(self, top_bins: list = None, bottom_bins: list = None) -> go.Figure:
		"""Create a standalone plot of the breakpoint degree distribution."""
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()

		# If no bins are provided, plot the full degree distribution
		if not top_bins and not bottom_bins:
			top_bins = list(range(len(self.breakpoints)))
			top_breakpoints = self.breakpoints
			top_degrees = [self.breakpoint_degrees[i] for i in top_bins]

		elif top_bins:
			top_breakpoints = [self.breakpoints[i] for i in top_bins]
			top_degrees = [self.breakpoint_degrees[i] for i in top_bins]

		# Create the top breakpoints plot
		fig = go.Figure(
			data=go.Bar(
				x=top_breakpoints,
				y=top_degrees,
				hovertext=[f"Breakpoint: {bp}<br>Connected to {deg} patients"
							for bp, deg in zip(top_breakpoints, top_degrees)],
				hoverinfo='text',
				marker_color='steelblue',
				marker_line_color='rgb(8,48,107)',
				marker_line_width=1.5,
				name="Degree Distribution"
			)
		)

		# If bottom bins are provided, add the bottom breakpoints plot
		if bottom_bins:
			bottom_breakpoints = [self.breakpoints[i] for i in bottom_bins]
			bottom_degrees = [self.breakpoint_degrees[i] for i in bottom_bins]

			fig.add_trace(
				go.Bar(
					x=bottom_breakpoints,
					y=bottom_degrees,
					hovertext=[f"Breakpoint: {bp}<br>Connected to {deg} patients"
								for bp, deg in zip(bottom_breakpoints, bottom_degrees)],
					hoverinfo='text',
					marker_color='steelblue',
					marker_line_color='rgb(8,48,107)',
					marker_line_width=1.5,
					name="Bottom Breakpoints"
				)
			)

		fig.update_layout(
			height=1000,
			width=850,
			title=dict(
				text="Breakpoint Degree Distribution",
				x=0.5,
				y=0.95,
				font=dict(size=18)
			),
			xaxis_title="Breakpoints",
			yaxis_title="Number of Patients",
			template="simple_white"
		)

		return fig
	
	def create_patient_similarity_matrix_plot(self) -> go.Figure:
		"""Create a standalone plot of the patient similarity matrix."""
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()

		# Extract the lower triangular portion of the matrix (excluding the diagonal)
		patient_similarity_lower = np.tril(1 - self.patient_similarity, -1)

		# Create a mask for the upper triangular portion (excluding the diagonal)
		patient_similarity_mask = np.tri(len(self.patients), len(self.patients), k=1, dtype=bool)

		# Create the heatmap data, setting the upper triangular portion to the maximum value
		patient_similarity_data = np.where(patient_similarity_mask, np.max(patient_similarity_lower), 1 - self.patient_similarity)

		fig = go.Figure(
			data=go.Heatmap(
				z=patient_similarity_data,
				x=self.patients,
				y=self.patients,
				colorscale="Viridis",
				showscale=True,
				colorbar=dict(title="Similarity"),
				name="Patient Similarity"
			)
		)

		fig.update_layout(
			height=1000,
			width=850,
			title=dict(
				text="Patient Similarity Matrix",
				x=0.5,
				y=0.95,
				font=dict(size=18)
			),
			xaxis_title="Patients",
			yaxis_title="Patients",
			template="simple_white"
		)

		return fig

	def create_breakpoint_cooccurrence_plot(self) -> go.Figure:
		"""Create a standalone plot of the breakpoint co-occurrence matrix."""
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()

		# Extract the lower triangular portion of the matrix (excluding the diagonal)
		breakpoint_similarity_lower = np.tril(1 - self.breakpoint_similarity, -1)

		# Create a mask for the upper triangular portion (excluding the diagonal)
		breakpoint_similarity_mask = np.tri(len(self.breakpoints), len(self.breakpoints), k=1, dtype=bool)

		# Create the heatmap data, setting the upper triangular portion to the maximum value
		breakpoint_similarity_data = np.where(breakpoint_similarity_mask, np.max(breakpoint_similarity_lower), 1 - self.breakpoint_similarity)

		fig = go.Figure(
			data=go.Heatmap(
				z=breakpoint_similarity_data,
				x=self.breakpoints,
				y=self.breakpoints,
				colorscale="Viridis",
				showscale=True,
				colorbar=dict(title="Co-occurrence"),
				name="Breakpoint Co-occurrence"
			)
		)

		fig.update_layout(
			height=1000,
			width=850,
			title=dict(
				text="Breakpoint Co-occurrence Matrix",
				x=0.5,
				y=0.95,
				font=dict(size=18)
			),
			xaxis_title="Breakpoints",
			yaxis_title="Breakpoints",
			template="simple_white"
		)

		return fig

	def create_dashboard(self) -> go.Figure:
		"""Create a comprehensive visualization dashboard."""
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()
		# Convert to dense only when needed for specific calculations
		dense_matrix = self.adj_matrix_sparse.toarray()
		# Create a binary colorscale - only two colors
		binary_colorscale = [
		[0, 'white'],  # White for 0 (no connection)
		[1, 'crimson']     # Crimson for 1 (connection exists)
		]
		fig = make_subplots(
			horizontal_spacing=0.15,
			rows=2, cols=2,
			subplot_titles=("Patient-Breakpoint Adjacency Matrix", 
							"Breakpoint Degree Distribution",
							"Patient Similarity Matrix", 
							"Breakpoint Co-occurrence Matrix"),
			specs=[[{"type": "heatmap"}, {"type": "bar"}],
					[{"type": "heatmap"}, {"type": "heatmap"}]]
		)

		# 1. Adjacency Matrix with custom hover text
		hover_text = [[f"Patient: {p}<br>Breakpoint: {b}<br>Connected: {'Yes' if dense_matrix[i][j] else 'No'}"
						for j, b in enumerate(self.breakpoints)]
						for i, p in enumerate(self.patients)]

		fig.add_trace(
			go.Heatmap(
				z=dense_matrix,
				x=self.breakpoints,
				y=self.patients,
				colorscale=binary_colorscale,
				showscale=True,
				hoverongaps=False,
				hoverinfo='text',
				text=hover_text,
				colorbar=dict(
                title="Connection",
                x=0.43,  # Position to the right of the first subplot
                y=1.00,
                len=0.3,
                yanchor='top'
            ),
				name="Connections"
			),
			row=1, col=1
		)

		# 2. Degree Distribution with custom hover
		hover_text = [f"Breakpoint: {bp}<br>Connected to {deg} patients"
						for bp, deg in zip(self.breakpoints, self.breakpoint_degrees)]

		fig.add_trace(
			go.Bar(
				x=self.breakpoints,
				y=self.breakpoint_degrees,
				hovertext=hover_text,
				hoverinfo='text',
				marker_color='crimson',
				marker_line_color='rgb(8,48,107)',
				marker_line_width=1.5,
				name="Breakpoint Degrees"
			),
			row=1, col=2
		)

		# 3. Patient Similarity Matrix
		fig.add_trace(
			go.Heatmap(
				z=1 - self.patient_similarity,
				x=self.patients,
				y=self.patients,
				colorscale="Viridis",
				showscale=True,
				colorbar=dict(title="Similarity",
                x=0.43,  # Position to the right of the third subplot
                y=0.35,
                len=0.3,
                yanchor='top'),
				name="Patient Similarity"
			),
			row=2, col=1
		)

		# 4. Breakpoint Co-occurrence
		fig.add_trace(
			go.Heatmap(
				z=1 - self.breakpoint_similarity,
				x=self.breakpoints,
				y=self.breakpoints,
				colorscale="Viridis",
				showscale=True,
				colorbar=dict(
            title="Co-occurrence",
                x=1.02,  # Position to the right of the fourth subplot
                y=0.35,
                len=0.3,
                yanchor='top'      # Width of the colorbar
        ),
			name="Breakpoint Co-occurrence"
			),
			row=2, col=2
		)

		fig.update_layout(
			height=900,
			width=1400,
			title=dict(
				text="Network Analysis Dashboard",
				x=0.5,
				y=0.95,
				font=dict(size=22)
			),
			showlegend=False,
			template="simple_white"
		)

		font_size = 14
		fig.update_xaxes(title_text="Breakpoints", title_font=dict(size=font_size), row=1, col=1)
		fig.update_yaxes(title_text="Patients", title_font=dict(size=font_size), row=1, col=1)
		fig.update_xaxes(title_text="Breakpoints", title_font=dict(size=font_size), row=1, col=2)
		fig.update_yaxes(title_text="Number of Patients", title_font=dict(size=font_size), row=1, col=2)
		fig.update_xaxes(title_text="Patients", title_font=dict(size=font_size), row=2, col=1)
		fig.update_yaxes(title_text="Patients", title_font=dict(size=font_size), row=2, col=1)
		fig.update_xaxes(title_text="Breakpoints", title_font=dict(size=font_size), row=2, col=2)
		fig.update_yaxes(title_text="Breakpoints", title_font=dict(size=font_size), row=2, col=2)

		return fig
	
	def get_breakpoint_bins(self, top_percentile: float = 0.001, bottom_percentile: float = 0.001) -> tuple:
		"""
		Calculate the indexes of the breakpoints at the specified percentiles.
		Returns a tuple of two lists:
		- The first list contains the indexes of the top `top_percentile` breakpoints by degree.
		- The second list contains the indexes of the bottom `bottom_percentile` breakpoints by degree.
		"""
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()
		sorted_degrees = sorted(self.breakpoint_degrees)
		top_cutoff = int(len(sorted_degrees) * top_percentile)
		bottom_cutoff = int(len(sorted_degrees) * (1 - bottom_percentile))

		top_bins = [i for i in range(top_cutoff)]
		bottom_bins = [i for i in range(bottom_cutoff, len(sorted_degrees))]

		return top_bins, bottom_bins

	def print_summary_stats(self):
		# Ensure metrics are calculated before accessing them
		self._ensure_metrics_calculated()
		print(f"Network Summary Statistics:")
		print(f"---------------------------")
		print(f"Number of Patients: {len(self.patients)}")
		print(f"Number of Breakpoints: {len(self.breakpoints)}")
		print(f"Network Density: {self.density:.3f}")
		print(f"Average Patient Degree: {np.mean(self.patient_degrees):.2f}")
		print(f"Average Breakpoint Degree: {np.mean(self.breakpoint_degrees):.2f}")
		print(f"\nTop Breakpoints by Degree:")
		for bp, degree in sorted(zip(self.breakpoints, self.breakpoint_degrees), 
								key=lambda x: x[1], reverse=True)[:5]:
			print(f"  {bp}: {degree}")
		# print(f"\nTop 10 Breakpoints by Degree Centrality:")
    	# # Create list of (breakpoint, centrality) tuples and sort by centrality
		# centrality_pairs = list(zip(self.breakpoints, self.breakpoint_centrality))
		# sorted_by_centrality = sorted(centrality_pairs, key=lambda x: x[1], reverse=True)
		
		# # Print top 10
		# for bp, centrality in sorted_by_centrality[:10]:
		# 	print(f"  {bp}: {centrality:.3f}")

### Test Class on Toy Data

In [None]:
# create toy data
np.random.seed(420)  # for reproducibility

patients = [f'P{i}' for i in range(1, 21)]  # 20 patients
breakpoints = [f'BP{i}' for i in range(1, 16)]  # 15 breakpoints

# Create random connections (each patient has 2-6 breakpoints)
data = []
for patient in patients:
	num_breakpoints = np.random.randint(2, 7)
	patient_breakpoints = np.random.choice(breakpoints, size=num_breakpoints, replace=False)
	for bp in patient_breakpoints:
		data.append({'patient_id': patient, 'breakpoint_id': bp})

# Create Polars DataFrame
df = pl.DataFrame(data)

# now test the class
# Create and display visualization
analyzer = NetworkAnalyzer(df, patient_col='patient_id', breakpoint_col='breakpoint_id')
fig = analyzer.create_dashboard()
fig.show()

# Print summary statistics
analyzer.print_summary_stats()

In [None]:
%%jmd 
{{ fig.to_html(include_plotlyjs=False, full_html=False, default_height=400, default_width=600) }}

### Save Adjacency Matrix as Precomputed `.npz` File

In [None]:
%%jmd
We can call the method `save_matrix` directly like so:

> `analyzer.save_matrix('output/toy_data')`

and then try reloading into an instance using the decorated class method `load_from_file` like so:

> `analyzer_reloaded = NetworkAnalyzer.load_from_files('output/toy_data')`

Finally print summary statistics as sanity check.

> `analyzer.print_summary_stats()`

### Keep FT Breakpoints Seen in More than 9 Patients (1% of Cohort)

The distribution of the sharedness degree of each unique breakpoint, is as expected, skewed towards having a lot of unique, patient-specific connections, and very few shared breakpoints across patients. 

We can try to visualize the adjacency matrix, but because of the massive matrix dimension we have (**988 patients x 43927 unique breakpoints**), it is best to first filter out patient-specific breakpoints first. In fact, as the putative FT neoantigen distribution is so skewed towards individualized presence, let's create a filtering threshold of keeping only the breakpoint IDs that are seen in **more than 9 patients (approximately 1% of the MyBrCa cohort)**. 


In [None]:
# -.-|m { input: true, output: true, input_fold: show}
# This is DF with redundant duplicate breakpointID-sampleID pairing FILTERED OUT

bpsample_poldf = pl.from_pandas(bpsample_pdf_unique)

In [None]:
# -.-|m { input: false, output: true}

display(HTML("Nonredundant <b>normal-filtered</b> BreakPoint–Patient Connection dimension: " + f"<b>{bpsample_poldf.shape}</b>"))
display(HTML("Displaying truncated table:"))
show(bpsample_poldf.head(10), maxBytes=0, classes="display compact")

We can then count how many patients from our cohort these 53k unique breakpoints cover.

In [None]:
# Method 2: More explicit way
unique_samples = set(bpsample_poldf["sampleID"])
len(unique_samples)

Now we go back to the `sharednessDegree` Pandas dataFrame to select rows that has *sharednessDegree* > 9.

Alternatively we can directly use the Polars dataframe `normfilt_my_sharedness` by running the command below:

> bp_sharedness_gt9 = normfilt_mybrca_sharedness.filter(pl.col('sharednessDegree') > 9)

In [None]:
# -.-|m { input: true, output: true}
bp_sharedness_gt9 = pl.from_pandas(breakpoint_counts).filter(pl.col('sharednessDegree') > 9)

show(bp_sharedness_gt9, maxBytes=0)

We can explore how many breakpoints and patients are covered when we filter out just the patient-specific breakpoints. 

In [None]:
bp_sharedness_gt1 = pl.from_pandas(breakpoint_counts).filter(pl.col('sharednessDegree') > 1)

show(bp_sharedness_gt1, maxBytes=0)

Now we can use the unique, filtered, thresholded elements in the `breakpointID` column of the filtered dataFrame above as the filtering list to keep only these same breakpoints in the other dataFrame used to instantiate `NetworkAnalyzer`. 

The resulting ***connection*** dataframe can be used to instantiate a `NetworkAnalyzer` instance.

Below is the *greater than 9 sharedness degree* connection dataFrame.

In [None]:
# -.-|m { input: true, output: true}
# bp_sharedness_gt9 is the dataframe with unique breakpointIDs to be used as filter
# bpsample_poldf is the dataframe to be filtered

filt_bpsample_gt9_poldf = bpsample_poldf.filter(
    pl.col("breakpointID").is_in(bp_sharedness_gt9["breakpointID"])
)

show(filt_bpsample_gt9_poldf, maxBytes=0)

Below is the *greater than 1 sharedness degree* connection dataFrame.

In [None]:
# -.-|m { input: true, output: true}
# bp_sharedness_gt1 is the dataframe with unique breakpointIDs to be used as filter
# bpsample_poldf is the dataframe to be filtered

filt_bpsample_gt1_poldf = bpsample_poldf.filter(
    pl.col("breakpointID").is_in(bp_sharedness_gt1["breakpointID"])
)

show(filt_bpsample_gt1_poldf, maxBytes=0)

### Instantiate NetworkAnalyzer on Filtered Data

Now instantiate the class we built on our normal-filtered, unique-breakpoint-only subset dataFrame from MyBrCa FT data.

In [None]:
# -.-|m { input: true, output: true, input_fold: show}
analyzer_my_filt_gt9 = NetworkAnalyzer(filt_bpsample_gt9_poldf, patient_col='sampleID', breakpoint_col='breakpointID')

# Print summary statistics
analyzer_my_filt_gt9.print_summary_stats()

In [None]:
print(f"Percentage of Patients in the Cohort covered by breakpoints with more than 9 sharedness degree: {776/988*100}%")

In [None]:
# -.-|m { input: true, output: true, input_fold: show}
analyzer_my_filt_gt1 = NetworkAnalyzer(filt_bpsample_gt1_poldf, patient_col='sampleID', breakpoint_col='breakpointID')

# Print summary statistics
analyzer_my_filt_gt1.print_summary_stats()

In [None]:
print(f"Percentage of Patients in the Cohort covered by breakpoints with more than 1 sharedness degree: {933/988*100}%")

#### Plot Adjacency Matrix

In [None]:
plot = analyzer_my_filt_gt9.create_adjacency_matrix_plot()
plot.show()

In [None]:
%%jmd 
{{ plot.to_html(include_plotlyjs=False, full_html=False, default_height=400, default_width=600) }}

#### Plot Breakpoint Degree Distribution

In [None]:
plot = analyzer_my_filt_gt9.create_degree_distribution_plot()
plot.show()

In [None]:
%%jmd 
{{ plot.to_html(include_plotlyjs=False, full_html=False, default_height=400, default_width=600) }}

#### Plot Patient Similarity Matrix

In [None]:
plot = analyzer_my_filt_gt9.create_patient_similarity_matrix_plot()
plot.show()

In [None]:
%%jmd 
{{ plot.to_html(include_plotlyjs=False, full_html=False, default_height=400, default_width=600) }}

#### Plot Breakpoint Co-Occurrence 

In [None]:
plot = analyzer_my_filt_gt9.create_breakpoint_cooccurrence_plot()
plot.show()

In [None]:
%%jmd 
{{ plot.to_html(include_plotlyjs=False, full_html=False, default_height=400, default_width=600) }}

## **Implement Set Cover Problem Solution (Greedy Algorithm)**

### Background

If we have a set of unique FT breakpoints found in different subsets of patients in our cohort, it is possible to find the most minimal number of unique FT breakpoints that collectively appear in 100% of our filtered cohort.

This is called **set cover problem** in set theory, and it is considered one of the classical problems in combinatorics, with real-world applications such as shift scheduling and operations optimization (https://en.wikipedia.org/wiki/Set_cover_problem).

The traditional formalization of the problem is as follows:

> Given a universe set $U$ of $n$ elements, $U := \{ e1, . . . , en \}$ and a collection of subsets
of $U$, $S := \{ S1, . . . , Sk \}$ with a cost function $c : S → Q+$, the goal is to pick the minimum-cost subcollection
of $S$ that covers all the elements of $U$.

We can reformulate this problem in the context of patient cohort and unique breakpoint neoantigens found in the cohort as follows:

>Given a patient cohort $U$ of $n$ patients, such that;

$$U := \{ P1, P2, . . . , Pn \}$$ 

>and a collection of patient subsets, $S$, such that each subset represents the subset of patients within which a unique neoantigen ($NeoX$) can be found; 

$$S:= \{ NeoA\{ P1, P2, P3 \}, NeoB\{ P2, P3 \}, . . . , NeoZ\{ P1, P4 \} \}$$

>with a *neoantigen quality* function; 

$$q : S → Q+$$ 

>the goal is to pick the ***most minimal*** subcollection of $S$ that covers all the patients in the cohort $U$ (while either minimizing cost function or maximizing neoantigen quality function)

The resulting minimal subcollection of patient subsets represent the most minimal unique set of neoantigens that would cover the whole patient cohort.

![Venn Diagram](../docs/assets/neo-patient-venn.png)

### Implementation

Below is the **connection** dataframe filtered for just breakpoints above sharedness degree of 9.

In [None]:
show(filt_bpsample_gt9_poldf, maxBytes=0)

Below is the **connection** dataframe filtered for just the breakpoints that are shared between at least 2 patients. 

In [None]:
show(filt_bpsample_gt1_poldf, maxBytes=0)

1. Using the ***connection*** dataframe above (that maps the bipartite relationships between `breakpointID`s and `sampleID`s), let us create a dataframe that is group by `breakpointID`, then aggregated on the `sampleID` column.

In [None]:
# -.-|m { input: false, output: true}
# create a dataframe that group by breakpointID

temp_df = (
    filt_bpsample_gt9_poldf
    .group_by(pl.col("breakpointID"))
    .agg([pl.col("sampleID")])
    .sort("breakpointID")
)

# print("Type of temp_df:", type(temp_df))
# print("\nSchema of temp_df:")
# print(temp_df.schema)
# print("\nFirst few rows of temp_df:")
# print(temp_df.head())

# # Let's also look at a single row
# first_row = temp_df.row(0)
# print("\nType of first row:", type(first_row))
# print("First row contents:", first_row)

2. Next, let us create a dictionary that has each `breakpointID` value as the key, and a set object containing the `sampleID` for each breakpoint as the dictionary value. This allows us to map each breakpoint to a subset of patients.

In [None]:
# -.-|m { input: false, output: true}
# Then convert the tuples of each row into a key-val dictionary entry. Turn the list values into sets

bp_coverage_dict = {}
for row in temp_df.iter_rows():
    breakpoint_id = row[0]
    sample_set = set(row[1])    # convert list to set
    bp_coverage_dict[breakpoint_id] = sample_set

# print(bp_coverage_dict)

In [None]:
filt_gt9_matrix = analyzer_my_filt_gt9.adj_matrix_sparse.toarray()

In [None]:
filt_gt1_matrix = analyzer_my_filt_gt1.adj_matrix_sparse.toarray()

Now we can implement a greedy algorithm in a function to search for the minimal subset of breakpoints that covers 100% of the patient space. As we have filtered out patients who do not has any breakpoint at all, this means that given enough computation, there would be a (or many) solution of our set cover problem that covers 100% of the input patient space.

In [None]:
def find_minimal_covering_subset(adjacency_matrix, coverage_threshold, label_to_index_dict=None):
    """
    Find minimal subset of rows that covers at least coverage_threshold fraction of columns,
    with handling for dictionary-based label mapping.

    Parameters:
    -----------
    adjacency_matrix : np.ndarray
        Binary matrix where rows are members of set A and columns are members of set B
        1 indicates overlap, 0 indicates no overlap
    coverage_threshold : float
        Fraction of set B that needs to be covered (between 0 and 1)
    label_to_index_dict : dict, optional
        Dictionary mapping labels (strings) to indices (int)
        Example: {'label1': 0, 'label2': 1, ...}

    Returns:
    --------
    tuple
        (selected_indices, selected_labels, actual_coverage)
        - selected_indices: List of numerical indices of selected rows
        - selected_labels: List of original labels corresponding to the indices
        - actual_coverage: Achieved coverage fraction
    """
    # Create reverse mapping from index to label
    if label_to_index_dict is not None:
        index_to_label = {v: k for k, v in label_to_index_dict.items()}
    
    # Work with transpose of the matrix
    working_matrix = adjacency_matrix.T
    num_rows, num_cols = working_matrix.shape
    
    # Calculate target coverage
    target_coverage = int(np.ceil(num_cols * coverage_threshold))
    print(f"Target coverage: {target_coverage} columns out of {num_cols}")
    
    # Initialize tracking variables
    selected_rows = []
    covered_cols = np.zeros(num_cols, dtype=bool)
    
    while np.sum(covered_cols) < target_coverage:
        # Calculate coverage gains for remaining rows
        available_rows = [i for i in range(num_rows) if i not in selected_rows]
        
        if not available_rows:
            break
            
        coverage_gains = np.array([
            np.sum(~covered_cols & (working_matrix[i] == 1))
            for i in available_rows
        ])
        
        if np.max(coverage_gains) == 0:
            print("No more improvements possible")
            break
            
        # Select the row that covers the most new columns
        best_row_idx = available_rows[np.argmax(coverage_gains)]
        selected_rows.append(best_row_idx)
        
        # Print progress with label if available
        if label_to_index_dict is not None:
            label = index_to_label[best_row_idx]
            new_coverage = np.sum(~covered_cols & (working_matrix[best_row_idx] == 1))
            print(f"Selected {label} (index {best_row_idx}) covering {new_coverage} new columns")
        
        # Update covered columns
        covered_cols = covered_cols | (working_matrix[best_row_idx] == 1)
    
    # Calculate actual coverage achieved
    actual_coverage = np.sum(covered_cols) / num_cols
    print(f"Achieved {actual_coverage:.2%} coverage")
    
    # Convert indices to labels if dictionary provided
    if label_to_index_dict is not None:
        selected_labels = [index_to_label[idx] for idx in selected_rows]
    else:
        selected_labels = None
    
    return selected_rows, selected_labels, actual_coverage


We can run the function on the matrix of breakpoints shared between more than 9 patients first. 

In [None]:
breakpoint_dict_gt9 = analyzer_my_filt_gt9.breakpoint_idx_dict
filt_bp_indices_gt9, filt_bp_labels_gt9, coverage_gt9 = find_minimal_covering_subset(filt_gt9_matrix, coverage_threshold=0.4, label_to_index_dict=breakpoint_dict_gt9)

In [None]:
print(f"Minimal coverage percent (%): {coverage_gt9 * 100}")
print(f"Minimal set of breakpoints: {filt_bp_indices_gt9}")
print(f"Length of set: {len(filt_bp_indices_gt9)}")
# print(f"Labels of the minimal cover set of breakpoints: {filt_bp_labels_gt9}")

Then run the same function on the matrix of breakpoints shared between 2 or more patients.

In [None]:
breakpoint_dict_gt1 = analyzer_my_filt_gt1.breakpoint_idx_dict
filt_bp_indices_gt1, filt_bp_labels_gt1, coverage_gt1 = find_minimal_covering_subset(filt_gt1_matrix, coverage_threshold=1.0, label_to_index_dict=breakpoint_dict_gt1)

In [None]:
print(f"Minimal coverage percent (%): {coverage_gt1 * 100}")
print(f"Minimal set of breakpoints: {filt_bp_indices_gt1}")
print(f"Length of set: {len(filt_bp_indices_gt1)}")
print(f"Labels of the minimal cover set of breakpoints: {filt_bp_labels_gt1}")

In [None]:
patient_dict_gt9 = analyzer_my_filt_gt9.patient_idx_dict
# print(patient_dict_gt9)
patient_dict_gt1 = analyzer_my_filt_gt1.patient_idx_dict
# print(patient_dict_gt1)
# subset original matrix
minimal_set_cover_subset_gt9_matrix = filt_gt9_matrix[:, filt_bp_indices_gt9]
print(minimal_set_cover_subset_gt9_matrix.shape)
minimal_set_cover_subset_gt1_matrix = filt_gt1_matrix[:, filt_bp_indices_gt1]
print(minimal_set_cover_subset_gt1_matrix.shape)

Below are the subset dataframes containing the minimal breakpoint sets for the 'greater than 9 sharedness degree' and 'greater than 1 sharedness degree' matrices.

In [None]:
# subset the original df used to generate analyzer_my_filt instance

filt_bpsample_minsetcover_gt9_poldf = bpsample_poldf.filter(
    pl.col("breakpointID").is_in(filt_bp_labels_gt9)
)

show(filt_bpsample_minsetcover_gt9_poldf, maxBytes=0)


In [None]:
# subset the original df used to generate analyzer_my_filt instance

filt_bpsample_minsetcover_gt1_poldf = bpsample_poldf.filter(
    pl.col("breakpointID").is_in(filt_bp_labels_gt1)
)

show(filt_bpsample_minsetcover_gt1_poldf, maxBytes=0)

In [None]:
# create a new NetworkAnalyzer instance
analyzer_subset_gt9_filt = NetworkAnalyzer(filt_bpsample_minsetcover_gt9_poldf, patient_col="sampleID", breakpoint_col="breakpointID")
analyzer_subset_gt1_filt = NetworkAnalyzer(filt_bpsample_minsetcover_gt1_poldf, patient_col="sampleID", breakpoint_col="breakpointID")

In [None]:
# Print summary statistics
analyzer_subset_gt9_filt.print_summary_stats()

In [None]:
# Print summary statistics
analyzer_subset_gt1_filt.print_summary_stats()

In [None]:
plot = analyzer_subset_gt9_filt.create_adjacency_matrix_plot()
plot.show()

In [None]:
plot = analyzer_subset_gt1_filt.create_adjacency_matrix_plot()
plot.show()

#### **Filtering Out non-TNBCs**

We can filter out rows corresponding to `sampleID` more than 172, because these are not TNBC samples.



In [None]:
show(bp_sharedness_gt1, maxBytes=0)

In [None]:
filt_bpsample_gt1_poldf = bpsample_poldf.filter(
    pl.col("breakpointID").is_in(bp_sharedness_gt1["breakpointID"])
)

show(filt_bpsample_gt1_poldf, maxBytes=0)

In [None]:
filt_bpsample_gt1_tnbc_poldf = filt_bpsample_gt1_poldf.with_columns(
    pl.col('sampleID').str.replace('P', '').cast(pl.Int64).alias('sampleID')
)
filt_bpsample_gt1_tnbc_poldf =  filt_bpsample_gt1_tnbc_poldf.filter(pl.col('sampleID') < 173)

show(filt_bpsample_gt1_tnbc_poldf, maxBytes=0)

##### **SANITY CHECK**

We can also check the original Pandas df that was normal-filtered.

In [None]:
print("Original number of rows:", len(bp_sample_array_pdf))
print("Number of rows after removing duplicates:", len(bpsample_pdf_unique))
print("Number of duplicates removed:", len(bp_sample_array_pdf) - len(bpsample_pdf_unique))

In [None]:
# Remove 'P' prefix and convert to integer
bpsample_pdf_unique_converted = bpsample_pdf_unique.copy()
bpsample_pdf_unique_converted['sampleID'] = bpsample_pdf_unique_converted['sampleID'].str.replace('P', '').astype(int)
show(bpsample_pdf_unique_converted, maxBytes=0)

# Group by breakpointID and count unique sampleIDs
breakpoint_counts_tnbc = bpsample_pdf_unique_converted.groupby('breakpointID')['sampleID'].nunique().reset_index()

# Rename the column for clarity
breakpoint_counts_tnbc = breakpoint_counts_tnbc.rename(columns={'sampleID': 'sharednessDegree'})

tnbc_breakpoints_gt1_pandf = breakpoint_counts_tnbc[breakpoint_counts_tnbc['sharednessDegree'] > 1]
show(tnbc_breakpoints_gt1_pandf, maxBytes=0)

Now filter for TNBC.

In [None]:
tnbc_only_breakpoints_poldf = pl.from_pandas(bpsample_pdf_unique_converted).filter(pl.col('sampleID') < 173)

show(tnbc_only_breakpoints_poldf, maxBytes=0)

In [None]:
tnbc_only_gt1_breakpoints_uniq_poldf = tnbc_only_breakpoints_poldf.filter(
    pl.col("breakpointID").is_in(tnbc_breakpoints_gt1_pandf["breakpointID"]))

show(tnbc_only_gt1_breakpoints_uniq_poldf, maxBytes=0)

Instantiate NetworkAnalyzer class.

In [None]:
analyzer_tnbc_filt_gt1 = NetworkAnalyzer(tnbc_only_gt1_breakpoints_uniq_poldf, patient_col='sampleID', breakpoint_col='breakpointID')

# Print summary statistics
analyzer_tnbc_filt_gt1.print_summary_stats()

In [None]:
filt_tnbc_gt1_matrix = analyzer_tnbc_filt_gt1.adj_matrix_sparse.toarray()

In [None]:
breakpoint_dict_gt1_tnbc = analyzer_tnbc_filt_gt1.breakpoint_idx_dict
filt_bp_indices_gt1_tnbc, filt_bp_labels_gt1_tnbc, coverage_gt1_tnbc = find_minimal_covering_subset(filt_tnbc_gt1_matrix, coverage_threshold=1.0, label_to_index_dict=breakpoint_dict_gt1_tnbc)

In [None]:
print(f"Minimal coverage percent (%): {coverage_gt1_tnbc * 100}")
print(f"Minimal set of breakpoints: {filt_bp_indices_gt1_tnbc}")
print(f"Length of set: {len(filt_bp_indices_gt1_tnbc)}")
# print(f"Labels of the minimal cover set of breakpoints: {filt_bp_labels_gt1}")

In [None]:
print(set(filt_bp_labels_gt1_tnbc))
print(len(set(filt_bp_labels_gt1_tnbc)))

##### Read 85 Breakpoint List from Joyce for Filtering

We can read the breakpoint list from Joyce to restrict the `tnbc_only_gt1_breakpoints_uniq_poldf` dataframe to just TNBC data. 

In [None]:
# Read breakpointIDs from text file
with open('../data/tnbc_breakpoints.txt', 'r') as f:
    breakpoint_list = [line.strip() for line in f]

print(breakpoint_list)
print(len(set(breakpoint_list)))

In [None]:
# Read breakpointIDs from text file
with open('../data/tnbc_breakpoints_validated.txt', 'r') as f:
    breakpoint_val_list = [line.strip() for line in f]

print(breakpoint_val_list)
print(len(set(breakpoint_val_list)))

# # Filter DataFrame to keep only rows where breakpointID is in the list
# tnbc_connection_validated_poldf = tnbc_only_gt1_breakpoints_uniq_poldf.filter(pl.col('breakpointID').is_in(breakpoint_list))

# show(tnbc_connection_validated_poldf, maxBytes=0)

In [None]:
for i in breakpoint_val_list:
	if i in set(filt_bp_labels_gt1):
		print(i)

In [None]:
print(set(breakpoint_list) & set(filt_bp_labels_gt1))
print(len(set(breakpoint_list) & set(filt_bp_labels_gt1)))

#### Function to Work Out The Coverage of a Specific Subset of Breakpoints

In [None]:
def calculate_coverage_from_breakpoints(adjacency_matrix, selected_breakpoints, label_to_index_dict=None):
    """
    Calculate the coverage of samples achieved by a given set of breakpoints.
    
    Parameters:
    -----------
    adjacency_matrix : np.ndarray
        Binary matrix where rows are breakpoints and columns are samples
        1 indicates overlap, 0 indicates no overlap
    selected_breakpoints : list
        List of breakpoint indices or labels to evaluate
    label_to_index_dict : dict, optional
        Dictionary mapping breakpoint labels to matrix indices
        Example: {'breakpoint1': 0, 'breakpoint2': 1, ...}
    
    Returns:
    --------
    tuple
        (coverage_fraction, covered_samples, uncovered_samples)
        - coverage_fraction: Fraction of samples covered by the selected breakpoints
        - covered_samples: Indices of covered samples
        - uncovered_samples: Indices of samples not covered by any selected breakpoint
    """
    import numpy as np
    
    # Convert labels to indices if dictionary is provided
    if label_to_index_dict is not None:
        try:
            selected_indices = [label_to_index_dict[label] for label in selected_breakpoints]
        except KeyError as e:
            raise KeyError(f"Breakpoint label {e} not found in label dictionary")
    else:
        selected_indices = selected_breakpoints
    
    # Validate indices
    if max(selected_indices) >= adjacency_matrix.shape[0]:
        raise ValueError("Selected breakpoint index exceeds matrix dimensions")
    
    # Extract rows for selected breakpoints
    selected_matrix = adjacency_matrix[selected_indices]
    
    # Calculate which samples are covered (have at least one 1 in any selected breakpoint)
    covered_samples = np.where(np.any(selected_matrix == 1, axis=0))[0]
    
    # Calculate which samples are not covered
    uncovered_samples = np.where(~np.any(selected_matrix == 1, axis=0))[0]
    
    # Calculate coverage fraction
    total_samples = adjacency_matrix.shape[1]
    coverage_fraction = len(covered_samples) / total_samples
    
    return coverage_fraction, covered_samples, uncovered_samples

# Example usage:
if __name__ == "__main__":
    import numpy as np
    
    # Example adjacency matrix (5 breakpoints × 8 samples)
    matrix = np.array([
        [1, 1, 0, 0, 1, 0, 0, 1],  # breakpoint 0
        [0, 1, 1, 0, 0, 1, 0, 0],  # breakpoint 1
        [1, 0, 0, 1, 0, 0, 1, 0],  # breakpoint 2
        [0, 0, 1, 1, 0, 0, 0, 1],  # breakpoint 3
        [1, 0, 0, 0, 1, 1, 0, 0],  # breakpoint 4
    ])
    
    # Example with indices
    selected_breakpoints = [0, 2]  # Using breakpoints 0 and 2
    coverage, covered, uncovered = calculate_coverage_from_breakpoints(matrix, selected_breakpoints)
    
    print(f"Coverage with breakpoints {selected_breakpoints}: {coverage:.2%}")
    print(f"Covered samples: {covered}")
    print(f"Uncovered samples: {uncovered}")
    
    # Example with labels
    labels = {
        'break_A': 0,
        'break_B': 1,
        'break_C': 2,
        'break_D': 3,
        'break_E': 4
    }
    
    selected_labels = ['break_A', 'break_C']
    coverage, covered, uncovered = calculate_coverage_from_breakpoints(
        matrix, 
        selected_labels,
        labels
    )
    
    print(f"\nCoverage with breakpoints {selected_labels}: {coverage:.2%}")
    print(f"Covered samples: {covered}")
    print(f"Uncovered samples: {uncovered}")

In [None]:
breakpoint_val_list_edit = ['10:49885203-10:42791627', '18:32092886-18:32068298', '19:804438-19:726138', '1:26239838-1:26197686', '11:102448812-11:102610030', '8:60743097-8:61376605', '3:132713126-3:138131282', '9:92108740-9:92210889']

coverage, covered, uncovered = calculate_coverage_from_breakpoints(filt_tnbc_gt1_matrix, breakpoint_val_list_edit, label_to_index_dict=breakpoint_dict_gt1_tnbc)

print(f"Coverage with breakpoints {breakpoint_val_list_edit}: {coverage:.2%}")
print(f"Covered samples: {covered}")
print(f"Uncovered samples: {uncovered}")