In [40]:
def create_gene_expression_heatmap(expression_tsv, all_genes_file, treatment_genes_file, timepoint_genes_file, rest_genes_file, output_dir):
    """
    Creates a heatmap of gene expression data for specified genes, using the same
    structure as the combined_heatmap_unified function.
    """
    print("Reading input files...")
    
    # Read the gene expression data
    expression_df = pd.read_csv(expression_tsv)
    
    # Display the first few rows to understand the structure
    print("Data structure:")
    print(expression_df.head())
    print(f"Columns: {expression_df.columns.tolist()}")
    
    # Initialize column names
    gene_col = None
    expr_col = None
    sample_col = None
    
    # First check for exact column name matches
    for col in expression_df.columns:
        if col.lower() == 'gene' or 'gene' in col.lower():
            gene_col = col
        elif col.lower() == 'expression_level' or 'expression' in col.lower():
            expr_col = col
        elif col.lower() == 'sample' or 'sample' in col.lower():
            sample_col = col
    
    # If any columns are still missing, try to identify them by content
    if not all([gene_col, expr_col, sample_col]):
        for col in expression_df.columns:
            # Sample first 10 values of each column
            sample_values = expression_df[col].astype(str).head(10).tolist()
            
            # Check if the column contains ENSMUSG IDs (likely gene column)
            if not gene_col and any('ENSMUSG' in str(val) for val in sample_values):
                gene_col = col
            
            # Check if the column contains numeric values (likely expression column)
            if not expr_col:
                try:
                    pd.to_numeric(expression_df[col], errors='raise')
                    expr_col = col
                except:
                    pass
            
            # Check if the column contains sample IDs
            if not sample_col and any('_' in str(val) for val in sample_values) and any(('morning' in str(val).lower() or 'evening' in str(val).lower()) for val in sample_values):
                sample_col = col
    
    # Print which columns were found and their data types
    print(f"Detected columns:")
    print(f"  - Gene column: {gene_col} (type: {expression_df[gene_col].dtype if gene_col else 'N/A'})")
    print(f"  - Expression column: {expr_col} (type: {expression_df[expr_col].dtype if expr_col else 'N/A'})")
    print(f"  - Sample column: {sample_col} (type: {expression_df[sample_col].dtype if sample_col else 'N/A'})")
    
    # If we couldn't detect columns automatically, try using position-based column selection
    if not all([gene_col, expr_col, sample_col]):
        print("Could not detect all column names automatically. Using positional mapping.")
        # Rename columns based on your sample data format
        column_mapping = {}
        if len(expression_df.columns) >= 3:
            column_mapping = {
                expression_df.columns[1]: 'Gene',               # Second column is Gene
                expression_df.columns[2]: 'Expression_level',   # Third column is Expression
                expression_df.columns[3]: 'Sample'              # Fourth column is Sample
            }
            expression_df = expression_df.rename(columns=column_mapping)
            gene_col = 'Gene'
            expr_col = 'Expression_level'
            sample_col = 'Sample'
    else:
        # Rename detected columns to standard names
        expression_df = expression_df.rename(columns={
            gene_col: 'Gene',
            expr_col: 'Expression_level',
            sample_col: 'Sample'
        })
        gene_col = 'Gene'
        expr_col = 'Expression_level'
        sample_col = 'Sample'
    
    # Read the list of genes to include
    with open(all_genes_file, 'r') as f:
        all_genes = [line.strip() for line in f.readlines()]
    
    # Read the gene groups
    with open(treatment_genes_file, 'r') as f:
        treatment_genes = set([line.strip() for line in f.readlines()])
    
    with open(timepoint_genes_file, 'r') as f:
        timepoint_genes = set([line.strip() for line in f.readlines()])
    
    with open(rest_genes_file, 'r') as f:
        rest_genes = set([line.strip() for line in f.readlines()])
    
    # Print gene group statistics for validation
    print(f"Gene group statistics:")
    print(f"  - All genes: {len(all_genes)} genes")
    print(f"  - Treatment genes: {len(treatment_genes)} genes")
    print(f"  - Timepoint genes: {len(timepoint_genes)} genes")
    print(f"  - Rest genes: {len(rest_genes)} genes")
    
    # Check for overlaps between groups (should be none)
    treatment_timepoint_overlap = treatment_genes.intersection(timepoint_genes)
    if treatment_timepoint_overlap:
        print(f"Warning: {len(treatment_timepoint_overlap)} genes appear in both treatment and timepoint groups.")
        print(f"Examples: {list(treatment_timepoint_overlap)[:3]}")
    
    print(f"Processing {len(all_genes)} genes...")
    
    # Filter the expression data to only include the genes of interest
    filtered_expr = expression_df[expression_df['Gene'].isin(all_genes)]
    
    # Check if we have data for all genes
    found_genes = set(filtered_expr['Gene'].unique())
    missing_genes = set(all_genes) - found_genes
    
    if missing_genes:
        print(f"Warning: {len(missing_genes)} genes from the input list were not found in the expression data.")
        print(f"First few missing genes: {list(missing_genes)[:5]}")
    
    # Create a pivot table: rows are samples, columns are genes, values are expression levels
    pivot_df = filtered_expr.pivot_table(
        index='Sample', 
        columns='Gene', 
        values='Expression_level', 
        aggfunc='first'  # In case of duplicates, take the first value
    )
    
    # Check for missing values
    missing_values = pivot_df.isna().sum().sum()
    if missing_values > 0:
        print(f"Warning: {missing_values} missing values in the pivot table. Filling with zeros.")
        pivot_df = pivot_df.fillna(0)
    
    # Extract metadata from sample names
    # Assuming sample names are in format like "01_Ctrl_morning"
    metadata_df = pd.DataFrame(index=pivot_df.index)
    
    # Extract treatment and timepoint from sample names
    sample_parts = metadata_df.index.str.split('_')
    
    # Check if the split worked as expected
    if all(len(parts) >= 3 for parts in sample_parts):
        # Standard format with at least 3 parts
        metadata_df['Treatment'] = sample_parts.str[1]
        metadata_df['Timepoint'] = sample_parts.str[2]
    else:
        # Try a different approach - look for morning/evening keywords
        metadata_df['Timepoint'] = 'Unknown'
        metadata_df.loc[metadata_df.index.str.contains('morning', case=False), 'Timepoint'] = 'morning'
        metadata_df.loc[metadata_df.index.str.contains('evening', case=False), 'Timepoint'] = 'evening'
        
        # For treatment, check common keywords
        metadata_df['Treatment'] = 'Unknown'
        metadata_df.loc[metadata_df.index.str.contains('ctrl', case=False), 'Treatment'] = 'Ctrl'
        metadata_df.loc[metadata_df.index.str.contains('control', case=False), 'Treatment'] = 'Ctrl'
        metadata_df.loc[metadata_df.index.str.contains('treat', case=False), 'Treatment'] = 'Treated'
        metadata_df.loc[metadata_df.index.str.contains('crs', case=False), 'Treatment'] = 'CRS'
    
    # Print sample metadata statistics
    print(f"Sample metadata:")
    print(f"  - Treatment values: {sorted(metadata_df['Treatment'].unique())}")
    print(f"  - Timepoint values: {sorted(metadata_df['Timepoint'].unique())}")
    
    # Reset index to make 'Sample' a regular column
    metadata_df = metadata_df.reset_index()
    
    # Prepare the gene groups for coloring
    gene_groups = {}
    for gene in pivot_df.columns:
        if gene in treatment_genes:
            gene_groups[gene] = 'Treatment'
        elif gene in timepoint_genes:
            gene_groups[gene] = 'Timepoint'
        else:
            gene_groups[gene] = 'Other'
    
    # Validate gene group assignments
    group_counts = {group: 0 for group in ['Treatment', 'Timepoint', 'Other']}
    for gene, group in gene_groups.items():
        group_counts[group] += 1
    
    print(f"Genes assigned to groups:")
    for group, count in group_counts.items():
        print(f"  - {group}: {count} genes")
    
    # Reorder columns by group
    group_order = ['Treatment', 'Timepoint', 'Other']
    group_order_dict = {group: i for i, group in enumerate(group_order)}
    
    # Create a DataFrame with the group information for sorting
    column_group_df = pd.DataFrame({
        'gene': list(pivot_df.columns),
        'group': [gene_groups.get(gene, 'Other') for gene in pivot_df.columns]
    })
    
    # Sort by group first, then by gene
    column_group_df['group_order'] = column_group_df['group'].map(group_order_dict)
    column_group_df = column_group_df.sort_values(['group_order', 'gene'])
    
    # Reorder the columns in the heatmap_data
    ordered_genes = column_group_df['gene'].tolist()
    pivot_df = pivot_df[ordered_genes]
    
    # Sort samples by metadata factors
    sorted_df = pivot_df.copy()
    sorted_df = sorted_df.loc[metadata_df.sort_values(['Treatment', 'Timepoint']).set_index('Sample').index]
    
    # Extract row colors information
    row_colors = pd.DataFrame({
        'Treatment': metadata_df.set_index('Sample')['Treatment'],
        'Timepoint': metadata_df.set_index('Sample')['Timepoint']
    }).loc[sorted_df.index]
    
    # Define color maps
    # Define fixed colors for treatments and timepoints
    treatment_palette = dict(zip(
        sorted(row_colors['Treatment'].unique()), 
        sns.color_palette("Set1", len(row_colors['Treatment'].unique()))
    ))
    
    timepoint_palette = dict(zip(
        sorted(row_colors['Timepoint'].unique()), 
        sns.color_palette("Set2", len(row_colors['Timepoint'].unique()))
    ))
    
    # Define group colors
    group_colors = {
        'Treatment': "#1f77b4",  # Blue
        'Timepoint': "#ff7f0e",  # Orange
        'Other': "#7f7f7f"       # Gray
    }
    
    # Apply color palettes
    row_colors_mapped = pd.DataFrame({
        'Treatment': row_colors['Treatment'].map(treatment_palette),
        'Timepoint': row_colors['Timepoint'].map(timepoint_palette)
    })
    
    # Create column colors based on gene groups
    col_colors = pd.Series({
        gene: group_colors.get(gene_groups.get(gene, 'Other'), "#7f7f7f") 
        for gene in ordered_genes
    })
    
    # Create the plot
    print("Creating heatmap...")
    
    # Calculate optimal figure size
    fig_width = max(12, len(ordered_genes) * 0.3) + 2
    fig_height = max(8, len(sorted_df) * 0.3)
    
    # Create a new figure for the combined plot
    combined_fig = plt.figure(figsize=(fig_width, fig_height))
    
    # Calculate the width proportions
    row_colors_width = 0.05  # 5% for each row color band
    heatmap_width = 1 - (row_colors_width * 2) - 0.15  # 15% for legend
    
    # Create GridSpec
    gs = plt.GridSpec(1, 4, width_ratios=[row_colors_width, row_colors_width, heatmap_width, 0.15])
    
    # Create axes for each component
    treatment_ax = combined_fig.add_subplot(gs[0, 0])
    timepoint_ax = combined_fig.add_subplot(gs[0, 1])
    heatmap_ax = combined_fig.add_subplot(gs[0, 2])
    legend_ax = combined_fig.add_subplot(gs[0, 3])
    
    # Draw treatment colors
    for i, (idx, row) in enumerate(row_colors_mapped.iterrows()):
        treatment_ax.add_patch(plt.Rectangle((0, i), 1, 1, color=row['Treatment']))
    
    # Draw timepoint colors
    for i, (idx, row) in enumerate(row_colors_mapped.iterrows()):
        timepoint_ax.add_patch(plt.Rectangle((0, i), 1, 1, color=row['Timepoint']))
    
    # Set axes properties for color bands
    for ax, title in [(treatment_ax, 'Treatment'), (timepoint_ax, 'Timepoint')]:
        ax.set_xlim(0, 1)
        ax.set_ylim(0, len(sorted_df))
        ax.set_xticks([0.5])
        ax.set_xticklabels([title])
        ax.set_yticks([])
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)
    
    # Draw the heatmap in the main section
    sns.heatmap(
        sorted_df,
        cmap="viridis",
        ax=heatmap_ax,
        cbar_ax=legend_ax,
        cbar_kws={"label": "Expression Level"},
        xticklabels=False,   # Disable xtick labels (we'll add them manually)
        yticklabels=False    # Disable ytick labels (we'll add them manually)
    )
    
    # Remove all spines and ticks
    for spine in heatmap_ax.spines.values():
        spine.set_visible(False)
    
    heatmap_ax.tick_params(
        axis='both',
        which='both',
        bottom=False,
        top=False,
        left=False,
        right=False
    )
    
    # Add column colors at the top
    col_colors_ax = combined_fig.add_axes([
        heatmap_ax.get_position().x0, 
        heatmap_ax.get_position().y1, 
        heatmap_ax.get_position().width, 
        0.02
    ])
    
    # Draw the column color patches
    for i, gene in enumerate(sorted_df.columns):
        col_colors_ax.add_patch(plt.Rectangle(
            (i, 0), 
            1.0, 
            1.0, 
            color=col_colors.get(gene, "#7f7f7f")
        ))
    
    # Set column color axes properties
    col_colors_ax.set_xlim(0, len(sorted_df.columns))
    col_colors_ax.set_ylim(0, 1)
    col_colors_ax.set_xticks([])
    col_colors_ax.set_yticks([])
    col_colors_ax.spines['top'].set_visible(False)
    col_colors_ax.spines['right'].set_visible(False)
    col_colors_ax.spines['bottom'].set_visible(False)
    col_colors_ax.spines['left'].set_visible(False)
    
    # Add title to the combined figure
    combined_fig.suptitle(f"Gene Expression Heatmap ({len(ordered_genes)} genes)", 
                          fontsize=16, y=0.98)
    
    # Create a separate figure for the legend
    legend_fig = plt.figure(figsize=(3, 6))
    legend_ax = legend_fig.add_subplot(111)
    legend_ax.axis('off')
    
    # Create legend for all color elements
    legend_handles = []
    legend_labels = []
    
    # Add treatment items
    for label in sorted(row_colors['Treatment'].unique()):
        legend_handles.append(plt.Rectangle((0, 0), 1, 1, color=treatment_palette[label]))
        legend_labels.append(f"Treatment: {label}")
    
    # Add timepoint items
    for label in sorted(row_colors['Timepoint'].unique()):
        legend_handles.append(plt.Rectangle((0, 0), 1, 1, color=timepoint_palette[label]))
        legend_labels.append(f"Timepoint: {label}")
    
    # Add gene group items
    for group, color in group_colors.items():
        legend_handles.append(plt.Rectangle((0, 0), 1, 1, color=color))
        legend_labels.append(f"Gene Group: {group}")
    
    # Add the legend to the figure
    legend_ax.legend(
        legend_handles, 
        legend_labels, 
        loc='center', 
        frameon=True
    )
    
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save the figures
    combined_fig.savefig(f"{output_dir}/gene_expression_heatmap.png", 
                         dpi=300, bbox_inches='tight')
    legend_fig.savefig(f"{output_dir}/gene_expression_legend.png", 
                       dpi=300, bbox_inches='tight')
    
    plt.close('all')  # Close all figures to free memory
    
    print(f"Heatmap completed. Files saved to:")
    print(f"  - {output_dir}/gene_expression_heatmap.png")
    print(f"  - {output_dir}/gene_expression_legend.png")
    
    # Return the dataframes for reference
    return {
        'expression_pivot': pivot_df,
        'metadata': metadata_df,
        'gene_groups': gene_groups
    }

