In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import pybedtools
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from pathlib import Path
from dotenv import load_dotenv
import pickle
import re

load_dotenv()


In [None]:
cistopic_obj = pickle.load(open('/mnt/windows/extradata/meiotic_cells/atac_preprocessing/cistopic_obj.pkl', "rb"))


In [None]:
DATA_PATH = Path(os.getenv("DATA_PATH"))/'garcia_ATAC'
OUTPUT_PATH = Path(os.getenv("OUTPUT_PATH"))/'garcia_ATAC'

#cistopic_obj = pickle.load(open(OUTPUT_PATH / 'atac_preprocessing/cistopic_obj.pkl', "rb"))

In [None]:
has_celltype = cistopic_obj.cell_data[~cistopic_obj.cell_data.celltype.isna()].index.tolist()
cistopic_obj = cistopic_obj.subset(cells=has_celltype, copy = True)

In [None]:
fragment_matrix = cistopic_obj.fragment_matrix 
cell_types = np.where(cistopic_obj.cell_data['celltype'].isin(['oogonia_STRA8', 'oogonia_meiotic']), "meiotic", "non-meiotic")
#cell_types = "All"
peaks = cistopic_obj.region_names
peaks_df = pd.DataFrame([p.replace(':', '-').split('-') for p in peaks], 
                       columns=['chr', 'start', 'end'])

peaks_bed = pybedtools.BedTool.from_dataframe(peaks_df)
peaks_bed.head()

In [None]:
promoter_bed_file = DATA_PATH / 'feature_annotation/Hs_EPDnew_006_hg38_900up400down.bed''
# Try with this too: Hs_EPDnew_006_hg38.bed
# TSSs are just a single point here. 
# Load bed file
promoter_bed = pybedtools.BedTool(promoter_bed_file)
promoter_bed.head()


In [None]:
gtf_file = "/home/bdobre/resources/gencode.v47.basic.annotation.gtf"

# Columns in a GTF file
columns = ["chrom", "source", "feature", "start", "end", "score", "strand", "frame", "attribute"]

# Read the GTF, skipping comment lines beginning with '#'
df = pd.read_csv(
    gtf_file,
    sep="\t",
    comment="#",
    names=columns,
    low_memory=False
)

# Keep only rows corresponding to genes 
df = df[df["feature"] == "gene"].copy()

# Function to pick the TSS from start/end depending on strand
def get_tss(row):
    # For + strand, TSS = 'start'
    # For - strand, TSS = 'end'
    return row["start"] if row["strand"] == "+" else row["end"]

df["tss"] = df.apply(get_tss, axis=1)

# Extract transcript_id or gene_id from the attribute column using a regex
def get_id(attr_str, key="gene_name"):
    match = re.search(f'{key} "([^"]+)"', attr_str)
    return match.group(1) if match else "NA"

df["name"] = df["attribute"].apply(lambda x: get_id(x, "gene_name"))

# Build a BED DataFrame
# BED format: [chrom, start, end, name, score, strand]
bed = pd.DataFrame({
    "chrom": df["chrom"],
    "start": df["tss"],
    "end":   df["tss"],
    "name":  df["name"],
    "score": 0,
    "strand": df["strand"]
})
gene_bed = pybedtools.BedTool.from_dataframe(bed)
gene_bed.head()

In [None]:
# Find closest promoter
closest_promoter_bed = peaks_bed.sort().closest(gene_bed.sort(), t='first')
closest_promoter_bed.head()

In [None]:
closest_df = closest_promoter_bed.to_dataframe()
closest_df.head()

If > 200k label it as intergenic

In [None]:
distances = closest_df.iloc[:, 4] - 0.5*(closest_df.iloc[:, 1] + closest_df.iloc[:, 2])
distances

In [None]:
sns.histplot(distances)

In [None]:
np.sum(np.abs(distances) > 3000) / len(distances)

In [None]:
window_size=3000# Try 200k
in_range_mask = (distances > -window_size) & (distances < window_size)

