# DREAM5

# Data Exploration

First, I want to assert a couple of statements:
- The set of regulating genes present in the reference network is a subset of the set of transcription factors.
- There are elements of the set of transcription factors in the set of the reference network's target genes. 

We need to encapsulate the data wrangling for each source of datasets. Here (TODO: Subject to change), the source of datasets is either 'DREAM5' or 'BEELINE'.

In [37]:
from typing import Literal
NETWORK_ID : Literal[1, 3, 4] = 3

In [58]:
from pathlib import Path
from typing import List, Dict
import numpy as np
import pandas as pd
from numpy.typing import NDArray


class Dream5GRNDataset:
    def __init__(
        self,
        root : Path,
        network_id: int,
    ):
        self.network_id_to_directory_name = {
            1: Path("Network 1 - in silico"),
            # 2: Path("Network 2 - S. aureus"), # Not used for evaluation
            3: Path("Network 3 - E. coli"),
            4: Path("Network 4 - S. cerevisiae"),
        }

        self.network_dir = self.network_id_to_directory_name[NETWORK_ID]

        self.root: Path = Path("../data/raw/syn2787209/Gene Network Inference")
        self.training_data_dir = self.root / "training data"
        self.reference_network_dir = self.root / "test data"

        self.gene_expression_path = (
            self.training_data_dir
            / self.network_dir
            / f"net{network_id}_expression_data.tsv"
        )
        self.id_to_name_path = (
            self.training_data_dir / self.network_dir / f"net{network_id}_gene_ids.tsv"
        )
        self.transcription_factors_path = (
            self.training_data_dir
            / self.network_dir
            / f"net{network_id}_transcription_factors.tsv"
        )
        self.network_data_path = (
            self.reference_network_dir
            / f"DREAM5_NetworkInference_GoldStandard_{str(self.network_dir).replace(f'Network {network_id}', f'Network{network_id}')}.tsv"
        )

        self.gene_expressions = pd.read_csv(
            self.gene_expression_path,
            sep="\t",
            dtype=np.float32,
        )
        self.gene_ids = pd.read_csv(
            self.id_to_name_path,
            sep="\t",
            dtype=str,
        )
        self.gene_ids = dict(self.gene_ids.values)
        self.transcription_factors = pd.read_csv(
            self.transcription_factors_path,
            sep="\t",
            header=None,
            dtype=str,
        )

        self.ref_network = pd.read_csv(
            self.network_data_path,
            sep="\t",
            header=None,
            dtype={
                0: str,
                1: str,
                2: float,
            },
        )
        
        self.ref_network.columns = [
            "regulator_gene",
            "target_gene",
            "ground_truth",
        ]
        self.transcription_factors = self.transcription_factors[0]
        self.transcription_factors.name = "transcription_factors"

        self.gene_expressions = Dream5GRNDataset._map_gene_ids_to_names_for_expression_data(
            self.gene_expressions, self.gene_ids
        )
        self.ref_network = Dream5GRNDataset._map_gene_ids_to_names_for_network_data(
            self.ref_network, self.gene_ids
        )
        self.transcription_factors = Dream5GRNDataset._map_gene_ids_to_names_for_transcription_factors(
            self.transcription_factors, self.gene_ids
        )

    @staticmethod
    def _map_gene_ids_to_names_for_expression_data(
        gene_expressions: pd.DataFrame,
        gene_ids: Dict[str, str],
    ) -> pd.DataFrame:
        gene_expressions.columns = gene_expressions.columns.map(gene_ids)
        return gene_expressions

    @staticmethod
    def _map_gene_ids_to_names_for_network_data(
        network: pd.DataFrame, gene_ids: Dict[str, str]
    ) -> pd.DataFrame:
        network["regulator_gene"] = network["regulator_gene"].map(gene_ids)
        network["target_gene"] = network["target_gene"].map(gene_ids)
        return network

    @staticmethod
    def _map_gene_ids_to_names_for_transcription_factors(
        transcription_factors: pd.Series, gene_ids: Dict[str, str]
    ) -> pd.Series:
        transcription_factors = transcription_factors.map(gene_ids)
        return transcription_factors
    
    @staticmethod
    def _get_transcription_factor_indices(
        transcription_factors: pd.Series,
    ) -> List[int]:
        assert len(set(transcription_factors)) == len(
            transcription_factors
        ), "Transcription factors are not unique"
        transcription_factor_indices = list(range(len(transcription_factors)))
        return transcription_factor_indices

    def get_inputs(self):
        inputs : NDArray = self.gene_expressions.values
        transcription_factor_indices : List[int] = (
            Dream5GRNDataset._get_transcription_factor_indices(
                self.transcription_factors
            )
        )
        return inputs, transcription_factor_indices


