<a href="https://colab.research.google.com/github/ayyucedemirbas/Llama-HINN/blob/main/Llama_HINN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install captum torchinfo

In [1]:
!git clone https://github.com/bozdaglab/HINN.git

Cloning into 'HINN'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (23/23), done.[K
remote: Total 24 (delta 5), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (24/24), 3.77 MiB | 2.88 MiB/s, done.
Resolving deltas: 100% (5/5), done.


In [2]:
%cd HINN

/content/HINN


In [3]:
import os
import torch
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForCausalLM
import warnings
warnings.filterwarnings('ignore')

os.environ["KERAS_BACKEND"] = "torch"

import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary

import captum
from captum.attr import DeepLift
import plotly.express as px
import plotly.graph_objects as go
import plotly.colors as pc

In [4]:
def load_and_process_data():
    def preprocess(file_path, suffix):
        df = pd.read_csv(file_path)
        df.index = df.iloc[:, 0]
        df = df.drop(df.columns[0], axis=1)
        df.columns = [f"{col}_{suffix}" for col in df.columns]
        return df

    expression = preprocess("gene_data.csv", "expression")
    methy = preprocess("methyl_data.csv", "methy")
    snp = preprocess("snp_data.csv", "snp")
    demograph = pd.read_csv("demo_label_data.csv", usecols=range(7))
    demograph.index = demograph.iloc[:, 0]
    demograph = demograph.drop(demograph.columns[0], axis=1)
    demograph.columns = [f"{col}_demograph" for col in demograph.columns]

    label = pd.read_csv("demo_label_data.csv", usecols=[0, 8])
    label.index = label.iloc[:, 0]
    label = label.drop(label.columns[0], axis=1)
    label.columns = [f"{col}_label" for col in label.columns]


    data = snp.join(expression, how="inner") \
              .join(methy, how="inner") \
              .join(demograph, how="inner") \
              .join(label, how="inner")
    return data


In [5]:
class PrimaryInputLayer(nn.Module):
    def __init__(self, units, output_dim, activation="sigmoid", mask=None):
        super().__init__()
        self.units = units
        self.output_dim = output_dim

        if activation == "sigmoid":
            self.activation = nn.Sigmoid()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.w = nn.Parameter(torch.empty(units, output_dim))
        self.b = nn.Parameter(torch.zeros(output_dim))
        nn.init.xavier_normal_(self.w)

        if mask is None:
            raise ValueError("mask tensor is required")
        self.register_buffer("mask", mask.float())

    def forward(self, x):
        masked_w = self.w * self.mask
        out = x @ masked_w + self.b
        return self.activation(out)


class SecondaryInputLayer(nn.Module):
    def __init__(self, units):
        super().__init__()
        self.units = units
        self.register_buffer("mask", torch.eye(units))
        self.w = nn.Parameter(torch.empty(units, units))
        nn.init.xavier_normal_(self.w)

    def forward(self, x):
        masked_w = self.w * self.mask
        return x @ masked_w


class MultiplicationInputLayer(nn.Module):
    def __init__(self, units, activation="sigmoid"):
        super().__init__()
        self.units = units

        if activation == "sigmoid":
            self.activation = nn.Sigmoid()
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        self.b = nn.Parameter(torch.zeros(units))
        nn.init.xavier_normal_(self.b.unsqueeze(0))

    def forward(self, x):
        return self.activation(x + self.b)

In [6]:
class CustomDataset(Dataset):
    def __init__(self, inputs, targets):
        self.inputs = inputs
        self.targets = targets

    def __len__(self):
        return len(self.targets)

    def __getitem__(self, idx):
        return [input[idx] for input in self.inputs], self.targets[idx]

In [7]:
class EarlyStopping:
    def __init__(self, patience=50, delta=0.0, restore_best_weights=True):
        self.patience = patience
        self.delta = delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = float("inf")
        self.counter = 0
        self.best_model_state = None

    def __call__(self, val_loss, model):
        if isinstance(val_loss, torch.Tensor):
            val_loss = val_loss.item()

        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            if self.restore_best_weights:
                self.best_model_state = {k: v.cpu().clone() for k, v in model.state_dict().items()}
        else:
            self.counter += 1

        if self.counter >= self.patience:
            print(f"Early stopping triggered. Best val_loss = {self.best_loss:.4f}")
            if self.restore_best_weights and self.best_model_state is not None:
                model.load_state_dict(self.best_model_state)
            return True
        return False


In [8]:
def train_model_torch(model, train_loader, val_loader, device="cuda",
                      lr=1e-3, epochs=1000, patience=500):
    criterion = torch.nn.L1Loss()  # MAE
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    early_stopper = EarlyStopping(patience=patience, delta=0.0, restore_best_weights=True)

    model.to(device)

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        for inputs, targets in train_loader:
            inputs = [x.to(device).float() for x in inputs]
            targets = targets.to(device).float()

            if inputs[0].size(0) == 1:
                model.eval()
                with torch.no_grad():
                    outputs = model(*inputs).squeeze()
                model.train()
            else:
                optimizer.zero_grad()
                outputs = model(*inputs).squeeze()
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                train_loss += loss.item() * targets.size(0)

        train_loss /= len(train_loader.dataset)

        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs = [x.to(device).float() for x in inputs]
                targets = targets.to(device).float()
                outputs = model(*inputs).squeeze()
                loss = criterion(outputs, targets)
                val_loss += loss.item() * targets.size(0)

        val_loss /= len(val_loader.dataset)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1:03d} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f}")

        if early_stopper(val_loss, model):
            print(f"Stopping at epoch {epoch+1}")
            break

    return model

