In [None]:
# /*==========================================================================================*\
# **                        _           _ _   _     _  _         _                            **
# **                       | |__  _   _/ | |_| |__ | || |  _ __ | |__                         **
# **                       | '_ \| | | | | __| '_ \| || |_| '_ \| '_ \                        **
# **                       | |_) | |_| | | |_| | | |__   _| | | | | | |                       **
# **                       |_.__/ \__,_|_|\__|_| |_|  |_| |_| |_|_| |_|                       **
# \*==========================================================================================*/


# -----------------------------------------------------------------------------------------------
# Author: Bùi Tiến Thành - Tien-Thanh Bui (@bu1th4nh)
# Title: X-intNMF-main.ipynb
# Date: 2025/03/17 16:11:50
# Description: 
# 
# (c) 2025 bu1th4nh / UCF Computational Biology Lab. All rights reserved. 
# Written with dedication in the University of Central Florida, EPCOT and the Magic Kingdom.
# -----------------------------------------------------------------------------------------------


# Libraries
import os
import sys
import logging
import numpy as np
import pandas as pd
from tqdm import tqdm
from colorlog import ColoredFormatter
from typing import List, Dict, Any, Tuple, Union, Literal

# Model
from model.crossOmicNMF import XIntNMF

# Settings
# Logging
logging.root.handlers = [];
handler_sh = logging.StreamHandler(sys.stdout)
handler_sh.setFormatter(
    ColoredFormatter(
        "%(cyan)s%(asctime)s.%(msecs)03d %(log_color)s[%(levelname)s]%(reset)s %(light_white)s%(message)s%(reset)s %(blue)s(%(filename)s:%(lineno)d)",
        datefmt  = '%Y/%m/%d %H:%M:%S',
        log_colors={
            'DEBUG': 'white',
            'INFO': 'green',
            'WARNING': 'yellow',
            'ERROR': 'red',
            'CRITICAL': 'red,bg_white',
        }
    )
)
logging.basicConfig(
    level    = logging.INFO,
    handlers = [handler_sh]
)

# Data Acquisition

In [None]:
mRNA = pd.read_parquet('sample_processed_data/BRCA_micro_dataset/mRNA.parquet')
miRNA = pd.read_parquet('sample_processed_data/BRCA_micro_dataset/miRNA.parquet')
DNAMethylation = pd.read_parquet('sample_processed_data/BRCA_micro_dataset/mRNA.parquet')

mRNA_miRNA_graph = pd.read_parquet('sample_processed_data/BRCA_micro_dataset/mRNA_miRNA_graph.parquet')

# Settings

In [None]:
k_latent_components = 10
alpha_graph_reg = 100
beta_omics_factors_sparsity_reg = 100

# Model

In [None]:
model = XIntNMF(
    omics_layers = [
        mRNA.to_numpy(np.float64, True), 
        miRNA.to_numpy(np.float64, True), 
        DNAMethylation.to_numpy(np.float64, True)
    ],
    cross_omics_interaction = {
        (0, 1): mRNA_miRNA_graph.to_numpy(np.float64, True) # 0-1 here is for the relative index of omics_layers list on the input
    },
    k = k_latent_components,
    alpha = alpha_graph_reg,
    betas = beta_omics_factors_sparsity_reg,
    gammas = 1,
    max_iter = 5000,
    tol = 1e-4,
    verbose = True,
    backend = 'numpy'
)

# Solve
Ws, H = model.solve(run_mode='full', use_cross_validation=True)


# Post-process
sample_list = mRNA.columns
omics_features_list = [mRNA.index, miRNA.index, DNAMethylation.index]
latent_columns = [f"Latent_{i:03}" for i in range(k_latent_components)]


# Save the results
sample_factor_matrix_H = pd.DataFrame(H, index=latent_columns, columns=sample_list)
omics_factor_matrices_Ws = [pd.DataFrame(W, index=omics_features_list[i], columns=latent_columns) for i, W in enumerate(Ws)]

