In [2]:
"""Tests for the TreatmentEffect class in the treatment_effect module."""

from typing import List

import pandas as pd

from medmodels import MedRecord
from medmodels.medrecord.types import NodeIndex


def create_patients(patient_list: List[NodeIndex]) -> pd.DataFrame:
    """
    Create a patients dataframe.

    Returns:
        pd.DataFrame: A patients dataframe.
    """
    patients = pd.DataFrame(
        {
            "index": ["P1", "P2", "P3", "P4", "P5", "P6", "P7", "P8", "P9"],
            "age": [20, 30, 40, 30, 40, 50, 60, 70, 80],
            "gender": [
                "male",
                "female",
                "male",
                "female",
                "male",
                "female",
                "male",
                "female",
                "male",
            ],
        }
    )

    patients = patients.loc[patients["index"].isin(patient_list)]
    return patients


def create_diagnoses() -> pd.DataFrame:
    """
    Create a diagnoses dataframe.

    Returns:
        pd.DataFrame: A diagnoses dataframe.
    """
    diagnoses = pd.DataFrame(
        {
            "index": ["D1"],
            "name": ["Stroke"],
        }
    )
    return diagnoses


def create_prescriptions() -> pd.DataFrame:
    """
    Create a prescriptions dataframe.

    Returns:
        pd.DataFrame: A prescriptions dataframe.
    """
    prescriptions = pd.DataFrame(
        {
            "index": ["M1", "M2"],
            "name": ["Rivaroxaban", "Warfarin"],
        }
    )
    return prescriptions


def create_edges1(patient_list: List[NodeIndex]) -> pd.DataFrame:
    """
    Create an edges dataframe.

    Returns:
        pd.DataFrame: An edges dataframe.
    """
    edges = pd.DataFrame(
        {
            "source": [
                "M2",
                "M1",
                "M2",
                "M1",
                "M2",
                "M1",
                "M2",
            ],
            "target": [
                "P1",
                "P2",
                "P2",
                "P3",
                "P5",
                "P6",
                "P9",
            ],
            "time": [
                "1999-10-15",
                "2000-01-01",
                "1999-12-15",
                "2000-01-01",
                "2000-01-01",
                "2000-01-01",
                "2000-01-01",
            ],
        }
    )
    edges = edges.loc[edges["target"].isin(patient_list)]
    return edges


def create_edges2(patient_list: List[NodeIndex]) -> pd.DataFrame:
    """
    Create an edges dataframe with attribute "intensity".

    Returns:
        pd.DataFrame: An edges dataframe.
    """
    edges = pd.DataFrame(
        {
            "source": [
                "D1",
                "D1",
                "D1",
                "D1",
                "D1",
                "D1",
            ],
            "target": [
                "P1",
                "P2",
                "P3",
                "P3",
                "P4",
                "P7",
            ],
            "time": [
                "2000-01-01",
                "2000-07-01",
                "1999-12-15",
                "2000-01-05",
                "2000-01-01",
                "2000-01-01",
            ],
            "intensity": [
                0.1,
                0.2,
                0.3,
                0.4,
                0.5,
                0.6,
            ],
        }
    )
    edges = edges.loc[edges["target"].isin(patient_list)]
    return edges


def create_medrecord(
    patient_list: List[NodeIndex] = [
        "P1",
        "P2",
        "P3",
        "P4",
        "P5",
        "P6",
        "P7",
        "P8",
        "P9",
    ],
) -> MedRecord:
    """
    Create a MedRecord object.

    Returns:
        MedRecord: A MedRecord object.
    """
    patients = create_patients(patient_list=patient_list)
    diagnoses = create_diagnoses()
    prescriptions = create_prescriptions()
    edges1 = create_edges1(patient_list=patient_list)
    edges2 = create_edges2(patient_list=patient_list)
    medrecord = MedRecord.from_pandas(
        nodes=[(patients, "index"), (diagnoses, "index"), (prescriptions, "index")],
        edges=[(edges1, "source", "target")],
    )
    medrecord.add_group(group="patients", nodes=patients["index"].to_list())
    medrecord.add_group(
        "Stroke",
        ["D1"],
    )
    medrecord.add_group(
        "Rivaroxaban",
        ["M1"],
    )
    medrecord.add_group(
        "Warfarin",
        ["M2"],
    )
    medrecord.add_edges((edges2, "source", "target"))
    return medrecord