In [9]:
def evaluate_model_torch(model, test_loader, device="cuda"):
    model.eval()
    model.to(device)

    all_targets = []
    all_preds = []

    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs = [x.to(device).float() for x in inputs]
            targets = targets.to(device).float().unsqueeze(1)
            preds = model(*inputs)
            all_targets.append(targets.cpu().numpy())
            all_preds.append(preds.cpu().numpy())

    y_true = np.concatenate(all_targets, axis=0).squeeze()
    y_pred = np.concatenate(all_preds, axis=0).squeeze()

    mse = np.mean((y_true - y_pred) ** 2)
    mae = np.mean(np.abs(y_true - y_pred))

    return {"mse": mse, "mae": mae, "y_true": y_true, "y_pred": y_pred}

In [10]:
def interpret_model(model, test_inputs, baselines, device="cuda"):
    model.eval()
    model.to(device)

    test_inputs = tuple(t.to(device) for t in test_inputs)
    baselines = tuple(b.to(device) for b in baselines)

    explainer = DeepLift(model)
    attributions = explainer.attribute(
        test_inputs,
        baselines=baselines,
        return_convergence_delta=False,
    )
    return attributions

In [11]:
def export_attributions(attributions, feature_names, save_path_prefix):
    for i, name in enumerate(['snp', 'methy', 'gene', 'demo']):
        df = pd.DataFrame(attributions[i].detach().cpu().numpy(), columns=feature_names[i])
        df.to_csv(f"{save_path_prefix}_{name}.csv", index=False)

In [12]:
def filter_matrices_by_top_features(snp_list, methy_list, gene_list,
                                     sparse_methy, sparse_gene, sparse_pathway):
    subset_methy_matrix = sparse_methy.loc[snp_list, methy_list]
    subset_gene_matrix = sparse_gene.loc[methy_list, gene_list]
    subset_pathway_matrix = sparse_pathway.loc[gene_list, :]

    subset_methy_matrix = subset_methy_matrix.loc[subset_methy_matrix.any(axis=1) == 1, subset_methy_matrix.any(axis=0)]
    subset_gene_matrix = subset_gene_matrix.loc[subset_gene_matrix.any(axis=1) == 1, subset_gene_matrix.any(axis=0)]
    subset_pathway_matrix = subset_pathway_matrix.loc[subset_pathway_matrix.index.isin(subset_gene_matrix.columns)]
    subset_pathway_matrix = subset_pathway_matrix.loc[subset_pathway_matrix.any(axis=1) == 1, subset_pathway_matrix.any(axis=0)]

    return subset_methy_matrix, subset_gene_matrix, subset_pathway_matrix

In [13]:
def summarize_connections(*matrices):
    connection_counts = [int(matrix.sum().sum()) for matrix in matrices]
    labels = ["SNP-Methylation", "Methylation-Gene", "Gene-Pathway"]
    for label, count in zip(labels, connection_counts):
        print(f"Total connections ({label}): {count}")

In [14]:
def build_edge_list(subset_methy_matrix, subset_gene_matrix, subset_pathway_matrix):
    edges_snp_methy = (
        subset_methy_matrix[subset_methy_matrix == 1]
        .stack()
        .reset_index()
    )
    edges_snp_methy.columns = ["source", "target", "value"]
    edges_snp_methy["layer"] = "snp_methy"

    edges_methy_gene = (
        subset_gene_matrix[subset_gene_matrix == 1]
        .stack()
        .reset_index()
    )
    edges_methy_gene.columns = ["source", "target", "value"]
    edges_methy_gene["layer"] = "methy_gene"

    edges_gene_go = (
        subset_pathway_matrix[subset_pathway_matrix == 1]
        .stack()
        .reset_index()
    )
    edges_gene_go.columns = ["source", "target", "value"]
    edges_gene_go["layer"] = "gene_go"

    edges_all = pd.concat(
        [edges_snp_methy, edges_methy_gene, edges_gene_go],
        ignore_index=True,
    )
    edges_all["value"] = 1

    return edges_all

