# Imports

In [None]:
#!pip install matcouply tensorly tqdm

In [None]:
#!pip install scanpy sccellfie cell2cell

In [None]:
import numpy as np
import tensorly as tl
from tqdm import tqdm
import matcouply

from copy import deepcopy
from matcouply.decomposition import cmf_aoadmm
from matcouply.penalties import MatricesPenalty, NonNegativity,Parafac2
from scipy.io import loadmat

In [None]:
import sccellfie
import scanpy as sc
import cell2cell as c2c

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import pickle

In [None]:
date = '2025-07-03'
folder = './Mes/'

In [None]:
import os
os.makedirs(folder, exist_ok=True)

In [None]:
adata = sc.read("/lustre/scratch126/cellgen/vento/hm11/with_Erick/adata_mes_all_harm_anno_p1.h5ad", backed='r')
adata

In [None]:
adata2 = sc.read("/lustre/scratch126/cellgen/vento/hm11/with_Erick/cdata_thy_noStressEmb.h5ad", backed='r')
adata2

In [None]:
age_mapper = adata2.obs[['sample', 'age_group']].drop_duplicates().set_index('sample')['age_group'].to_dict()

In [None]:
all([s in age_mapper.keys() for s in adata.obs['sample'].unique()])

In [None]:
adata.obs['age_group'] = adata.obs['sample'].map(age_mapper)

In [None]:
celltypes = ['mes_CYGB', 'mes_KCNB2']

In [None]:
adata.obs['sample'].unique()

In [None]:
adata.obs['age_group'].unique()

In [None]:
new_groups = ['9-10', '11-13', '14-15', '16-17', '20']
def assign_group(x):
    for g in new_groups:
        if str(x) in g.split('-'):
            return g
        if str(x) == '12':
            return '11-13'

In [None]:
adata.obs['age_group2'] = adata.obs['pcw'].apply(lambda x: assign_group(x))

In [None]:
adata.obs['age_group2'].unique()

In [None]:
adata = adata[adata.obs['celltype'].isin(celltypes)].to_memory()
adata.X = adata.layers['counts']

In [None]:
adata.shape

In [None]:
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)
sc.pp.log1p(adata)

In [None]:
adata_og = adata.copy()

In [None]:
#sc.pp.highly_variable_genes(adata, n_top_genes=250, flavor="seurat")
#adata.shape

In [None]:
def getProteinCodingGenes(adata, group='protein_coding'):
    '''return protein-coding genes in adata
    example: getProteinCodingGenes(adata)'''
    import pandas as pd
    coding_genes = pd.read_csv('/lustre/scratch126/cellgen/vento/hm11/with_Erick/gene_symbol_type.tsv', sep='\t', header=None)
    cgenes = coding_genes.loc[coding_genes[2] == 'protein_coding',1].tolist()
    cgenesInData_logic = adata.var_names.isin(cgenes)
    print(f'[INFO] coding genes found in adata {cgenesInData_logic.sum()} ({len(cgenes)} total protein-coding genes)')
    return adata.var.loc[cgenesInData_logic,:].index.tolist()

In [None]:
def getExpPercGroup(adata, groups, min_cells = 20,  min_pct = 0.1):
    '''return expressed cells percentage and number per group, output genes and groups
    sf.getExpPercGroup(adata[adata.obs['celltype'].isin(['thy_TH_processing', 'thy_Lumen-forming'])], ['age_group','celltype','karyotype'])'''
    import numpy as np
    import pandas as pd
    df = adata.obs[groups].agg('_'.join, axis=1).rename('agg_col').reset_index()
    agg_dict = df.groupby('agg_col')['index'].apply(list).to_dict()
    perc_dict = {}
    exp_dict = {}
    for k,v in agg_dict.items():
        siz = np.size(adata[v,:].X, axis=0)
        exp = np.sum(adata[v,:].X > 0, axis=0).A1 
        perc_dict[k] = exp/siz
        exp_dict[k] = exp
        
    p = pd.DataFrame.from_dict(perc_dict).set_index(adata.var_names)
    e = pd.DataFrame.from_dict(exp_dict).set_index(adata.var_names)
    HI_GENES = ((e >= min_cells) * (p >= min_pct)).any(axis=1)
    HI_GENES = HI_GENES[HI_GENES].index.tolist()
    print(f'[INFO] {len(agg_dict.keys())} aggs using {groups} at min_cells={min_cells},  min_pct={min_pct}. {len(HI_GENES)} genes')
    return HI_GENES, list(agg_dict.keys())