In [3]:
medrecord = create_medrecord()
medrecord.edges_connecting(medrecord.nodes_in_group(medrecord.groups[0]), medrecord.nodes_in_group(medrecord.groups[1]), directed=False)

[]

In [4]:
groups = medrecord.groups
relation_types = [
    (group1, group2)
    for group1 in groups
    for group2 in groups
    if medrecord.edges_connecting(medrecord.nodes_in_group(group1), medrecord.nodes_in_group(group2))
]

In [5]:
from medmodels.predictive_modelling.hsgnn.hsgnn_preprocessing import HSGNNPreprocessor
hsgnn_preprocessing = HSGNNPreprocessor(medrecord)
meta_paths = hsgnn_preprocessing.find_metapaths()

In [6]:
from medmodels.medrecord.querying import node
medrecord.select_nodes(node().has_neighbor_with(node().index().equal("P1")))

['M2', 'D1']

In [7]:
import numpy as np
nodes_groups = [
    medrecord.nodes_in_group(group) for group in meta_paths[0]
]

# Flatten the list of node groups into a single nodelist
node_index = {}
current_index = 0
nodes_ranges = []

# Populate the node_index and nodes_ranges
for group_nodes in nodes_groups:
    start_index = current_index
    for node1 in group_nodes:
        if node1 not in node_index:
            node_index[node1] = current_index
            current_index += 1
    end_index = current_index
    nodes_ranges.append(range(start_index, end_index))

nodelist = list(node_index.keys())

In [8]:
def create_full_adjacency_matrix(nodelist, medrecord):
    n = len(nodelist)
    rows = []
    cols = []
    data = []
    
    for i, node1 in enumerate(nodelist):
        for j, node2 in enumerate(nodelist):
            edges = medrecord.edges_connecting(node1, node2, directed=False)
            if edges:
                rows.append(i)
                cols.append(j)
                data.append(len(edges))
    
    return csr_matrix((data, (rows, cols)), shape=(n, n))


from scipy.sparse import coo_matrix, csr_matrix
# Flatten the list of node groups into a single nodelist
node_indices = {}
current_index = 0
nodes_ranges = []

# Populate the node_index and nodes_ranges
for group_nodes in nodes_groups:
    start_index = current_index
    for node_name in group_nodes:
        if node_name not in node_indices:
            node_indices[node_name] = current_index
            current_index += 1
    end_index = current_index
    nodes_ranges.append(range(start_index, end_index))

nodelist = list(node_indices.keys())
full_adjacency_matrix = create_full_adjacency_matrix(nodelist, medrecord)
path_count_matrix = csr_matrix((len(nodes_groups[0]), len(nodes_groups[-1])))
path_count_matrix.setdiag(1)

for i in range(len(nodes_groups) - 1):
    current_nodes = nodes_groups[i]
    next_nodes = nodes_groups[i + 1]
    
    current_indices = [node_indices[node] for node in current_nodes]
    next_indices = [node_indices[node] for node in next_nodes]
    
    adj_submatrix = full_adjacency_matrix[current_indices, :][:, next_indices]
    path_count_matrix = path_count_matrix @ adj_submatrix

  

In [10]:
hsgnn_preprocessing.compute_all_subgraphs()

Computing similarity subgraphs:   0%|          | meta-path | 0/12 [00:00<?, ?it/s]


IndexError: range object index out of range

In [None]:
import numpy as np
adjancency_matrix = np.array([len(medrecord.edges_connecting(node1, node2, directed=False)) if medrecord.edges_connecting(node1, node2, directed=False) else 0 for node1 in nodelist for node2 in nodelist])
adjancency_matrix = adjancency_matrix.reshape((len(nodelist), len(nodelist)))
adjancency_matrix

array([[0, 0, 2, 1, 0, 1, 1, 0, 1, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [2, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
       [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [None]:
nodelist

{'D1', 'P1', 'P2', 'P3', 'P4', 'P5', 'P6', 'P7', 'P8', 'P9'}