In [15]:
def plot_sankey_from_edges(edges_all):
    edges_all_filtered = edges_all.copy()

    nodes = pd.unique(edges_all_filtered[["source", "target"]].values.ravel())

    snps = [
        node for node in nodes
        if ((node.startswith("rs") or ":" in node) and not node.startswith("GO"))
    ]
    methylation = [node for node in nodes if node.startswith("cg")]
    genes = [node for node in nodes if "_at" in node]
    go_terms = [node for node in nodes if node.startswith("GO:")]

    ordered_nodes = snps + methylation + genes + go_terms
    node_indices = {name: i for i, name in enumerate(ordered_nodes)}

    edges_all_filtered = edges_all_filtered[
        edges_all_filtered["source"].isin(ordered_nodes)
        & edges_all_filtered["target"].isin(ordered_nodes)
    ].copy()

    edges_all_filtered["source_index"] = edges_all_filtered["source"].map(node_indices)
    edges_all_filtered["target_index"] = edges_all_filtered["target"].map(node_indices)

    node_positions_x = [
        0.0 if node in snps
        else 0.33 if node in methylation
        else 0.66 if node in genes
        else 0.99
        for node in ordered_nodes
    ]

    unique_colors = pc.qualitative.Dark24
    repeated_colors = (unique_colors * ((len(ordered_nodes) // len(unique_colors)) + 1))[:len(ordered_nodes)]
    node_colors = repeated_colors

    fig = go.Figure(go.Sankey(
        arrangement="snap",
        node=dict(
            pad=10,
            thickness=20,
            line=dict(color="black", width=0.5),
            label=ordered_nodes,
            color=node_colors,
            x=node_positions_x,
        ),
        link=dict(
            source=edges_all_filtered["source_index"],
            target=edges_all_filtered["target_index"],
            value=edges_all_filtered["value"],
        ),
    ))

    fig.update_layout(
        font_size=14,
        height=1500,
        width=2000,
    )
    fig.show()


In [16]:
class HINN(nn.Module):
    def __init__(
        self,
        snp_dim,
        methy_dim,
        exp_dim,
        demo_dim,
        sparse_methy_tensor,
        sparse_gene_tensor,
        sparse_pathway_tensor,
        dense_nodes_1=128,
        drop_rate=0.7,
        activation_function="sigmoid",
    ):
        super().__init__()

        # First block: SNP -> Methy
        self.primary1 = PrimaryInputLayer(
            units=snp_dim,
            output_dim=methy_dim,
            activation=activation_function,
            mask=sparse_methy_tensor,
        )
        self.secondary1 = SecondaryInputLayer(units=methy_dim)
        self.mult1 = MultiplicationInputLayer(
            units=methy_dim,
            activation=activation_function,
        )
        self.snp_fc = nn.Linear(snp_dim, 20)

        # Second block: Methy -> Gene
        self.primary2 = PrimaryInputLayer(
            units=methy_dim,
            output_dim=exp_dim,
            activation=activation_function,
            mask=sparse_gene_tensor,
        )
        self.secondary2 = SecondaryInputLayer(units=exp_dim)
        self.mult2 = MultiplicationInputLayer(
            units=exp_dim,
            activation=activation_function,
        )
        self.mid_fc = nn.Linear(methy_dim + 20, 20)

        # Third block: Gene -> Pathway
        pathway_dim = sparse_pathway_tensor.shape[1]
        self.primary3 = PrimaryInputLayer(
            units=exp_dim,
            output_dim=pathway_dim,
            activation=activation_function,
            mask=sparse_pathway_tensor,
        )
        self.mid_fc2 = nn.Linear(exp_dim + 20, 20)

        # Dense layers
        custom_input_dim = pathway_dim + 20

        self.bn1 = nn.BatchNorm1d(custom_input_dim)
        self.fc1 = nn.Linear(custom_input_dim, dense_nodes_1)
        self.drop1 = nn.Dropout(drop_rate)

        self.bn2 = nn.BatchNorm1d(dense_nodes_1)
        self.fc2 = nn.Linear(dense_nodes_1, dense_nodes_1)
        self.drop2 = nn.Dropout(drop_rate)

        self.bn3 = nn.BatchNorm1d(dense_nodes_1)
        self.fc3 = nn.Linear(dense_nodes_1, dense_nodes_1)
        self.drop3 = nn.Dropout(drop_rate)

        self.bn4 = nn.BatchNorm1d(dense_nodes_1)
        self.fc4 = nn.Linear(dense_nodes_1, dense_nodes_1)
        self.drop4 = nn.Dropout(drop_rate)

        self.dense_fourth = nn.Linear(dense_nodes_1, 20)
        self.bn_demo = nn.BatchNorm1d(20 + demo_dim)
        self.fc_demo = nn.Linear(20 + demo_dim, dense_nodes_1)
        self.drop_demo = nn.Dropout(drop_rate)

        self.out = nn.Linear(dense_nodes_1, 1)

        self.activation_function = activation_function

    def _nonlin(self, x):
        return torch.sigmoid(x)

    def forward(self, snp, methy, exp, demo):
        # First block
        primary1 = self.primary1(snp)
        secondary1 = self.secondary1(methy)
        mult_res1 = primary1 * secondary1
        mult1 = self.mult1(mult_res1)

        snp_fc = self._nonlin(self.snp_fc(snp))
        out2 = torch.cat([mult1, snp_fc], dim=1)

        # Second block
        primary2 = self.primary2(mult1)
        secondary2 = self.secondary2(exp)

        eps = 1e-6
        denom = primary2.clone()
        denom = torch.where(denom.abs() < eps, eps * torch.ones_like(denom), denom)
        div_res1 = secondary2 / denom
        div_res1 = torch.clamp(div_res1, -1e6, 1e6)

        mult2 = self.mult2(div_res1)

        mid_fc = self._nonlin(self.mid_fc(out2))
        out3 = torch.cat([mult2, mid_fc], dim=1)

        # Third block
        primary3 = self.primary3(mult2)
        mid_fc2 = self._nonlin(self.mid_fc2(out3))
        out4 = torch.cat([primary3, mid_fc2], dim=1)

        # Dense stack
        x = self.bn1(out4)
        x = self._nonlin(self.fc1(x))
        x = self.drop1(x)

        x = self.bn2(x)
        x = self._nonlin(self.fc2(x))
        x = self.drop2(x)

        x = self.bn3(x)
        x = self._nonlin(self.fc3(x))
        x = self.drop3(x)

        x = self.bn4(x)
        x = self._nonlin(self.fc4(x))
        x = self.drop4(x)

        dense_fourth = self._nonlin(self.dense_fourth(x))
        demo_concat = torch.cat([dense_fourth, demo], dim=1)

        x = self.bn_demo(demo_concat)
        x = self._nonlin(self.fc_demo(x))
        x = self.drop_demo(x)

        out = self.out(x)
        return out


In [17]:
class MultiOmicReportGenerator:
    def __init__(self, model_name="NousResearch/Llama-2-7b-chat-hf", device="cuda",
                 offload_to_cpu=True, use_8bit=False):

        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        if use_8bit:
            try:
                from transformers import BitsAndBytesConfig
                quantization_config = BitsAndBytesConfig(load_in_8bit=True)
                self.llm = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    quantization_config=quantization_config,
                    device_map="auto",
                    low_cpu_mem_usage=True,
                )
                self.device = device
            except ImportError:
                offload_to_cpu = True
                use_8bit = False

        if not use_8bit:
            if offload_to_cpu and device == "cuda":
                self.llm = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16,
                    device_map="auto",  # Automatically distribute across GPU/CPU
                    low_cpu_mem_usage=True,
                    offload_folder="offload",  # Offload to disk if needed
                    offload_state_dict=True,
                )
                self.device = device
            else:
                self.llm = AutoModelForCausalLM.from_pretrained(
                    model_name,
                    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
                    device_map="auto" if device == "cuda" else None,
                    low_cpu_mem_usage=True,
                )

                if device == "cpu":
                    self.llm = self.llm.to(device)

                self.device = device


    def extract_patient_features(self, attributions, feature_names, top_k=10):
        attr_snp, attr_methy, attr_gene, attr_demo = attributions

        snp_scores = attr_snp.abs().squeeze().detach().cpu().numpy()
        methy_scores = attr_methy.abs().squeeze().detach().cpu().numpy()
        gene_scores = attr_gene.abs().squeeze().detach().cpu().numpy()
        demo_scores = attr_demo.abs().squeeze().detach().cpu().numpy()

        top_snp_idx = np.argsort(-snp_scores)[:top_k]
        top_methy_idx = np.argsort(-methy_scores)[:top_k]
        top_gene_idx = np.argsort(-gene_scores)[:top_k]

        top_features = {
            'snp': [
                {
                    'name': feature_names[0][i].replace('_snp', ''),
                    'score': float(snp_scores[i])
                }
                for i in top_snp_idx
            ],
            'methylation': [
                {
                    'name': feature_names[1][i].replace('_methy', ''),
                    'score': float(methy_scores[i])
                }
                for i in top_methy_idx
            ],
            'gene': [
                {
                    'name': feature_names[2][i].replace('_expression', ''),
                    'score': float(gene_scores[i])
                }
                for i in top_gene_idx
            ],
            'demographics': [
                {
                    'name': feature_names[3][i].replace('_demograph', ''),
                    'score': float(demo_scores[i])
                }
                for i in range(min(len(demo_scores), 5))
            ]
        }

        return top_features

    def get_pathways_for_features(self, top_features, sparse_methy, sparse_gene, sparse_pathway):
        snp_list = [f['name'] for f in top_features['snp']]
        methy_list = [f['name'] for f in top_features['methylation']]
        gene_list = [f['name'] for f in top_features['gene']]

        try:
            # SNP -> Methylation connections
            subset_methy = sparse_methy.loc[
                sparse_methy.index.intersection(snp_list),
                sparse_methy.columns.intersection(methy_list)
            ]

            # Methylation -> Gene connections
            subset_gene = sparse_gene.loc[
                sparse_gene.index.intersection(methy_list),
                sparse_gene.columns.intersection(gene_list)
            ]

            # Gene -> Pathway connections
            subset_pathway = sparse_pathway.loc[
                sparse_pathway.index.intersection(gene_list), :
            ]
            subset_pathway = subset_pathway.loc[:, subset_pathway.any(axis=0)]

            connected_pathways = subset_pathway.columns[subset_pathway.sum(axis=0) > 0].tolist()

            pathway_info = {
                'pathways': connected_pathways[:15],
                'connection_counts': {
                    'snp_to_methylation': int(subset_methy.sum().sum()),
                    'methylation_to_gene': int(subset_gene.sum().sum()),
                    'gene_to_pathway': int(subset_pathway.sum().sum())
                }
            }

        except Exception as e:
            print(f"Warning: Error extracting pathways: {e}")
            pathway_info = {
                'pathways': [],
                'connection_counts': {
                    'snp_to_methylation': 0,
                    'methylation_to_gene': 0,
                    'gene_to_pathway': 0
                }
            }

        return pathway_info

    def create_clinical_prompt(self, patient_id, predicted_score, true_score,
                               top_features, pathway_info):
        prompt = f"""[INST] You are a clinical genetics and neuroscience expert analyzing multi-omic data for cognitive decline prediction.

**PATIENT CASE SUMMARY**
Patient ID: {patient_id}
Predicted MMSE Score: {predicted_score:.2f}
Actual MMSE Score: {true_score:.2f}
Prediction Error: {abs(predicted_score - true_score):.2f} points

**TOP GENETIC VARIANTS (SNPs)**
The following SNPs showed the highest importance in the prediction:
{self._format_features(top_features['snp'][:5])}

**TOP METHYLATION SITES**
The following CpG methylation sites were most influential:
{self._format_features(top_features['methylation'][:5])}

**TOP GENE EXPRESSIONS**
The following genes showed the strongest predictive signal:
{self._format_features(top_features['gene'][:5])}

**BIOLOGICAL PATHWAYS**
These features connect to {len(pathway_info['pathways'])} biological pathways including:
{', '.join(pathway_info['pathways'][:10])}

Network connectivity: {pathway_info['connection_counts']['snp_to_methylation']} SNP-Methylation links,
{pathway_info['connection_counts']['methylation_to_gene']} Methylation-Gene links,
{pathway_info['connection_counts']['gene_to_pathway']} Gene-Pathway links.

**TASK**
Please provide a comprehensive clinical interpretation report with the following sections:

1. **Clinical Interpretation**: Explain what the MMSE score indicates about the patient's cognitive status.

2. **Genetic Risk Factors**: Describe the biological significance of the top SNPs and their known associations with cognitive decline, Alzheimer's disease, or related conditions.

3. **Epigenetic Findings**: Explain the role of the key methylation sites and what changes in methylation at these loci might indicate.

4. **Gene Expression Patterns**: Interpret the expression changes in the top genes and their relevance to neurodegeneration.

5. **Pathway Analysis**: Provide a biological narrative connecting the genetic variants through methylation and gene expression to the identified pathways. What cellular processes are most affected?

6. **Clinical Recommendations**: Suggest potential follow-up investigations or considerations for clinicians based on these findings.

Keep the report professional, scientifically accurate, and accessible to clinicians. Focus on biological mechanisms and clinical relevance. [/INST]"""

        return prompt

    def _format_features(self, features):
        if not features:
            return "None identified"
        return '\n'.join([f"  - {f['name']} (importance: {f['score']:.4f})"
                         for f in features])

    def generate_report(self, prompt, max_length=2048, temperature=0.7):

        inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=4096)
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = self.llm.generate(
                **inputs,
                max_new_tokens=max_length,
                temperature=temperature,
                do_sample=True,
                top_p=0.9,
                num_return_sequences=1,
                pad_token_id=self.tokenizer.eos_token_id,
            )

        full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)

        if "[/INST]" in full_text:
            report = full_text.split("[/INST]")[-1].strip()
        else:
            report = full_text

        return report

    def create_patient_report(self, patient_id, patient_data, predicted_score,
                             true_score, model, feature_names, sparse_matrices,
                             device="cuda"):
        print(f"Generating Report for Patient {patient_id}")

        patient_data = tuple(
            x.unsqueeze(0) if x.dim() == 1 else x
            for x in patient_data
        )

        model.eval()
        patient_data = tuple(x.to(device).requires_grad_(True) for x in patient_data)

        baselines = tuple(torch.zeros_like(x) for x in patient_data)

        explainer = DeepLift(model)
        attributions = explainer.attribute(
            patient_data,
            baselines=baselines,
            return_convergence_delta=False
        )


        top_features = self.extract_patient_features(attributions, feature_names, top_k=10)

        sparse_methy, sparse_gene, sparse_pathway = sparse_matrices
        pathway_info = self.get_pathways_for_features(
            top_features, sparse_methy, sparse_gene, sparse_pathway
        )

        prompt = self.create_clinical_prompt(
            patient_id, predicted_score, true_score, top_features, pathway_info
        )


        report = self.generate_report(prompt, max_length=2048, temperature=0.7)


        return {
            'patient_id': patient_id,
            'predicted_score': predicted_score,
            'true_score': true_score,
            'top_features': top_features,
            'pathway_info': pathway_info,
            'report': report,
            'prompt': prompt
        }