In [None]:
protein_encoding = getProteinCodingGenes(adata)

In [None]:
adata = adata[:, protein_encoding]

In [None]:
filter_genes, groups = getExpPercGroup(adata, ['age_group2'])

In [None]:
groups

In [None]:
adata = adata[:, filter_genes]

In [None]:
adata.shape

In [None]:
groups = 'age_group2'

In [None]:
agg_dfs = []
for group in tqdm(new_groups): # sorted(adata.obs[groups].unique(), key= lambda x: int(x))
    print(group)
    adata_tmp = adata[adata.obs[groups] == group, :] # adata.var["highly_variable"]
    agg = sccellfie.expression.aggregation.agg_expression_cells(adata_tmp, "celltype", layer=None, gene_symbols=None, agg_func='trimean')
    agg_dfs.append(agg)

In [None]:
# Drop all zero genes
for i, agg in enumerate(agg_dfs):
    if i == 0:
        agg_filter = (agg.sum(axis=0) > 0)
    else:
        agg_filter = agg_filter | (agg.sum(axis=0) > 0)

agg_dfs = [agg.T[agg_filter].T for agg in agg_dfs]

In [None]:
gene_names = [agg.columns.tolist() for agg in agg_dfs]
cell_names = agg_dfs[0].index.tolist()
time_names = new_groups
all(lst == gene_names[0] for lst in gene_names)

In [None]:
with open(f'{folder}/tPARAFAC2_tensor_indexes-{date}.pkl', 'wb') as f:
    pickle.dump([cell_names, gene_names, time_names], f)

In [None]:
tensor = tl.tensor([df.values.T for df in agg_dfs]).T
tensor.shape # cell types by genes by time points

In [None]:
# mu = adata.X.mean()
# print(mu)
# tensor = tensor / (tensor + mu) # regularization

# Custom penalty class for matCoupLy

In [None]:
class myTemporalSmoothnessPenalty(MatricesPenalty):
    def __init__(
        self, smoothness_l, aux_init="random_uniform", dual_init="random_uniform"
    ):
        super().__init__(aux_init=aux_init, dual_init=dual_init)
        self.smoothness_l = smoothness_l

    @copy_ancestor_docstring
    def factor_matrices_update(self, factor_matrices, feasibility_penalties, auxes):

        # factor_matrices: factor + mus
        # feasability_penalties: rhos
        # auxes: -||-

        # rhs = [rhos[i] * factor_matrices[i] for i in range(len(B_is))]

        B_is = factor_matrices
        rhos = feasibility_penalties

        rhs = [rhos[i] * factor_matrices[i] for i in range(len(B_is))]

        # Construct matrix A to peform gaussian elimination on

        A = np.zeros((len(B_is), len(B_is)))

        for i in range(len(B_is)):
            for j in range(len(B_is)):
                if i == j:
                    A[i, j] = 4 * self.smoothness_l + rhos[i]
                elif i == j - 1 or i == j + 1:
                    A[i, j] = -2 * self.smoothness_l
                else:
                    pass

        A[0, 0] -= 2 * self.smoothness_l
        A[len(B_is) - 1, len(B_is) - 1] -= 2 * self.smoothness_l

        # Peform GE

        for k in range(1, A.shape[-1]):
            m = A[k, k - 1] / A[k - 1, k - 1]

            A[k, :] = A[k, :] - m * A[k - 1, :]
            rhs[k] = rhs[k] - m * rhs[k - 1]  # Also update the respective rhs!

        # Back-substitution

        new_ZBks = [np.empty_like(B_is[i]) for i in range(len(B_is))]

        new_ZBks[-1] = rhs[-1] / A[-1, -1]
        q = new_ZBks[-1]

        for i in range(A.shape[-1] - 2, -1, -1):
            q = (rhs[i] - A[i, i + 1] * q) / A[i, i]
            new_ZBks[i] = q

        return new_ZBks

    def penalty(self, x):
        penalty = 0
        for x1, x2 in zip(x[:-1], x[1:]):
            penalty += np.sum((x1 - x2) ** 2)
        return self.smoothness_l * penalty

# Loading the data

In [None]:
data = tensor

In [None]:
data.shape

# Fitting a PARAFAC2 model with AO-ADMM

