In [None]:
import os
import requests
import tarfile
import gzip
import shutil
import pandas as pd

# Optional DESeq2 import
try:
    #from pyDESeq2 import py_DESeq2
    from pydeseq2.dds import DeseqDataSet
    from pydeseq2.default_inference import DefaultInference
    from pydeseq2.ds import DeseqStats
    from sklearn.preprocessing import LabelEncoder
    use_deseq2 = True
except ImportError:
    print("⚠️ pyDESeq2 not found. Skipping DEG step.")
    use_deseq2 = False


def download_and_extract_gse(gse_id="GSE242272", base_dir="rna_seq_analysis"):
    download_dir = os.path.join(base_dir, gse_id)
    os.makedirs(download_dir, exist_ok=True)

    url = f"https://www.ncbi.nlm.nih.gov/geo/download/?acc={gse_id}&format=file"
    tar_path = os.path.join(download_dir, f"{gse_id}_supplement.tar")

    print(f"📦 Downloading: {url}")
    try:
        response = requests.get(url, stream=True)
        response.raise_for_status()
        with open(tar_path, "wb") as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"✅ Downloaded to: {tar_path}")
    except Exception as e:
        print(f"❌ Download failed: {e}")
        return None

    print(f"🗂️ Extracting TAR...")
    try:
        with tarfile.open(tar_path, "r:*") as tar:
            tar.extractall(path=download_dir)
        print(f"✅ Extracted contents to: {download_dir}")
    except Exception as e:
        print(f"❌ Extraction failed: {e}")
        return None

    return download_dir


def decompress_gz_files(root_dir: str):
    print(f"🔍 Decompressing .txt.gz files...")
    for dirpath, _, filenames in os.walk(root_dir):
        for filename in filenames:
            if filename.endswith(".txt.gz"):
                gz_path = os.path.join(dirpath, filename)
                txt_path = gz_path[:-3]
                if os.path.exists(txt_path):
                    print(f"⏭️ Skipping: {txt_path} already exists.")
                    continue
                try:
                    with gzip.open(gz_path, 'rb') as f_in, open(txt_path, 'wb') as f_out:
                        shutil.copyfileobj(f_in, f_out)
                    print(f"✅ Decompressed: {gz_path}")
                except Exception as e:
                    print(f"❌ Failed to decompress {gz_path}: {e}")

from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats
def run_deg_analysis_from_gse(gse_id="GSE242272", base_dir="rna_seq_analysis"):
    gse_dir = os.path.join(base_dir, gse_id)
    output_csv = os.path.join(base_dir, f"{gse_id}_deseq2_results.csv")

    count_dfs = []
    sample_conditions = []

    print(f"\n🔍 Searching GSM folders in: {gse_dir}")
    for fname in os.listdir(gse_dir):
        if fname.endswith(".txt") and not fname.endswith(".txt.gz"):
            file_path = os.path.join(gse_dir, fname)
            sample_id = fname.split("_")[0]
            try:
                df = pd.read_csv(file_path, sep="\t", header=None, names=["gene", sample_id])
                df.set_index("gene", inplace=True)
                count_dfs.append(df)

                label = "treated" if any(x in fname.lower() for x in ["treated", "pge2", "caffeine"]) else "control"
                sample_conditions.append((sample_id, label))
                print(f"  ✅ Loaded: {fname} as {label}")
            except Exception as e:
                print(f"⚠️ Failed to read {file_path}: {e}")

    if not count_dfs:
        raise ValueError("❌ No valid count files found.")

    combined_counts = pd.concat(count_dfs, axis=1).fillna(0).astype(int)

    print(f"\n✅ Count matrix shape: {combined_counts.shape}")
    print(combined_counts.iloc[:5, :5])

    # Save raw counts
    raw_csv = os.path.join(base_dir, f"{gse_id}_raw_counts.csv")
    combined_counts.to_csv(raw_csv)
    print(f"📄 Saved raw count matrix to: {raw_csv}")

    if not use_deseq2:
        return combined_counts

    # Prepare design metadata
    sample_df = pd.DataFrame(sample_conditions, columns=["sample", "condition"])
    sample_df.set_index("sample", inplace=True)

    print("\n🚀 Running DESeq2...")
    dds = py_DESeq2(
        count_matrix=combined_counts.T,
        design_matrix=sample_df,
        design_formula="~ condition",
        gene_column="gene"
    )
    dds.run_deseq()
    res = dds.get_deseq_result()
    res_sorted = res.sort_values("padj").dropna().head(50)
    res_sorted.to_csv(output_csv, index=False)
    print(f"📊 DESeq2 top results saved to: {output_csv}")

    return res_sorted


In [None]:
# Updated import using correct PyDESeq2 structure
from pydeseq2.dds import DeseqDataSet
from pydeseq2.default_inference import DefaultInference
from pydeseq2.ds import DeseqStats

def run_deseq2_manual(counts_df, metadata_df, contrast=["condition", "treated", "control"], output_path="rna_seq_analysis"):
    # Filter out samples with missing conditions
    print("📋 Initial metadata:")
    print(metadata_df)

    samples_to_keep = ~metadata_df.condition.isna()
    counts_df = counts_df.loc[samples_to_keep]
    metadata_df = metadata_df.loc[samples_to_keep]

    # Filter genes with low counts
    genes_to_keep = counts_df.columns[counts_df.sum(axis=0) >= 10]
    counts_df = counts_df[genes_to_keep]

    # Set up DESeq2
    inference = DefaultInference(n_cpus=1)
    dds = DeseqDataSet(
        counts=counts_df,
        metadata=metadata_df,
        design_factors="condition",
        refit_cooks=True,
        inference=inference,
    )

    dds.deseq2()
    print("📊 LFC matrix:")
    print(dds.varm["LFC"].head())

    # Run differential expression stats
    ds = DeseqStats(dds, contrast=contrast, inference=inference)
    ds.summary()

    results_path = os.path.join(output_path, "results.csv")
    ds.results_df.to_csv(results_path)
    print(f"✅ DESeq2 results saved to: {results_path}")

    return ds.results_df
# Load real data saved earlier
gse_id = "GSE242272"
base_dir = "rna_seq_analysis"
counts_path = os.path.join(base_dir, f"{gse_id}_raw_counts.csv")
metadata_path = os.path.join(base_dir, f"{gse_id}_sample_metadata.csv")

# Read count matrix
counts_df = pd.read_csv(counts_path, index_col=0)

# Read or reconstruct metadata (if not saved separately, recreate it from column names)
sample_names = counts_df.index.tolist()
conditions = ["treated" if any(x in name.lower() for x in ["pge2", "caffeine", "treated"])
              else "control" for name in sample_names]
metadata_df = pd.DataFrame({"condition": conditions}, index=sample_names)

# Run DESeq2 analysis on real data
results_df = run_deseq2_manual(counts_df, metadata_df, contrast=["condition", "treated", "control"], output_path=base_dir)
results_df.head()