In [18]:
def generate_patient_reports(model, test_loader, test_df, y_test, feature_names,
                            sparse_matrices, device="cuda", num_reports=3,
                            move_hinn_to_cpu=True, use_8bit_llm=False):
    model.eval()
    all_predictions = []

    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = [x.to(device).float() for x in inputs]
            preds = model(*inputs).squeeze()
            all_predictions.append(preds.cpu().numpy())

    all_predictions = np.concatenate(all_predictions)

    errors = np.abs(all_predictions - y_test.values)

    indices_to_report = [
        np.argmin(errors),
        np.argmax(errors),
        np.argsort(errors)[len(errors)//2],
    ]

    indices_to_report = indices_to_report[:num_reports]

    if move_hinn_to_cpu and device == "cuda":
        model = model.cpu() #move the hinn model to the cpu
        torch.cuda.empty_cache()
        hinn_device = "cpu"
    else:
        hinn_device = device

    report_gen = MultiOmicReportGenerator(
        model_name="NousResearch/Llama-2-7b-chat-hf",
        device=device,
        offload_to_cpu=True if device == "cuda" else False,
        use_8bit=use_8bit_llm
    )

    reports = []

    for idx in indices_to_report:
        patient_id = test_df.index[idx]

        snp_data = torch.tensor(
            test_df.filter(like="_snp").iloc[idx].values,
            dtype=torch.float32
        )
        methy_data = torch.tensor(
            test_df.filter(like="_methy").iloc[idx].values,
            dtype=torch.float32
        )
        exp_data = torch.tensor(
            test_df.filter(like="_expression").iloc[idx].values,
            dtype=torch.float32
        )
        demo_data = torch.tensor(
            test_df.filter(like="_demograph").iloc[idx].values,
            dtype=torch.float32
        )

        patient_data = (snp_data, methy_data, exp_data, demo_data)

        report_dict = report_gen.create_patient_report(
            patient_id=patient_id,
            patient_data=patient_data,
            predicted_score=all_predictions[idx],
            true_score=y_test.values[idx],
            model=model,
            feature_names=feature_names,
            sparse_matrices=sparse_matrices,
            device=hinn_device
        )

        reports.append(report_dict)


        print(f"CLINICAL REPORT - Patient {patient_id}")
        print(report_dict['report'])

        with open(f"patient_report_{patient_id}.txt", "w") as f:
            f.write(f"CLINICAL REPORT - Patient {patient_id}\n")
            f.write(report_dict['report'])

        print(f"Report saved to: patient_report_{patient_id}.txt")

    #if move_hinn_to_cpu and device == "cuda":
    #    model = model.to(device)

    return reports

In [19]:
def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"

    data = load_and_process_data()
    y = data["MMSE_label"]
    X = data.drop(columns=[c for c in data.columns if c.endswith("MMSE_label")])

    X_train_int, X_test_df, y_train_int, y_test = train_test_split(
        X, y, test_size=0.3, random_state=42
    )

    X_train_df, X_val_df, y_train, y_val = train_test_split(
        X_train_int, y_train_int, test_size=0.2, random_state=42
    )

    print(f"Train samples: {len(X_train_df)}")
    print(f"Val samples: {len(X_val_df)}")
    print(f"Test samples: {len(X_test_df)}")

    X_train_list = [
        torch.tensor(X_train_df.filter(like=s).values, dtype=torch.float32)
        for s in ["_snp", "_methy", "_expression", "_demograph"]
    ]
    y_train_t = torch.tensor(y_train.values, dtype=torch.float32)

    X_val_list = [
        torch.tensor(X_val_df.filter(like=s).values, dtype=torch.float32)
        for s in ["_snp", "_methy", "_expression", "_demograph"]
    ]
    y_val_t = torch.tensor(y_val.values, dtype=torch.float32)

    X_test_list = [
        torch.tensor(X_test_df.filter(like=s).values, dtype=torch.float32)
        for s in ["_snp", "_methy", "_expression", "_demograph"]
    ]
    y_test_t = torch.tensor(y_test.values, dtype=torch.float32)

    train_dataset = CustomDataset(X_train_list, y_train_t)
    val_dataset = CustomDataset(X_val_list, y_val_t)
    test_dataset = CustomDataset(X_test_list, y_test_t)

    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


    sparse_methy = pd.read_csv("snp_methyl_matrix.csv", index_col=0)
    sparse_gene = pd.read_csv("methyl_gene_matrix.csv.zip", compression='zip', index_col=0)
    sparse_pathway = pd.read_csv("gene_pathway_matrix.csv", index_col=0)

    print(f"SNP-Methylation matrix: {sparse_methy.shape}")
    print(f"Methylation-Gene matrix: {sparse_gene.shape}")
    print(f"Gene-Pathway matrix: {sparse_pathway.shape}")

    sparse_methy_tensor = torch.tensor(sparse_methy.values, dtype=torch.float32)
    sparse_gene_tensor = torch.tensor(sparse_gene.values, dtype=torch.float32)
    sparse_pathway_tensor = torch.tensor(sparse_pathway.values, dtype=torch.float32)


    snp_dim = X_train_list[0].shape[1]
    methy_dim = X_train_list[1].shape[1]
    exp_dim = X_train_list[2].shape[1]
    demo_dim = X_train_list[3].shape[1]

    print(f"Input dimensions:")
    print(f"  SNP: {snp_dim}")
    print(f"  Methylation: {methy_dim}")
    print(f"  Gene Expression: {exp_dim}")
    print(f"  Demographics: {demo_dim}")

    model = HINN(
        snp_dim=snp_dim,
        methy_dim=methy_dim,
        exp_dim=exp_dim,
        demo_dim=demo_dim,
        sparse_methy_tensor=sparse_methy_tensor,
        sparse_gene_tensor=sparse_gene_tensor,
        sparse_pathway_tensor=sparse_pathway_tensor,
        dense_nodes_1=128,
        drop_rate=0.7,
        activation_function="sigmoid",
    )

    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)


    model = train_model_torch(
        model,
        train_loader,
        val_loader,
        device=device,
        lr=1e-3,
        epochs=1000,
        patience=50,
    )

    eval_results = evaluate_model_torch(model, test_loader, device=device)
    print(f"Test MAE: {eval_results['mae']:.4f}")
    print(f"Test MSE: {eval_results['mse']:.4f}")
    print(f"Test RMSE: {np.sqrt(eval_results['mse']):.4f}")

    test_inputs = tuple(
        torch.tensor(arr, dtype=torch.float32, requires_grad=True).to(device)
        for arr in [
            X_test_df.filter(like="_snp").values,
            X_test_df.filter(like="_methy").values,
            X_test_df.filter(like="_expression").values,
            X_test_df.filter(like="_demograph").values
        ]
    )

    baselines = tuple(
        torch.tensor(arr.mean(axis=0), dtype=torch.float32)
        .unsqueeze(0)
        .expand_as(torch.tensor(arr, dtype=torch.float32))
        .to(device)
        for arr in [
            X_test_df.filter(like="_snp").values,
            X_test_df.filter(like="_methy").values,
            X_test_df.filter(like="_expression").values,
            X_test_df.filter(like="_demograph").values
        ]
    )

    attributions = interpret_model(model, test_inputs, baselines, device=device)
    attr_snp, attr_methy, attr_gene, attr_demo = attributions

    snp_importance = attr_snp.abs().mean(dim=0).detach().cpu().numpy()
    methy_importance = attr_methy.abs().mean(dim=0).detach().cpu().numpy()
    gene_importance = attr_gene.abs().mean(dim=0).detach().cpu().numpy()

    print(f"Computed attributions for {len(snp_importance)} SNPs, "
          f"{len(methy_importance)} methylation sites, "
          f"{len(gene_importance)} genes")

    feature_names = [
        X_train_df.filter(like=s).columns.tolist()
        for s in ["_snp", "_methy", "_expression", "_demograph"]
    ]

    export_attributions(attributions, feature_names, "MMSE_attributions")
    print("Attribution files saved with prefix: MMSE_attributions_")

    TOP_SNP = 20
    TOP_METHY = 100
    TOP_GENE = 50

    top_snp_idx = np.argsort(-snp_importance)[:TOP_SNP]
    top_methy_idx = np.argsort(-methy_importance)[:TOP_METHY]
    top_gene_idx = np.argsort(-gene_importance)[:TOP_GENE]

    snp_list = [
        feature_names[0][i].replace("_snp", "")
        for i in top_snp_idx
    ]
    methy_list = [
        feature_names[1][i].replace("_methy", "")
        for i in top_methy_idx
    ]
    gene_list = [
        feature_names[2][i].replace("_expression", "")
        for i in top_gene_idx
    ]

    print(f"  SNPs: {len(snp_list)}")
    print(f"  Methylation: {len(methy_list)}")
    print(f"  Genes: {len(gene_list)}")


    subset_methy_matrix, subset_gene_matrix, subset_pathway_matrix = filter_matrices_by_top_features(
        snp_list, methy_list, gene_list, sparse_methy, sparse_gene, sparse_pathway
    )

    summarize_connections(subset_methy_matrix, subset_gene_matrix, subset_pathway_matrix)

    edges_all = build_edge_list(
        subset_methy_matrix,
        subset_gene_matrix,
        subset_pathway_matrix,
    )

    plot_sankey_from_edges(edges_all)



    sparse_matrices = (sparse_methy, sparse_gene, sparse_pathway)

    reports = generate_patient_reports(
        model=model,
        test_loader=test_loader,
        test_df=X_test_df,
        y_test=y_test,
        feature_names=feature_names,
        sparse_matrices=sparse_matrices,
        device=device,
        num_reports=3,
        move_hinn_to_cpu=True,  # Move HINN to CPU to free GPU memory
        use_8bit_llm=False  # Set to True if you have bitsandbytes installed
    )

    return model, reports


