In [1]:
import os
import numpy as np
import pytorch_lightning as pl
import torch.nn as nn
import torch
from data import LabeledDataset
from torch.utils.data import DataLoader
from CFS_SG import CFS_SG

In [2]:
import matplotlib
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from scipy.linalg import svd
from scipy.linalg import pinv
from itertools import product
import numpy as np, h5py, os
import matplotlib.pyplot as plt
from operator import itemgetter 
from scipy.sparse import vstack, coo_matrix, csc_matrix, isspmatrix_csc
%matplotlib inline
import scanpy as sc
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler



In [3]:
import pandas as pd
import gzip
from anndata import AnnData
import scanpy as sc
import os
import requests

def download_binary_file(file_url: str, output_path: str) -> None:
    """
    Download binary data file from a URL.

    Args:
    ----
        file_url: URL where the file is hosted.
        output_path: Output path for the downloaded file.

    Returns
    -------
        None.
    """
    request = requests.get(file_url)
    with open(output_path, "wb") as f:
        f.write(request.content)
    print(f"Downloaded data from {file_url} at {output_path}")

    

def download_haber_2017(output_path: str) -> None:
    """
    Download Haber et al. 2017 data from the hosting URLs.

    Args:
    ----
        output_path: Output path to store the downloaded and unzipped
        directories.

    Returns
    -------
        None. File directories are downloaded to output_path.
    """

    url = (
        "https://ftp.ncbi.nlm.nih.gov/geo/series/GSE92nnn/GSE92332/suppl/GSE92332"
        "_SalmHelm_UMIcounts.txt.gz"
    )

    output_filename = os.path.join(output_path, url.split("/")[-1])

    download_binary_file(url, output_filename)
def read_haber_2017(file_directory: str) -> pd.DataFrame:
    """
    Read the expression data from Haber et al. 2017 given the directory.

    Args:
    ----
        file_directory: Directory containing Haber et al. 2017 data.

    Returns
    -------
        A DataFrame containing single-cell gene expression counts, with cell
        identification barcodes as column names and gene IDs as indices.
    """

    # Path to the file
    file_path = os.path.join(file_directory, "GSE92332_SalmHelm_UMIcounts.txt.gz")
    
    # Read the .txt.gz file, assuming the first column contains gene names
    with gzip.open(file_path, "rt") as f:
        df = pd.read_csv(f, sep="\t", index_col=0)

    return df

def preprocess_haber_2017(download_path: str, n_top_genes: int) -> (AnnData, list):
    """
    Preprocess expression data from Haber et al. 2017.

    Args:
    ----
        download_path: Path containing the downloaded Haber et al. 2017 data file.
        n_top_genes: Number of most variable genes to retain.

    Returns
    -------
        An AnnData object containing single-cell expression data. The layer
        "count" contains the count data for the most variable genes. The X
        variable contains the total-count-normalized and log-transformed data
        for the most variable genes (a copy with all the genes is stored in
        .raw).
        A list of conditions associated with each cell.
    """

    # Read the expression data
    df = read_haber_2017(download_path)
    
    # Transpose the dataframe so that genes become columns and cells are rows
    df = df.transpose()

    # Extract metadata from the cell names
    cell_groups = []
    barcodes = []
    conditions = []
    cell_types = []
    cell_names = []

    for cell in df.index:
        try:
            cell_group, barcode, condition, cell_type = cell.split("_")
            cell_groups.append(cell_group)
            barcodes.append(barcode)
            conditions.append(condition)
            cell_types.append(cell_type)
            cell_names.append(cell)
        except ValueError:
            print(f"Error parsing cell name: {cell}")
            continue

    # Create a DataFrame for metadata
    metadata_df = pd.DataFrame(
        {
            "cell_group": cell_groups,
            "barcode": barcodes,
            "condition": conditions,
            "cell_type": cell_types,
        }, index=cell_names  # Ensure the cell names are set as the index
    )

    # Create the AnnData object with gene names in 'var' and cell metadata in 'obs'
    adata = AnnData(X=df.values, obs=metadata_df, var=pd.DataFrame(index=df.columns))

    # Preprocess the data: normalize, log-transform, and keep most variable genes
    adata = adata[adata.obs["condition"] != "Hpoly.Day3"]
    adata.layers["count"] = adata.X.copy()

    # Normalize and log-transform
    sc.pp.normalize_total(adata)
    sc.pp.log1p(adata)
    adata.raw = adata

    # Select highly variable genes
    sc.pp.highly_variable_genes(
        adata, flavor="seurat_v3", n_top_genes=n_top_genes, layer="count", subset=True
    )

    # Remove cells with all zero counts
    adata = adata[adata.layers["count"].sum(1) != 0]  # Remove cells with all zeros.

    # Return the AnnData object and the conditions list
    return adata, conditions


In [4]:
root_data_path = ""   # local computer data path to downloaad the dataset. 

download_haber_2017(root_data_path)

Downloaded data from https://ftp.ncbi.nlm.nih.gov/geo/series/GSE92nnn/GSE92332/suppl/GSE92332_SalmHelm_UMIcounts.txt.gz at GSE92332_SalmHelm_UMIcounts.txt.gz


In [5]:
import numpy as np
data, conditions = preprocess_haber_2017(root_data_path, 1000)

  adata.layers["count"] = adata.X.copy()


In [6]:
foreground = data[data.obs["condition"] != "Control"]
k, p = foreground.shape


background = data[data.obs["condition"] == "Control"]
m, p = background.shape



In [7]:
foreground = foreground.to_df()
background = background.to_df()
genes = foreground.columns.to_numpy()


In [8]:
# Label here determines target versus background
labels_train = np.concatenate([np.zeros(background.shape[0]), np.ones(foreground.shape[0])])
data_train = np.concatenate([background, foreground])

# Ensure your data is converted to the correct dtype outside the dataset class
data_train = torch.from_numpy(data_train).float()  # Convert data to torch float32
labels_train = torch.from_numpy(labels_train).float()  # Convert labels to torch float32

# Create the dataset as before
dataset = LabeledDataset(data_train.numpy(), labels_train.numpy())


In [9]:
input_size = foreground.shape[1]
output_size = background.shape[1]
batch_size = 128

In [10]:
model = CFS_SG(
    input_size=input_size,
    output_size=output_size,
    hidden=[512, 512], # Number of units in each hidden layer
    k_prime=20, # Background dimension size
    lam=0.15, # Tuned to select about 10 features
    lr=1e-3,
    loss_fn=nn.MSELoss()
)


In [11]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)



In [12]:
trainer = pl.Trainer(max_epochs=10, accelerator='gpu', devices=1)
trainer.fit(model, loader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/nas/longleaf/home/eyzhang/.local/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA A100-PCIE-40GB MIG 2g.10gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.htm

Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [19]:
indices = model.get_inds(10) 

In [20]:
genes[indices]

array(['H2.Aa', 'Cd74', 'Ang4', 'H2.Ab1', 'Ifitm3', 'Uqcrb', 'S100a6',
       'Mt2', 'Fabp6', 'Reg1'], dtype=object)