In [None]:
# Run X initializations of tPARAFAC2
initializations = 10
rank = 6
factors_list = []
diagnostics_list = []

for init_no in tqdm(range(initializations)):

    input_data = deepcopy(data)

    (weights, (D, B, A)), diagnostics = cmf_aoadmm(
        matrices=input_data.T, # Has to be .T because of tensorly
        rank=rank, # No of components
        regs=[
            [NonNegativity()], # Mode-3 constraints
            [Parafac2(),myTemporalSmoothnessPenalty(smoothness_l=200)], # Mode-2 constraints
            [NonNegativity()], # Mode-1 constraints
        ],
        l1_penalty=[0, 0, 20], # Lasso penalties for each mode [mode-3,mode-2,mode-1] / sparsity # [0, 0, 20],
        l2_penalty=[20, 0, 0], # Ridge penalties for each mode [mode-3,mode-2,mode-1] / low values
        return_errors=True,
        n_iter_max=8000,
        inner_n_iter_max=5, # inner admm iters
        tol=1e-8,
        absolute_tol=1e-10,
        feasibility_tol=1e-6,
        inner_tol=1e-5,
        verbose=500, # print intermediate run info every 500 iters
        random_state=init_no # if you would like to fix the initiliaizations
    )
    factors_list.append([D,B,A])
    diagnostics_list.append(diagnostics)