In [20]:
if __name__ == "__main__":
    model, reports = main()

Train samples: 33
Val samples: 9
Test samples: 19
SNP-Methylation matrix: (254, 13688)
Methylation-Gene matrix: (13688, 1727)
Gene-Pathway matrix: (1727, 158)
Input dimensions:
  SNP: 254
  Methylation: 13688
  Gene Expression: 1727
  Demographics: 6
Epoch 010 | train_loss=22.0989 | val_loss=22.4259
Epoch 020 | train_loss=21.5017 | val_loss=21.7850
Epoch 030 | train_loss=20.8181 | val_loss=21.1461
Epoch 040 | train_loss=20.3473 | val_loss=20.4981
Epoch 050 | train_loss=19.9462 | val_loss=19.8380
Epoch 060 | train_loss=19.1264 | val_loss=19.1590
Epoch 070 | train_loss=18.3911 | val_loss=18.4847
Epoch 080 | train_loss=17.7758 | val_loss=17.7912
Epoch 090 | train_loss=17.3699 | val_loss=17.0446
Epoch 100 | train_loss=16.6046 | val_loss=16.2676
Epoch 110 | train_loss=15.2226 | val_loss=15.4404
Epoch 120 | train_loss=14.5437 | val_loss=14.5911
Epoch 130 | train_loss=13.6422 | val_loss=13.6262
Epoch 140 | train_loss=12.2308 | val_loss=12.7106
Epoch 150 | train_loss=11.5046 | val_loss=11.5294

tokenizer_config.json:   0%|          | 0.00/746 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/583 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:00, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.50G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/200 [00:00<?, ?B/s]

Generating Report for Patient 4270
CLINICAL REPORT - Patient 4270
Clinical Interpretation:
The predicted MMSE score of 21.89 indicates that the patient's cognitive status is within the normal range. However, the small prediction error of 0.11 points suggests that there may be some degree of cognitive impairment that has not been captured by the MMSE score. Further evaluation and assessment are necessary to confirm or rule out cognitive decline.

Genetic Risk Factors:
The top SNPs identified in the analysis are associated with various genetic variants that have been linked to cognitive decline and Alzheimer's disease. For example, rs6857 has been associated with an increased risk of Alzheimer's disease, while 11:71654207 has been linked to cognitive decline in older adults. These findings suggest that the patient may be at increased risk for cognitive decline or Alzheimer's disease due to genetic factors.

Epigenetic Findings:
The key methylation sites identified in the analysis are inv