# Filter your data
filtered_fragment_matrix = fragment_matrix[in_range_mask, :]  # pick those rows/peaks
filtered_distances = distances[in_range_mask]
closest_df_filtered = closest_df[in_range_mask]
sns.histplot(filtered_distances)

In [None]:
genes = closest_df_filtered.iloc[:, 6]
genes

In [None]:
from collections import defaultdict

gene_to_peak_indices = defaultdict(list)
for i, g in enumerate(genes):
    gene_to_peak_indices[g].append(i)


In [None]:
unique_genes = np.unique(genes)
n_genes = len(unique_genes)
n_genes

In [None]:
from scipy.ndimage import gaussian_filter1d

bins=100

# Create distance bins
bin_edges = np.linspace(-window_size, window_size, bins + 1)
bin_centers = (bin_edges[:-1] + bin_edges[1:]) / 2
dist_bin = np.digitize(filtered_distances, bin_edges)

# Initialize figure
fig = plt.figure(figsize=(15, 10))
gs = plt.GridSpec(2, 3, height_ratios=[1, 3])

# Process each cell type
unique_cell_types = np.unique(cell_types)

for idx, cell_type in enumerate(unique_cell_types):
    # Get cells of this type
    cell_mask = cell_types == cell_type
    submatrix_all_peaks = filtered_fragment_matrix[:, cell_mask]
    peak_accessibility = submatrix_all_peaks.mean(axis=1)    
    # Calculate profile
    profile = np.zeros(len(bin_centers))
    for b_i in range(1, bins + 1):  
        # find peaks that fall in this bin
        in_bin = (dist_bin == b_i)
        if np.any(in_bin):
            profile[b_i - 1] = peak_accessibility[in_bin].mean()
        else:
            profile[b_i - 1] = 0.0

    # Smooth the profile
    profile_smooth = gaussian_filter1d(profile, sigma=2.0)
    
    # Plot line graph
    ax_top = fig.add_subplot(gs[0, idx])
    ax_top.plot(bin_centers, profile_smooth)
    ax_top.set_title(f'Cell Type: {cell_type}')
    ax_top.set_xlabel('Distance from TSS (bp)')
    ax_top.set_ylabel('Average Accessibility')
    
    # Plot heatmap
    ax_bottom = fig.add_subplot(gs[1, idx])
    
    # Create distance bins for heatmap
    heatmap_data = np.zeros((n_genes, bins))

    celltype_submatrix = fragment_matrix[:, cell_mask]

    for g_idx, g in enumerate(unique_genes):
        peak_inds = gene_to_peak_indices[g]  
        if len(peak_inds) == 0:
            # Just skip if no peaks for this gene
            continue  
        submat_gene = submatrix_all_peaks[peak_inds, :]  
        gene_peak_access = submat_gene.mean(axis=1)
    
        gene_peak_distbins = dist_bin[peak_inds]
        
        # For each bin, average accessibility
        for b_i in range(1, bins + 1):
            in_bin = (gene_peak_distbins == b_i)
            if np.any(in_bin):
                heatmap_data[g_idx, b_i - 1] = gene_peak_access[in_bin].mean()
            else:
                heatmap_data[g_idx, b_i - 1] = 0.0

    # sort by the maximum accessibility in each row
    row_max = heatmap_data.max(axis=1)
    sort_idx = np.argsort(row_max)[::-1]    # descending order
    heatmap_data = heatmap_data[sort_idx, :]
        
    sns.heatmap(heatmap_data, cmap='YlOrRd', ax=ax_bottom, vmax = 0.02, yticklabels=False)
    ax_bottom.set_xlabel('Distance from TSS (bp)')
    ax_bottom.set_ylabel('Genes')
    n_labels = 7
    tick_positions = np.linspace(0, bins - 1, n_labels)
    tick_labels = np.linspace(-window_size, window_size, n_labels, dtype=int)
    ax_bottom.set_xticks(tick_positions)
    ax_bottom.set_xticklabels(tick_labels)

plt.tight_layout()