In [None]:
def plot_convergence(diagnostics_per_init,factors_per_init,zoom_in_to_first_n=50,filename=None):
    '''
    Plot convergence of all initializations in the following format:

    rel_sse | parafac2 constraint feasiblity gap
    -----------------------------------------------
    total_loss | temporal smoothness feasibility gap

    Degenerate cases are not plotted.
    '''

    inits2ignore = [False] * len(factors_per_init)
    
    for init_no in range(len(factors_per_init)):
        
        if check_degenerate(factors_per_init[init_no]) == True:
            
            inits2ignore[init_no] = True
            print(f'Initialization {init_no} is degenerate and will not be plotted.')

    fig, axs = plt.subplot_mosaic([['rec_errors','parafac2_feasibility'],['reg_loss','smoothness_feasiblity']],figsize=(12,6))

    plt.tight_layout()

    max_iters = max([diag.n_iter for diag in diagnostics_per_init])

    all_min_cur_error = []
    all_min_cur_parafac2_feasibility = []
    all_min_cur_reg_loss = []
    all_min_cur_smoothness_feasibility = []

    all_max_cur_error = []
    all_max_cur_parafac2_feasibility = []
    all_max_cur_reg_loss = []
    all_max_cur_smoothness_feasibility = []

    all_median_cur_error = []
    all_median_cur_parafac2_feasibility = []
    all_median_cur_reg_loss = []
    all_median_cur_smoothness_feasibility = []

    for iter_no in range(max_iters):

        # Form a list of all diagnostics for this initialization at the current iteration

        cur_rer_errors = []
        cur_parafac2_feasibility = []
        cur_reg_loss = []
        cur_smoothness_feasibility = []

        for init_no in range(len(factors_per_init)):

            if inits2ignore[init_no] == False and iter_no <= diagnostics_per_init[init_no].n_iter:

                cur_rer_errors.append(diagnostics_per_init[init_no].rec_errors[iter_no])
                # cur_rer_errors.append(diagnostics_per_init[init_no].un_rec_errors[iter_no])
                cur_parafac2_feasibility.append(diagnostics_per_init[init_no].feasibility_gaps[iter_no][1][0])
                try:
                    cur_smoothness_feasibility.append(diagnostics_per_init[init_no].feasibility_gaps[iter_no][1][1])
                except:
                    pass
                cur_reg_loss.append(diagnostics_per_init[init_no].regularized_loss[iter_no])

        all_min_cur_error.append(min(cur_rer_errors))
        all_max_cur_error.append(max(cur_rer_errors))
        all_median_cur_error.append(np.median(cur_rer_errors))

        all_min_cur_parafac2_feasibility.append(min(cur_parafac2_feasibility))
        all_max_cur_parafac2_feasibility.append(max(cur_parafac2_feasibility))
        all_median_cur_parafac2_feasibility.append(np.median(cur_parafac2_feasibility))

        all_min_cur_reg_loss.append(min(cur_reg_loss))
        all_max_cur_reg_loss.append(max(cur_reg_loss))
        all_median_cur_reg_loss.append(np.median(cur_reg_loss))

        try:
            all_min_cur_smoothness_feasibility.append(min(cur_smoothness_feasibility))
            all_max_cur_smoothness_feasibility.append(max(cur_smoothness_feasibility))
            all_median_cur_smoothness_feasibility.append(np.median(cur_smoothness_feasibility))
        except:
            pass

    # Plot the area between min and max errors and the median error
    axs['rec_errors'].fill_between(range(max_iters),all_min_cur_error,all_max_cur_error,color='tab:blue',alpha=0.2)
    axs['rec_errors'].plot(range(max_iters),all_median_cur_error,color='tab:blue',label='rec_errors')
    axs['rec_errors'].set_title('Relative SSE')

    # Plot the area between min and max parafac2 feasibility and the median parafac2 feasibility
    try:
        axs['parafac2_feasibility'].fill_between(range(max_iters),all_min_cur_parafac2_feasibility,all_max_cur_parafac2_feasibility,color='tab:orange',alpha=0.2)
        axs['parafac2_feasibility'].plot(range(max_iters),all_median_cur_parafac2_feasibility,color='tab:orange',label='parafac2_feasibility')
        axs['parafac2_feasibility'].set_title('Parafac2 feasibility')
    except:
        pass

    # Plot the area between min and max reg loss and the median reg loss
    axs['reg_loss'].fill_between(range(max_iters),all_min_cur_reg_loss,all_max_cur_reg_loss,color='tab:green',alpha=0.2)
    axs['reg_loss'].plot(range(max_iters),all_median_cur_reg_loss,color='tab:green',label='reg_loss')
    axs['reg_loss'].set_title('Total loss')

    # Plot the area between min and max smoothness feasibility and the median smoothness feasibility
    try:
        axs['smoothness_feasiblity'].fill_between(range(max_iters),all_min_cur_smoothness_feasibility,all_max_cur_smoothness_feasibility,color='tab:red',alpha=0.2)
        axs['smoothness_feasiblity'].plot(range(max_iters),all_median_cur_smoothness_feasibility,color='tab:red',label='smoothness_feasiblity')
        axs['smoothness_feasiblity'].set_title('Smoothness feasibility')
    except:
        pass

    zoomed_in_rec_error = fig.add_axes([0.285,0.765,0.2,0.2])

    zoomed_in_rec_error.fill_between(range(max_iters)[:zoom_in_to_first_n],all_min_cur_error[:zoom_in_to_first_n],all_max_cur_error[:zoom_in_to_first_n],color='tab:blue',alpha=0.2)
    zoomed_in_rec_error.plot(range(max_iters)[:zoom_in_to_first_n],all_median_cur_error[:zoom_in_to_first_n],color='tab:blue',label='rec_errors')
    zoomed_in_rec_error.set_yticks([0.2,0.4,0.6,0.8])

    try:
        zoomed_in_parafac2_feasibility = fig.add_axes([0.778,0.765,0.2,0.2])
        zoomed_in_parafac2_feasibility.fill_between(range(max_iters)[:zoom_in_to_first_n],all_min_cur_parafac2_feasibility[:zoom_in_to_first_n],all_max_cur_parafac2_feasibility[:zoom_in_to_first_n],color='tab:orange',alpha=0.2)
        zoomed_in_parafac2_feasibility.plot(range(max_iters)[:zoom_in_to_first_n],all_median_cur_parafac2_feasibility[:zoom_in_to_first_n],color='tab:orange',label='parafac2_feasibility')
        zoomed_in_parafac2_feasibility.set_yticks([0.2,0.4,0.6,0.8])
        zoomed_in_parafac2_feasibility.set_yscale('log')
    except:
        pass

    zoomed_in_reg_loss = fig.add_axes([0.285,0.278,0.2,0.2])
    zoomed_in_reg_loss.fill_between(range(max_iters)[:zoom_in_to_first_n],all_min_cur_reg_loss[:zoom_in_to_first_n],all_max_cur_reg_loss[:zoom_in_to_first_n],color='tab:green',alpha=0.2)
    zoomed_in_reg_loss.plot(range(max_iters)[:zoom_in_to_first_n],all_median_cur_reg_loss[:zoom_in_to_first_n],color='tab:green',label='reg_loss')
    zoomed_in_reg_loss.set_yscale('log')

    try:
        zoomed_in_smoothness_feasiblity = fig.add_axes([0.778,0.278,0.2,0.2])
        zoomed_in_smoothness_feasiblity.fill_between(range(max_iters)[:zoom_in_to_first_n],all_min_cur_smoothness_feasibility[:zoom_in_to_first_n],all_max_cur_smoothness_feasibility[:zoom_in_to_first_n],color='tab:red',alpha=0.2)
        zoomed_in_smoothness_feasiblity.plot(range(max_iters)[:zoom_in_to_first_n],all_median_cur_smoothness_feasibility[:zoom_in_to_first_n],color='tab:red',label='smoothness_feasiblity')
        zoomed_in_smoothness_feasiblity.set_yscale('log')
    except:
        pass
    if filename is not None:
        plt.savefig(filename, dpi=300, bbox_inches='tight')
    