In [68]:
ROOT = Path("../data/raw/syn2787209/Gene Network Inference")
NETWORK_ID = 1

dataset = Dream5GRNDataset(
    ROOT,
    NETWORK_ID,
)
ref_network, transcription_factors = dataset.ref_network, dataset.transcription_factors
inputs, transcription_factor_indices = dataset.get_inputs()

In [69]:
# Check if the regulators and targets in the reference network are a subset of the transcription factors
unique_tfs = set(transcription_factors.unique())
unique_regulators_ref_network = set(ref_network["regulator_gene"].unique())
unique_targets_ref_network = set(ref_network["target_gene"].unique())

print("Number of TFs: ", len(unique_tfs))
print(
    f"All transcription factor entries are unique: {len(unique_tfs) == len(transcription_factors)}"
)
print(
    "Number of unique regulators in ref network: ",
    len(unique_regulators_ref_network),
)
print(
    f"All regulator genes in the reference network are unique: {len(unique_regulators_ref_network) == len(ref_network['regulator_gene'].unique())}"
)
print(
    f"Set of regulator genes present in the reference network is subset of set of TFs: {unique_regulators_ref_network.issubset(unique_tfs)}"
)
print(
    "Number of unique target genes in ref network: ",
    len(unique_targets_ref_network),
)

print(
    f"Set of regulator genes in the reference network that are not TFs: {unique_regulators_ref_network.difference(unique_tfs)}"
)
print(
    f"Transciption factors are present in the set of target genes: {any(unique_tfs.intersection(unique_targets_ref_network))}"
)

Number of TFs:  195
All transcription factor entries are unique: True
Number of unique regulators in ref network:  178
All regulator genes in the reference network are unique: True
Set of regulator genes present in the reference network is subset of set of TFs: True
Number of unique target genes in ref network:  1565
Set of regulator genes in the reference network that are not TFs: set()
Transciption factors are present in the set of target genes: True


In [70]:
from fedgenie3.genie3.modeling import GENIE3

tree_method = "RF"
tree_init_kwargs = {
    "n_estimators":  100,
    "max_features": 'sqrt',
    "random_state": 42,
    "n_jobs": -1,
}
genie3 = GENIE3(tree_method=tree_method, tree_init_kwargs=tree_init_kwargs)

In [71]:
importance_matrix = genie3.compute_importances(inputs, transcription_factor_indices)

Computing importances:   0%|          | 0/1643 [00:00<?, ?gene/s]

In [None]:
gene_ranking = genie3.get_gene_ranking(importance_matrix, transcription_factor_indices)
gene_ranking

Unnamed: 0,regulator_gene,target_gene,importance
0,187,937,0.182317
1,83,589,0.180860
2,94,469,0.169391
3,94,1105,0.167880
4,83,426,0.167562
...,...,...,...
320380,72,72,0.000000
320381,39,39,0.000000
320382,19,19,0.000000
320383,10,10,0.000000


In [73]:
def map_gene_indices_to_names_for_network_data(network : str, gene_expressions):
    def fn(x, gene_expressions):
        gene_names = gene_expressions.columns
        x = gene_names[x]
        return x

    gene_cols = ['regulator_gene', 'target_gene']
    network[gene_cols] = network[gene_cols].apply(lambda x : fn(x, gene_expressions), axis=0)
    return network

gene_ranking = map_gene_indices_to_names_for_network_data(gene_ranking, dataset.gene_expressions)

In [76]:
from fedgenie3.genie3.eval import evaluate

evaluate(gene_ranking, ref_network)

{'auroc': 0.8240701674817695, 'aupr': 0.2711146041283269}