In [41]:
! pwd

/cluster/home/taekim/stressed_mice/jupyter_notebooks


In [42]:
# Example usage:
result = create_gene_expression_heatmap(
    '../data_expr/CRS_Morning_Evening_TPM_rearr.csv', 
    '../sig_genes/all_genes.txt', 
    '../sig_genes/treatment_only_genes.txt', 
    '../sig_genes/timepoint_only_genes.txt', 
    '../sig_genes/rest_genes.txt', 
    '../images/expr' )

Reading input files...
Data structure:
   Unnamed: 0                Gene  Expression_level           Sample
0           0  ENSMUSG00000000001         15.613475  01_Ctrl_morning
1           1  ENSMUSG00000000003          0.000000  01_Ctrl_morning
2           2  ENSMUSG00000000028          0.775004  01_Ctrl_morning
3           3  ENSMUSG00000000031          0.099366  01_Ctrl_morning
4           4  ENSMUSG00000000037          0.403567  01_Ctrl_morning
Columns: ['Unnamed: 0', 'Gene', 'Expression_level', 'Sample']
Detected columns:
  - Gene column: Gene (type: object)
  - Expression column: Expression_level (type: float64)
  - Sample column: Sample (type: object)
Gene group statistics:
  - All genes: 65 genes
  - Treatment genes: 29 genes
  - Timepoint genes: 21 genes
  - Rest genes: 15 genes
Processing 65 genes...
First few missing genes: ['ENSMUSG00000106018', 'ENSMUSG00000102642']
Sample metadata:
  - Treatment values: ['CRS', 'Ctrl']
  - Timepoint values: ['evening', 'morning']
Genes as