from tensorly.cp_tensor import CPTensor
from tlviz.factor_tools import degeneracy_score

# Function for checking degeneracy

def check_degenerate(factors,threshold=-0.85):
    '''
    Check solution for degenerecy (just a wrapper for tlviz degeneracy score).
    '''

    A = factors[0]
    B = factors[1]
    D = factors[2]

    new_B = np.zeros((len(B)*B[0].shape[0],B[0].shape[-1]))

    for r in range(B[0].shape[-1]):
        
        b_temp = B[0][:,r]

        for k in range(1,len(B)):

            b_temp = np.concatenate((b_temp,B[k][:,r]))

        new_B[:,r] = b_temp

    decomp = CPTensor((np.array([1.0]*rank),(A,new_B,D)))

    if degeneracy_score(decomp) < threshold: return True
    else: return False

In [None]:
plot_convergence(diagnostics_list,factors_list,10, f'{folder}/{date}-Plot-Convergence.pdf')


# Choose the best run according to loss

In [None]:
best_factors = factors_list[0]
best_error = diagnostics_list[0].regularized_loss[-1]

if len(factors_list) > 1:
    for init_no in range(1,len(best_factors)):
        if diagnostics_list[init_no].regularized_loss[-1] < best_error:
            best_factors = factors_list[init_no]
            best_errors = diagnostics_list[init_no].regularized_loss[-1]
else:
    best_factors = factors_list[0]
    best_errors = diagnostics_list[0].regularized_loss[-1]

In [None]:
import pickle
# Save the complex structure
with open(f'{folder}/tPARAFAC2_factors-{date}.pkl', 'wb') as f:
    pickle.dump(best_factors, f)

In [None]:
import pickle

# Load the complex structure
with open(f'{folder}/tPARAFAC2_factors-{date}.pkl', 'rb') as f:
    best_factors = pickle.load(f)

In [None]:
D, B, A = best_factors

In [None]:
D.shape, len(B), B[0].shape, A.shape

In [None]:
def plot_Bs(est_B, J, K, rank, filename):
    """
    Plot B factors computed by PARAFAC2.
    
    Parameters:
    est_B (list of np.array): List of B factor matrices
    rank (int): Number of factors (previously no_of_concepts)
    J (int): Number of variables
    K (int): Number of time points
    """
    fig, axes = plt.subplots(1, rank, figsize=(4*rank, 10))
    plt.tight_layout()
    fig.suptitle('B factors', fontsize=40)
    plt.subplots_adjust(wspace=0.3, top=0.85)

    # Normalize the B factors
    for r in range(rank):
        b_temp = est_B[0][:, r]
        for k in range(1, K):
            b_temp = np.concatenate((b_temp, est_B[k][:, r]))
        
        norm = np.linalg.norm(b_temp)
        for k in range(K):
            est_B[k][:, r] = est_B[k][:, r] / norm

    for pattern_no in range(rank):
        B_2_plot = form_plotting_B(est_B, pattern_no, J, K)

        ax = axes[pattern_no] if rank > 1 else axes
        im = sns.heatmap(B_2_plot.T, 
                         ax=ax,
                         cbar=(pattern_no == rank-1),  # Only show colorbar for the last plot
                         cmap='Reds',
                         #vmin=-0.4, 
                         #vmax=0.4
                        )
        
        ax.tick_params(left=False, bottom=True)
        ax.patch.set_edgecolor('black')
        ax.set_yticks([])
        ax.patch.set_linewidth(1.5)
        ax.set_xticks(np.arange(0.5, adata.obs['age_group2'].unique().shape[0]+0.5, 1), labels=new_groups,fontsize=3.5, rotation=90)
        ax.set_xlabel(r'time',fontsize=16)
        ax.set_ylabel(r'genes',fontsize=16)
        ax.set_yticks([])
        ax.set_title(f'Factor {pattern_no+1}',pad=3.5,fontsize=36)

    plt.tight_layout()
    if filename is not None:
        plt.savefig(filename, dpi=300, bbox_inches='tight')

def form_plotting_B(B_list, pattern_no, J, K):
    """
    Takes as input a list of B factors and returns a matrix containing
    the pattern_no-th column of each factor matrix.
    """
    matrix2return = np.zeros((K, J))
    for k in range(K):
        matrix2return[k, :] = B_list[k][:, pattern_no].T
    return matrix2return

In [None]:
def euclidean_normalization(matrix):
    # Calculate the Euclidean norm for each gene (row)
    norms = np.linalg.norm(matrix, axis=0, keepdims=True)

    # Normalize each gene by its Euclidean norm
    normalized_matrix = matrix / norms

    return normalized_matrix

In [None]:
plot_Bs(B, data.shape[1], data.shape[2], rank, f'{folder}/Plot-Factors-B.pdf')

In [None]:
mode1_df = pd.DataFrame(A,index=agg_dfs[0].index, columns=[f'Factor {i+1}' for i in range(rank)])
sns.heatmap(mode1_df, cmap='Reds', vmin=0)
plt.savefig(f'{folder}/Plot-Factors-A.pdf', dpi=300, bbox_inches='tight')

In [None]:
def vmin_func(data):
        return -np.abs(data).max()
    
def vmax_func(data):
    return np.abs(data).max()

In [None]:
factor_genes = dict()
percentile = 99
for factor in range(1, rank+1):
    plot_df = pd.DataFrame(form_plotting_B(B,factor-1,tensor.shape[1], tensor.shape[2]).T, index=agg_dfs[0].columns, columns=new_groups)
    
    p = np.percentile(plot_df.abs().max(axis=1).values, percentile)
    top_n = plot_df.abs().max(axis=1).sort_values(ascending=False)
    top_n = top_n[top_n >= p].index.tolist()
    factor_genes[f'Factor {factor}'] = top_n

In [None]:
factor_sum = {'shared' : ['Factor 5', 'Factor 6'],
              'mes_CYGB' : ['Factor 1'], 
              'mes_KCNB2' : ['Factor 2', 'Factor 3', 'Factor 4']}

In [None]:
from collections import defaultdict

factor_sum_genes = defaultdict(set)
total_genes = set()

for case in ['shared', 'mes_CYGB', 'mes_KCNB2']:
    for f in factor_sum[case]:
        for g in factor_genes[f]:
            if g not in factor_sum_genes['shared']:
                factor_sum_genes[case].add(g)
            total_genes.add(g)

In [None]:
len(total_genes)

In [None]:
ct1 = 'mes_CYGB'
ct2 = 'mes_KCNB2'
t1 = pd.concat([df.loc[[ct1], list(total_genes)].rename(index={ct1 : '_'.join([ct1, month])}) for df, month in zip(agg_dfs, new_groups)])
t2 = pd.concat([df.loc[[ct2], list(total_genes)].rename(index={ct2 : '_'.join([ct2, month])}) for df, month in zip(agg_dfs, new_groups)])
plot_df = pd.concat([t1, t2]).T.drop_duplicates()

cm = sns.clustermap(plot_df, cmap='Reds', col_cluster=False, row_cluster=True, yticklabels=1, figsize=(9, 30), standard_scale=0,
                    cbar_kws={"shrink": 0.5, "label": "Scaled Expression", "location" : "left"},  # Reduce colorbar size
                    dendrogram_ratio=(.1, .2),  # Adjust dendrogram sizes (left, top)
                    cbar_pos=(0.02, 0.85, .03, .1),  # Adjust colorbar position (left, bottom, width, height),
                    method='ward'
                   )


cm.ax_heatmap.set_xticklabels(cm.ax_heatmap.get_xmajorticklabels(), fontsize = 16)
cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_ymajorticklabels(), fontsize = 6)

plt.savefig(f'{folder}/Plot-Expression-All.pdf', dpi=300, bbox_inches='tight')

In [None]:
plot_df.to_csv(f'{folder}/DataFrame-tPARAFAC.csv')

In [None]:
plot_df.shape

In [None]:
ct1 = 'mes_CYGB'
ct2 = 'mes_KCNB2'
t1 = pd.concat([df.loc[[ct1], list(factor_sum_genes['shared'])].rename(index={ct1 : '_'.join([ct1, month])}) for df, month in zip(agg_dfs, new_groups)])
t2 = pd.concat([df.loc[[ct2], list(factor_sum_genes['shared'])].rename(index={ct2 : '_'.join([ct2, month])}) for df, month in zip(agg_dfs, new_groups)])
plot_df = pd.concat([t1, t2]).T.drop_duplicates()

cm = sns.clustermap(plot_df, cmap='Reds', col_cluster=False, row_cluster=True, yticklabels=1, figsize=(9, 24), standard_scale=0,
                    cbar_kws={"shrink": 0.5, "label": "Scaled Expression", "location" : "left"},  # Reduce colorbar size
                    dendrogram_ratio=(.1, .2),  # Adjust dendrogram sizes (left, top)
                    cbar_pos=(0.02, 0.85, .03, .1),  # Adjust colorbar position (left, bottom, width, height)
                   )


cm.ax_heatmap.set_xticklabels(cm.ax_heatmap.get_xmajorticklabels(), fontsize = 16)
cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_ymajorticklabels(), fontsize = 6)

plt.savefig(f'{folder}/Plot-Expression-shared.pdf', dpi=300, bbox_inches='tight')

In [None]:
ct1 = 'mes_CYGB'
ct2 = 'mes_KCNB2'
t1 = pd.concat([df.loc[[ct1], list(factor_sum_genes[ct1])].rename(index={ct1 : '_'.join([ct1, month])}) for df, month in zip(agg_dfs, new_groups)])
t2 = pd.concat([df.loc[[ct2], list(factor_sum_genes[ct1])].rename(index={ct2 : '_'.join([ct2, month])}) for df, month in zip(agg_dfs, new_groups)])
plot_df = pd.concat([t1, t2]).T.drop_duplicates()

cm = sns.clustermap(plot_df, cmap='Reds', col_cluster=False, row_cluster=True, yticklabels=1, figsize=(9, 24), standard_scale=0,
                    cbar_kws={"shrink": 0.5, "label": "Scaled Expression", "location" : "left"},  # Reduce colorbar size
                    dendrogram_ratio=(.1, .2),  # Adjust dendrogram sizes (left, top)
                    cbar_pos=(0.02, 0.85, .03, .1),  # Adjust colorbar position (left, bottom, width, height)
                   )


cm.ax_heatmap.set_xticklabels(cm.ax_heatmap.get_xmajorticklabels(), fontsize = 16)
cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_ymajorticklabels(), fontsize = 6)

plt.savefig(f'{folder}/Plot-Expression-thy_Lumen-forming.pdf', dpi=300, bbox_inches='tight')

In [None]:
ct1 = 'mes_CYGB'
ct2 = 'mes_KCNB2'
t1 = pd.concat([df.loc[[ct1], list(factor_sum_genes[ct2])].rename(index={ct1 : '_'.join([ct1, month])}) for df, month in zip(agg_dfs, new_groups)])
t2 = pd.concat([df.loc[[ct2], list(factor_sum_genes[ct2])].rename(index={ct2 : '_'.join([ct2, month])}) for df, month in zip(agg_dfs, new_groups)])
plot_df = pd.concat([t1, t2]).T.drop_duplicates()

cm = sns.clustermap(plot_df, cmap='Reds', col_cluster=False, row_cluster=True, yticklabels=1, figsize=(9, 24), standard_scale=0,
                    cbar_kws={"shrink": 0.5, "label": "Scaled Expression", "location" : "left"},  # Reduce colorbar size
                    dendrogram_ratio=(.1, .2),  # Adjust dendrogram sizes (left, top)
                    cbar_pos=(0.02, 0.85, .03, .1),  # Adjust colorbar position (left, bottom, width, height)
                   )


cm.ax_heatmap.set_xticklabels(cm.ax_heatmap.get_xmajorticklabels(), fontsize = 16)
cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_ymajorticklabels(), fontsize = 6)

plt.savefig(f'{folder}/Plot-Expression-thy_TH_processing.pdf', dpi=300, bbox_inches='tight')

In [None]:
with open(f'{folder}/Genes-{date}.pkl', 'wb') as f:
    pickle.dump(factor_sum_genes, f)

In [None]:
'SLC5A5' in [g for v in factor_sum_genes.values() for g in v]