# Bulk RNA-Seq: Cohort Selection and Data Retrieval

This notebook demonstrates how to use the **Genestack ODM API** to access and explore bulk RNA-seq data stored in an ODM instance. It explains how to configure the API connection, retrieve metadata and data for selected entities, and interpret the returned results in a reproducible, programmable way.
The notebook is organized into three main parts:
* **Prerequisites** – loads the required Python libraries and helper functions. This section can be minimized when running the notebook end-to-end if all dependencies are already installed.
* **ODM API Configuration** – an interactive setup for establishing a secure connection to your ODM instance using an API token.
* **Working with Data** – examples of typical ODM API endpoints for metadata and data retrieval (including multi-omics), with explanations of the API response structure and its relevance for downstream analysis.

## 1. Prerequisites

Before running the notebook, make sure your environment is ready. You will need Python 3.10+ and `pip`. Install all dependencies with:
```
pip install odm-sdk numpy pandas matplotlib seaborn scipy ipywidgets ipykernel requests
```


### 1.1 Imports

In [None]:
# standard library (come with Python)
import os
import re
import json
import time
import warnings
from getpass import getpass
from io import StringIO

# third-party (need installation)
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
import ipywidgets as widgets
from IPython.display import display
from scipy.stats import spearmanr, pearsonr
import requests
import odm_api

# set default matplotlib style
plt.style.use('default')

# set warnings to ignore FutureWarning
warnings.filterwarnings('ignore', category=FutureWarning)

### 1.2 Functions

This section defines utility functions used across the notebook to streamline interaction with the ODM API and to support visualization of retrieved data. Collecting them here keeps the workflow sections concise and focused on analysis rather than implementation details.

In [None]:
def set_api_credentials(odm_url, api_prefix='/api/v1'):
    """
    Set ODM API credentials interactively, using getpass or fallback to widget-based UI.

    Attempts to use `getpass` to prompt for an API token (works in terminal environments).
    If that fails (e.g., in JupyterLab or web-based notebooks), it falls back to a widget-based input form.

    Parameters
    ----------
    odm_url : str
        The ODM server URL, e.g. 'https://q001-demo.trial.genestack.com/'
    api_prefix : str, optional
        The API endpoint prefix (default is '/api/v1').

    Sets global variables
    ------------
        odm_base_url : str
            The base ODM API URL with prefix.
        token : str
            The provided API authentication token.
    """
    # ensure API prefix is provided with the URL
    if not re.search(r'/api.+', odm_url):
        base_url = odm_url + api_prefix

    try:
        # enter token via getpass (works in terminal-based environments)
        global odm_base_url, token
        token = getpass("Auth Token: ")
        odm_base_url = base_url

    except (EOFError, OSError):
        # fallback to widget-based input (works in web environments)
        set_api_credentials_ui(base_url)


def set_api_credentials_ui(base_url):
    """
    Displays widgets for ODM API server and token selection.

    Args:
        base_url (str, optional): Default server URL for ODM with API prefix (e.g. /api/v1).
        If not provided, uses 'ODM_BASE_URL' env variable or default.
    """
    odm_base_url_widget = widgets.Text(
        value=base_url if base_url is not None else os.getenv('ODM_BASE_URL', ''),
        description='Base URL:',
        layout=widgets.Layout(width='600px')
    )
    token_widget = widgets.Password(
        value=os.getenv('ODM_API_TOKEN', ''),
        description='Auth Token:',
        layout=widgets.Layout(width='400px')
    )
    set_button = widgets.Button(description='Set Credentials', button_style='primary')
    status_html = widgets.HTML()

    def _set_credentials(_):
        global odm_base_url, token
        odm_base_url = odm_base_url_widget.value.strip()
        token = token_widget.value.strip()
        masked = ('***' if not token else (token[:4] + '…' + token[-4:] if len(token) >= 8 else '***'))
        status_html.value = f"<span style='color: green;'>Credentials set. Token: {masked}</span>"

    set_button.on_click(_set_credentials)
    display(widgets.VBox([odm_base_url_widget, token_widget, set_button, status_html]))


def plot_grouped_violin(
    combined_data,
    group_by='Disease',
    preferred=None,
    log1p=True,
    figsize_scale=0.6,
    min_figsize=12,
    height=6
):
    """
    Plot grouped violin + strip plot, optionally log1p of expression,
    grouped by an inferred or specified metadata column.

    Parameters:
        combined_data (pd.DataFrame): DataFrame with expression data and sample metadata.
        group_by (str): Initial column to try as grouping variable (default: 'Disease').
        preferred (list): Optional list of preferred columns to scan as grouping variables.
        log1p (bool): Whether to plot log1p(expression). Default: True.
        figsize_scale (float): Scale for figure width by number of genes.
        min_figsize (int): Minimum figure width.
        height (float): Figure height.
    """

    sns.set(style="whitegrid", context="notebook")

    plot_df = combined_data.copy()

    # helper to normalize complex values (lists/dicts) to strings for grouping
    def _norm_group(v):
        if isinstance(v, (list, tuple, set)):
            return ', '.join(map(str, v))
        if isinstance(v, dict):
            return json.dumps(v, sort_keys=True)
        return v

    # determine/normalize a usable grouping column
    if preferred is None:
        preferred = ['Disease', 'Tissue', 'Cell Type', 'Project', 'Batch', 'Sex', 'Age']
    chosen = None
    if group_by in plot_df.columns:
        s = plot_df[group_by]
        if np.issubdtype(s.dtype, np.number):
            if s.nunique(dropna=True) >= 2:
                chosen = group_by
        else:
            s2 = s.map(_norm_group)
            if 2 <= s2.nunique(dropna=True) <= 10:
                plot_df[group_by] = s2
                chosen = group_by
    # fallback: scan preferred list
    if not chosen:
        for c in preferred:
            if c not in plot_df.columns:
                continue
            s = plot_df[c]
            if np.issubdtype(s.dtype, np.number):
                if s.nunique(dropna=True) >= 2:
                    bname = f'{c}_binned'
                    plot_df[bname] = pd.cut(
                        s,
                        bins=[-np.inf, 40, 50, 60, 70, 80, np.inf],
                        labels=['<=40', '40-50', '50-60', '60-70', '70-80', '80+']
                    )
                    chosen = bname
                    break
            else:
                s2 = s.map(_norm_group)
                if 2 <= s2.nunique(dropna=True) <= 10:
                    plot_df[c] = s2
                    chosen = c
                    break
    group_by = chosen

    # transform and ordering
    if log1p:
        y_col = 'log1p_value'
        plot_df[y_col] = np.log1p(plot_df['value'])
        ylabel = 'log1p(Expression)'
        title = 'Expression of gene selection (log1p)'
    else:
        y_col = 'value'
        ylabel = 'Expression'
        title = 'Expression of gene selection'

    order = (
        plot_df.groupby('gene')['value']
        .median()
        .sort_values(ascending=False)
        .index
    )

    plt.figure(figsize=(max(min_figsize, figsize_scale * len(order)), height))
    ax = sns.violinplot(
        data=plot_df, x='gene', y=y_col, order=order,
        hue=group_by, dodge=True, inner=None, cut=0, density_norm='width'
    )
    sns.stripplot(
        data=plot_df, x='gene', y=y_col, order=order,
        hue=group_by, dodge=True, size=3, jitter=0.25, alpha=0.6,
        linewidth=0, edgecolor=None
    )
    ax.set(xlabel='Gene', ylabel=ylabel, title=title)
    plt.xticks(rotation=45, ha='right')
    # clean up duplicate legends from violin+strip
    if group_by:
        if ax.legend_:
            ax.legend_.remove()
        plt.legend(title=group_by, bbox_to_anchor=(1.02, 1), loc='upper left')
    else:
        if ax.legend_:
            ax.legend_.remove()
    plt.tight_layout()
    plt.show()


def plot_grouped_heatmap(
    combined_data,
    group_by="Disease",
    value_col='value',
    gene_col='gene',
    sample_col="genestack:accession",
    cmap='vlag',
    log1p=True
):
    """
    Plot enhanced grouped heatmap: genes as rows, samples as columns (optionally group columns/colors).
    Shows gene names on the left side of the heatmap.

    Parameters
    ----------
    combined_data : pd.DataFrame
        DataFrame with gene expression data. Must contain gene_col, sample_col, value_col, and group_by.
    group_by : str
        Column to use for column color grouping (e.g., 'Disease').
    value_col : str
        Column with expression values.
    gene_col : str
        Column with gene names.
    sample_col : str
        Column with sample accessions.
    cmap : str
        Colormap to use for the heatmap.
    log1p : bool
        If True, log1p transform the expression values before visualization.
    """

    plot_df = combined_data.copy()

    def _norm_group(v):
        if isinstance(v, (list, tuple, set)):
            return ', '.join(map(str, v))
        if isinstance(v, dict):
            return json.dumps(v, sort_keys=True)
        return v

    # pivot and normalize matrix: genes as rows, samples as columns
    mat = plot_df.pivot_table(index=gene_col, columns=sample_col, values=value_col, aggfunc='median')

    if log1p:
        mat_proc = np.log1p(mat)
    else:
        mat_proc = mat

    mat_z = (mat_proc - mat_proc.mean(axis=1).values[:, None]) / mat_proc.std(axis=1, ddof=0).values[:, None]
    rows = mat_proc.median(axis=1).sort_values(ascending=False).index
    mat_z = mat_z.loc[rows]

    col_colors = None
    legend_handles = None
    palette = None

    # column colors by group
    if group_by and group_by in plot_df.columns:
        plot_df[group_by] = plot_df[group_by].map(_norm_group)
        meta = (
            plot_df[[sample_col, group_by]]
            .drop_duplicates().set_index(sample_col)
            .reindex(mat_z.columns)
        )
        levels = meta[group_by].astype('category').cat.categories
        palette = dict(zip(levels, sns.color_palette('Set2', n_colors=len(levels))))
        col_colors = meta[group_by].map(palette)
        mat_z = mat_z.loc[:, meta.sort_values(group_by).index]  # sort columns (samples) by group
        # prepare legend handles
        legend_handles = [
            mpatches.Patch(color=palette[level], label=str(level)) for level in levels
        ]

    # create clustermap: genes as rows, samples as columns
    g = sns.clustermap(
        mat_z,
        cmap=cmap,
        center=0,
        row_cluster=False,
        col_cluster=False,
        col_colors=col_colors,
        cbar_kws={'label': 'Z-score (per gene)'},
        figsize=(max(8, 0.22 * len(mat_z.columns)), max(8, 0.4 * len(rows))),
        linewidths=0.1,
        xticklabels=True,
        yticklabels=True,
        dendrogram_ratio=(.08, .02),  # row dendrogram larger, as genes are now rows
        cbar_pos=(0.94, 0.7, 0.02, 0.18)
    )
    g.ax_heatmap.set_xlabel('Sample', fontsize=13, fontweight='bold')
    g.ax_heatmap.set_ylabel('Gene', fontsize=13, fontweight='bold')
    g.ax_heatmap.tick_params(axis='x', rotation=90, labelsize=10)
    g.ax_heatmap.tick_params(axis='y', labelsize=10, labelleft=True, labelright=False)  # only show tick labels on the left

    # add better title and annotation
    if log1p:
        title = f'Expression heatmap for gene selection\n(log1p, per-gene z-score)'
    else:
        title = f'Expression heatmap for gene selection\n(per-gene z-score)'
    if group_by and group_by in plot_df.columns:
        title += f'\nGrouped by: {group_by}'
    g.ax_heatmap.set_title(title, fontsize=15, fontweight='bold', pad=18)

    # move cbar for better visibility
    g.cax.set_position([.93, .18, .02, .55])

    # show group (disease) legend at the top right of the plot, but do not show the title
    if legend_handles is not None:
        # place legend at top right of the figure, outside plot area
        g.ax_heatmap.legend(
            handles=legend_handles,
            title=None,
            loc='upper right',
            bbox_to_anchor=(1.12, 1.01),  # (x, y) as fraction of axes; adjust as needed
            borderaxespad=0.0,
            fontsize='small',
            frameon=False,
            ncol=1
        )

    # show gene names on the left side only: hide any right-side y-tick labels
    g.ax_heatmap.yaxis.tick_left()
    g.ax_heatmap.yaxis.set_label_position("left")
    g.ax_heatmap.yaxis.set_ticks_position("left")
    g.ax_heatmap.tick_params(axis='y', which='both', labelleft=True, labelright=False)

    plt.show()


def plot_correlation_scatter(data, title, method='pearson', color_by=None, width=None, height=None):
    """
    Create a scatter plot with regression line and correlation statistics,
    with optional color annotation.

    Parameters
    ----------
    data : pd.DataFrame
        DataFrame with two numeric columns (e.g., 'transcript' and 'protein').
    title : str
        Title for the plot.
    method : str, optional
        Correlation method to use. Either 'pearson' or 'spearman' (default: 'pearson').
    color_by : str, optional
        Column name in `data` to color points by. May be numeric or categorical.
    width : float, optional
        Width of the figure in inches. If not provided, default figure size is used.
    height : float, optional
        Height of the figure in inches. If not provided, default figure size is used.

    Returns
    -------
    None
        Displays the plot.

    Examples
    --------
    >>> import pandas as pd
    >>> import numpy as np
    >>> df = pd.DataFrame({
    ...     'transcript': np.random.randn(100),
    ...     'protein': np.random.randn(100),
    ...     'disease': np.random.choice(['NASH', 'Healthy'], size=100)
    ... })
    >>> plot_correlation_scatter(df, 'Gene_A', 'spearman', color_by='disease')
    """
    # use first two numeric columns
    numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
    if len(numeric_cols) < 2:
        raise ValueError("DataFrame must contain at least two numeric columns.")
    x_col = numeric_cols[0]
    y_col = numeric_cols[1]

    # extract data and remove NaN values (in x, y, and possibly color_by)
    cols_to_check = [x_col, y_col]
    if color_by is not None:
        cols_to_check.append(color_by)
    mask = ~data[cols_to_check].isna().any(axis=1)
    clean_data = data.loc[mask]

    x_clean = clean_data[x_col].values
    y_clean = clean_data[y_col].values

    if len(x_clean) < 3:
        raise ValueError("Not enough valid data points (need at least 3) for correlation calculation.")

    # calculate correlation
    if method.lower() == 'pearson':
        corr_coef, p_value = pearsonr(x_clean, y_clean)
    elif method.lower() == 'spearman':
        corr_coef, p_value = spearmanr(x_clean, y_clean)
    else:
        raise ValueError("method must be either 'pearson' or 'spearman'.")

    # handle color_by logic
    color_vals = None
    color_legend_info = None
    from matplotlib.lines import Line2D

    if color_by is not None and color_by in clean_data.columns:
        color_data = clean_data[color_by]
        if pd.api.types.is_numeric_dtype(color_data):
            # Numeric, use colormap
            cmap = plt.cm.viridis
            color_vals = color_data.values
            sc_kwargs = dict(c=color_vals, cmap=cmap)
            color_legend_info = 'numeric'
        else:
            # Categorical
            unique_cats = color_data.unique()
            n_categories = len(unique_cats)
            # Use tab10 or another colormap
            cmap = plt.get_cmap('tab10') if n_categories <= 10 else plt.get_cmap('tab20')
            color_map = {cat: cmap(i % cmap.N) for i, cat in enumerate(unique_cats)}
            color_vals = color_data.map(color_map)
            sc_kwargs = dict(c=color_vals)
            color_legend_info = (unique_cats, color_map)
    else:
        sc_kwargs = {}

    # create scatter plot with appropriate figure size and tight layout usage
    if width is not None and height is not None:
        plt.figure(figsize=(width, height))
        use_tight_layout = False
    else:
        plt.figure()
        use_tight_layout = True

    ax = plt.gca()
    scatter = plt.scatter(
        x_clean,
        y_clean,
        alpha=0.7,
        s=50,
        **sc_kwargs
    )

    # add regression line
    z = np.polyfit(x_clean, y_clean, 1)
    p = np.poly1d(z)
    x_line = np.linspace(x_clean.min(), x_clean.max(), 100)
    plt.plot(x_line, p(x_line), "r--", alpha=0.8, linewidth=2, label='Regression line')

    # add correlation statistics in top left corner
    stats_text = f'r = {corr_coef:.3f}\np = {p_value:.2e}'
    plt.text(0.05, 0.95, stats_text, transform=ax.transAxes,
             fontsize=12, verticalalignment='top',
             bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    # set labels and title
    plt.xlabel(x_col, fontsize=12)
    plt.ylabel(y_col, fontsize=12)
    plt.title(title, fontsize=14, fontweight='bold')
    plt.grid(True, alpha=0.3)

    # add color legend/colorbar at right center if requested
    if color_by is not None and color_by in clean_data.columns:
        if color_legend_info == 'numeric':
            cbar = plt.colorbar(scatter, ax=ax, fraction=0.05, pad=0.12)
            cbar.set_label(color_by, fontsize=12)
            cbar.ax.yaxis.set_ticks_position('right')
            cbar.ax.yaxis.set_label_position('right')
        else:
            unique_cats, color_map = color_legend_info
            handles = [Line2D([0], [0], marker='o', color='w', markerfacecolor=color_map[cat],
                              markersize=8, label=str(cat))
                       for cat in unique_cats]
            # Place legend at right center, outside plot
            ax.legend(
                handles=handles,
                title=color_by,
                loc='center left',
                bbox_to_anchor=(1.03, 0.5),
                borderaxespad=0.0,
                fontsize='small',
                frameon=False
            )

    if use_tight_layout:
        plt.tight_layout()
    plt.show()


def map_protein_to_gene(protein_names, organism="Homo sapiens"):
    """
    Map protein names to human gene names using UniProt REST API.

    Parameters
    ----------
    protein_names : list
        List of protein names to map (e.g., ['c-Met', 'Rictor', 'Paxillin']).
    organism : str, optional
        Organism filter for UniProt search (default: "Homo sapiens").

    Returns
    -------
    dict
        Dictionary mapping protein names to gene names.
        If a protein name cannot be mapped, the value will be None.

    Examples
    --------
    >>> protein_names = ['c-Met', 'Rictor', 'Paxillin', 'p21']
    >>> gene_mapping = map_protein_to_gene(protein_names)
    >>> print(gene_mapping)
    {'c-Met': 'MET', 'Rictor': 'RICTOR', 'Paxillin': 'PXN', 'p21': 'CDKN1A'}
    """
    mapping = {}
    base_url = "https://rest.uniprot.org/uniprotkb/search"

    for protein_name in protein_names:
        # skip if already processed
        if protein_name in mapping:
            continue

        # clean protein name for search (remove common suffixes/prefixes)
        search_name = protein_name.strip()

        # build query: search by protein name and organism
        query = f'name:"{search_name}" AND organism:"{organism}"'
        query = f"protein_name:{search_name} AND organism_name:{organism} AND reviewed:true "
        params = {
            'query': query,
            'fields': 'gene_names',
            'format': 'json',
            'size': 1  # Only need first result
        }

        try:
            # make API request
            response = requests.get(base_url, params=params, timeout=10)
            response.raise_for_status()
            data = response.json()

            # extract gene name from response
            gene_name = None
            if data.get('results') and len(data['results']) > 0:
                result = data['results'][0]
                # try to get gene name from geneNames field
                if 'genes' in result and len(result['genes']) > 0:
                    gene_info = result['genes'][0]
                    # prefer geneName over synonyms
                    if 'geneName' in gene_info:
                        gene_name = gene_info['geneName'].get('value')
                    elif 'synonyms' in gene_info and len(gene_info['synonyms']) > 0:
                        # fallback to first synonym if geneName not available
                        gene_name = gene_info['synonyms'][0].get('value')

            mapping[protein_name] = gene_name

            # small delay to respect API rate limits
            time.sleep(0.01)

        except requests.exceptions.RequestException as e:
            # if API request fails, set to None
            mapping[protein_name] = None
            print(f"Warning: Could not map '{protein_name}': {str(e)}")
        except (KeyError, IndexError) as e:
            # if response structure is unexpected, set to None
            mapping[protein_name] = None
            print(f"Warning: Unexpected response format for '{protein_name}': {str(e)}")

    return mapping


def compute_correlation_with_pvalue(df1, df2, axis=1, method='pearson'):
    """
    Compute correlation coefficient and p-value between two dataframes or series.
    Both df1 and df2 must have identical indexes and columns for axis=None.
    For axis=0 or 1, computes correlation column-wise or row-wise.

    Args:
        df1, df2: pd.DataFrame or pd.Series
        axis: If 0, computes correlation over rows for each column.
              If 1, computes correlation over columns for each row.
        method: If 'pearson', computes Pearson correlation.
                If 'spearman', computes Spearman correlation.
                Default is 'pearson'.
    Returns:
        pd.Series or pd.DataFrame: results with {'r': ..., 'p_value': ...}
    """
    # determine which correlation function to use
    if method == 'pearson':
        corr_func = pearsonr
        r_name = 'pearson_r'
    elif method == 'spearman':
        corr_func = spearmanr
        r_name = 'spearman_r'
    else:
        raise ValueError("method must be either 'pearson' or 'spearman'.")

    if axis == 0:
        # match columns, compute over rows for each column
        if not all(df1.columns == df2.columns):
            raise ValueError("Columns of the two dataframes are not identical.")

        corrs = []
        for col in df1.columns:
            x = df1[col]
            y = df2[col]
            mask = ~(pd.isna(x) | pd.isna(y))
            if mask.sum() < 3:
                corrs.append({r_name: np.nan, 'p_value': np.nan})
            else:
                corr, pval = corr_func(x[mask], y[mask])
                corrs.append({r_name: corr, 'p_value': pval})
        return pd.DataFrame(corrs, index=df1.columns)

    elif axis == 1:
        # match index, compute over columns for each row
        if not df1.index.equals(df2.index):
            raise ValueError("Indexes of the two dataframes are not identical.")

        corrs = []
        for row in df1.index:
            x = df1.loc[row, :]
            y = df2.loc[row, :]
            mask = ~(pd.isna(x) | pd.isna(y))
            if mask.sum() < 3:
                corrs.append({r_name: np.nan, 'p_value': np.nan})
            else:
                corr, pval = corr_func(x[mask], y[mask])
                corrs.append({r_name: corr, 'p_value': pval})
        return pd.DataFrame(corrs, index=df1.index)

    else:
        raise ValueError("axis must be None, 0, or 1")

## 2. ODM API configuration

**Configuring Access to Your ODM Instance**

Before querying data, establish a connection to your ODM deployment and authenticate using an API token.  
The ODM API uses token-based authentication, allowing secure programmatic access while preserving user-level permissions.

In this section:
* **Specify the ODM instance URL** – defines the environment you are connecting to.  
* **Provide the API token** – identifies and authorizes your user session.  
* **Initialize the ODM API client** – creates the communication layer for all subsequent requests.


In [None]:
# input ODM server address
server = 'https://q001-demo.trial.genestack.com/'

# input API token
set_api_credentials(server)

In [None]:
# credentials sanity check
if len(token)==0:
    print("Failed to paste API token from clipboard! Set the token manually (e.g. via `token = 'your_token'`).")
else:
    print("Token successfully set!")

In [None]:
# initialize API client
configuration = odm_api.Configuration(
    host=server,
    api_key={'Genestack-API-Token': token}
)
api_client = odm_api.ApiClient(configuration)

# read odm-api documentation from
print(f"{server}/user-docs/tools/odm-api/python/generated/")

## 3. Working with Data

### 3.1 Exploring Sample Endpoints

In ODM, each entity—such as samples, datasets, or assays—can be accessed through dedicated API endpoints.
The `Sample` class provides programmatic access to sample-level metadata, enabling you to search, retrieve, and inspect samples based on their attributes and values.

In this step, the `Sample` interface is initialized, and its available methods are listed.

In [None]:
# initialize API class
sample_api = odm_api.SampleSPoTAsUserApi(api_client)

# list all available sample_api endpoints
for item in [item for item in dir(sample_api) if item.endswith("_as_user")]:
    print(item)

### 3.2 Searching for Samples

This section demonstrates how to retrieve sample metadata using the `search_samples_as_user` endpoint.
The endpoint supports two complementary input parameters:

* `query` — performs a full-text search across all metadata fields.

* `filter` — applies logical conditions using metadata key–value pairs.

Both can be combined in a single request for flexible data retrieval.
The response includes:

* `meta` – summary information about the search results.

* `data` – a list of matching samples with their metadata attributes.

First, we will show how to retrieve a big number of samples for demonstration of `page_offset` parameter use when the number of samples is exceeding the `page_limit`.

In [None]:
# define search parameters
sample_filter = 'Organism="Homo sapiens"'

# batch 1:retrieve first 2000 items from all the human samples query
samples_batch1 = sample_api.search_samples_as_user(
    filter=sample_filter,
    page_limit=2000,
    page_offset=0
)

# batch 2: retrieve the rest of the terms from all the human samples query
# by setting 'page_offset' parameter equal to 'count' retrieved in 1st batch
samples_batch2 = sample_api.search_samples_as_user(
    filter=sample_filter,
    page_limit=2000,
    page_offset=1990
)

# show summary items from API responses
print(json.dumps(samples_batch1.meta.to_dict(), indent=2))
print(json.dumps(samples_batch2.meta.to_dict(), indent=2))

Next, we will use both sample query and filter parameters to retrieve and explore a specific list of samples.

In [None]:
# define search parameters
sample_query = 'steatohepatitis'
sample_filter = 'Organism="Homo sapiens" AND Tissue="liver"'

# search samples with both query and filter parameters
samples = sample_api.search_samples_as_user(
    filter=sample_filter,
    query=sample_query
)

# show first items from API response
print(json.dumps(samples.meta.to_dict(), indent=2))
print(json.dumps(samples.data[0], indent=2))

### 3.3 Exploring Sample Metadata Summary

This section summarizes the metadata attributes available in the retrieved sample set.
The sample records returned by the API are converted into a DataFrame, and basic statistics are calculated to inspect attribute completeness and diversity.

For each metadata field, the table shows:

* `unique` – number of distinct values across all samples.

* `total` – total number of samples included.

* `top_values` – the most frequent attribute values observed.

In [None]:
# convert samples.data list of dicts to a DataFrame
samples_df = pd.DataFrame(samples.data).dropna(axis=1, how='all')

# compute summary statistics for most common attribute values
samples_summary = pd.DataFrame({
    'unique': samples_df.nunique(),
    'total': samples_df.shape[0],
    'top_values': samples_df.apply(
        lambda col: " / ".join(col.value_counts(dropna=False).index.astype(str))
    )
})

# show summary statistics
samples_summary.sort_values('unique')

### 3.4 Exploring Integration Endpoints

Unlike the previous sample endpoints that search for samples using their own attributes, integration endpoints allow you to find entities based on the attributes of related entities.
In this example, studies can be retrieved by filtering on the attributes of samples that belong to them.

The `StudyIntegrationApi` class provides methods for querying studies through their relationships with other entities such as samples, libraries, preparations, or files.

In [None]:
# initialize API class
study_integration_api = odm_api.StudyIntegrationAsUserApi(api_client)

# list all available study_api endpoints
for item in [item for item in dir(study_integration_api) if item.endswith("_as_user")]:
    print(item)

### 3.5 Retrieving Studies by Associated Samples

Integration endpoints allow identifying higher-level entities based on connected metadata objects.
Here, the `get_studies_by_samples_as_user` integration endpoint is utilized to link and retrieve studies to samples, employing the identical sample filter as the preceding step. These integration endpoints are designed to identify higher-level entities by leveraging their connected metadata objects.

In [None]:
# define search parameters
sample_query = 'steatohepatitis'
sample_filter = 'Organism="Homo sapiens" AND Tissue="liver"'

# search studies by sample groupId
studies = study_integration_api.get_studies_by_samples_as_user(
    filter=sample_filter,
    query=sample_query
)

# show associated studies metadata
studies_df = pd.DataFrame(studies.data).sort_values('genestack:accession')
studies_df = studies_df.drop_duplicates(subset='Study Source ID', keep='first')
studies_df = studies_df[~studies_df['Study Title'].str.contains('Test study')]
studies_df

### 3.6 Exploring Omics Query Endpoints
Omics query endpoints provide access to quantitative datasets such as gene expression, variant, or flow cytometry data, along with their associated metadata.
These endpoints extend integration capabilities, enabling direct retrieval and filtering of omics measurements linked to specific samples or studies.

In this step, the `OmicsQueries` interface is initialized, and its available methods are listed to illustrate the range of omics data types accessible through the ODM API.

In [None]:
# initialize API class
omics_api = odm_api.OmicsQueriesAsUserApi(api_client)

# list all available omics_api endpoints
for item in [item for item in dir(omics_api) if item.endswith("_as_user")]:
    print(item)

### 3.7 Searching Samples via Omics Query Endpoints

The `omics_search_samples_as_user` endpoint enables sample metadata search within the omics query interface.
It functions similarly to the standard sample search but supports additional integration — combining study-level filters with sample attributes and linking downstream omics data types.

The API response includes:

* `log` – textual summary of matched studies and samples.

* `data` – a list of metadata records for the retrieved samples.

This combined search allows coordinated retrieval of samples and their related omics datasets across studies.

In [None]:
# define search parameters
study_ids = studies_df["genestack:accession"].tolist()
study_filter = ' OR '.join([f'"genestack:accession"="{id}"' for id in study_ids])
sample_query = '"steatohepatitis" OR "healthy"'
sample_filter = 'Organism="Homo sapiens" AND Tissue="liver"'

# search samples with both study and sample query/filter parameters
omics_samples = omics_api.omics_search_samples_as_user(
    study_filter=study_filter,
    sample_query=sample_query,
    sample_filter=sample_filter,
    returned_metadata_fields='original_data_included'
)

# show first items from API response
print(json.dumps(omics_samples.log, indent=2))
print(json.dumps(omics_samples.data[0], indent=2))

### 3.8 Visualizing Sample Metadata Distributions

This section illustrates how metadata attributes can be summarized and explored visually.
By converting the retrieved sample metadata JSON into a structured DataFrame, we can examine the distribution of key attributes such as study group, disease state, and sex.

The bar plots show how samples are distributed across these categories, enabling quick inspection of cohort balance and potential biases.
Visualizing attribute frequencies at this stage helps verify that metadata relationships are consistent before performing downstream omics queries or expression analyses.

In [None]:
# convert omics_samples.data list of nested dicts to a DataFrame
omics_samples_df = pd.DataFrame([
    item['metadata'] for item in omics_samples.data
])
omics_samples_df.dropna(axis=1, how='all', inplace=True)

# use disease abbreviation for visualization purposes
omics_samples_df.loc[omics_samples_df['Disease'].eq(
    "metabolic dysfunction-associated steatohepatitis"
), 'Disease'] = 'NASH'

# compute summary statistics for most common attribute values
omics_samples_summary = pd.DataFrame({
    'unique': omics_samples_df.nunique(),
    'total': omics_samples_df.shape[0],
    'top_values': omics_samples_df.apply(
        lambda col: " / ".join(col.value_counts(dropna=False).index.astype(str))
    )
})

# show summary statistics
omics_samples_summary.sort_values('unique')

In [None]:
# plot frequencies for disease and groupId attributes
group_counts = pd.crosstab(omics_samples_df['Disease'], omics_samples_df['groupId'])
x = group_counts.plot(kind='barh', stacked=True, colormap='tab20', figsize=(5, 2))
x.set_ylabel('')
x.set_title("Sample count per disease, colored by groupId")
x.legend(bbox_to_anchor=(1.05, 1))
plt.show()

# plot frequencies for disease and sex attributes
group_counts = pd.crosstab(omics_samples_df['Disease'], omics_samples_df['Sex'])
x = group_counts.plot(kind='barh', stacked=True, colormap='tab20', figsize=(5, 2))
x.set_ylabel('')
x.set_title("Sample count per disease, colored by sex")
x.legend(bbox_to_anchor=(1.05, 1))
plt.show()

### 3.9 Retrieving Expression Measurements

The `omics_search_expression_data_as_user` endpoint returns quantitative omics measurements with rich context.
Here we query by feature (gene symbols) and restrict results by sample accessions.

The response is a list of items; each item includes:

* `itemOrigin` – provenance identifiers (e.g., run and group).

* `metadata` – acquisition/processing details (platform, genome build, assay, files).

* `feature` – the targeted feature (e.g., gene symbol).

* `value` – the numeric measurement for that feature (representation depends on the dataset).

* `relationships` – linked entities such as the sample accession.


In [None]:
# define search parameters
genes = "CD36, CPT1A, CYP7A1, NR1H4, HMGCR, LDLR, LRP1, PPARA, SCARB1, SQLE"
genes = re.sub(r'\s+', '', genes)
ex_query = f"feature={genes}"
sample_ids = omics_samples_df["genestack:accession"].tolist()
sample_filter = ' OR '.join([f'"genestack:accession"="{id}"' for id in sample_ids])

# search expression data
ex_data = omics_api.omics_search_expression_data_as_user(
    ex_query=ex_query,
    sample_filter=sample_filter
)

# show first data item from API response
print(json.dumps(ex_data.data[0], indent=2))

After retrieving expression data, the next step is to combine all response items into a unified DataFrame for inspection.
Here, the metadata, feature, value, and relationship fields are merged, producing a tabular view of all expression records.

The summary table shows:

* `unique` – number of distinct values per attribute.

* `total` – total number of expression records.

* `top_values` – most common entries for each attribute.

In [None]:
# convert ex_data.data list of nested dicts to a DataFrame
expression_df = pd.DataFrame([
    item['metadata'] | item['feature'] | item['value'] | item['relationships']
    for item in ex_data.data
])
expression_df.dropna(axis=1, how='all', inplace=True)

# compute summary statistics for most common attribute values
ex_summary = pd.DataFrame({
    'unique': expression_df.nunique(),
    'total': expression_df.shape[0],
    'top_values': expression_df.apply(
        lambda col: " / ".join(col.value_counts(dropna=False).index.astype(str))
    )
})

# show summary statistics
ex_summary.sort_values('unique')

The expression data is coming from 2 sources with similar processing protocols and the reads mapped to the GRCh38 genome. This allows cross-study integration of the expression values for further analysis.

### 3.10 Visualizing Gene Expression Distributions

Here, expression data are merged with sample metadata to enable joint visualization and context-aware interpretation.

In [None]:
# merge samples and expression data
combined_data = pd.merge(
    omics_samples_df,
    expression_df,
    left_on="genestack:accession",
    right_on="sample"
)

Boxplot of data distribution. Each box in the plot represents the distribution of expression values for one gene across all retrieved samples.

In [None]:
# plot expression values for each gene
x = combined_data.boxplot('value', 'gene', rot=0, figsize=(7,4))
x.set_title("Expression of genes")
x.set_xlabel('')
plt.suptitle('')
plt.show()

This heatmap presents normalized expression profiles for the selected genes, grouped by a chosen metadata attribute.
Here, expression values are transformed using a log1p scale and standardized as per-gene z-scores, allowing differences in relative expression levels to be compared across samples.

In [None]:
plot_grouped_heatmap(combined_data, group_by="Disease", log1p=True)

### 3.11 Comparing Gene Expression Across Groups

This plot visualizes the distribution of expression values for each gene, grouped by a categorical metadata attribute.

Here, log1p-transformed expression levels are shown as violin plots, providing a combined view of data density and variability across groups.

In [None]:
plot_grouped_violin(combined_data, group_by="Disease", log1p=True, min_figsize=10, height=6)

This plot shows expression distributions for each gene, grouped by a continuous metadata variable that has been binned into discrete intervals.

Here, expression values remain on the original scale, allowing direct interpretation of magnitude differences.

In [None]:
combined_data_filtered = combined_data.loc[combined_data["Disease"].eq("NASH")].copy()
plot_grouped_violin(combined_data_filtered, group_by="Age", log1p=False, min_figsize=10, height=6)

### 3.12 Filtering Samples by Quantitative and Qualitative Conditions

This example demonstrates how `omics_search_samples_as_user` can perform complex, multi-dimensional filtering across metadata and expression data.
Unlike earlier queries limited to categorical metadata (e.g., organism or tissue), this query also introduces quantitative and omics-based conditions within a single search.

Here, multiple query layers are combined:

* Categorical filters define biological context (organism, tissue, or study).

* Quantitative filters specify numeric constraints on metadata attributes (e.g., thresholds for age or concentration).

* Expression filters restrict samples by quantitative omics measurements, such as a minimum expression value for a gene of interest.

The API resolves these conditions in sequence, returning only samples that satisfy all specified filters.

In [None]:
# define search parameters
sample_query = "steatohepatitis"
sample_filter = 'Organism="Homo sapiens" AND Tissue="liver"'
ex_query = 'feature=NR1H4 value >= 4'

# search samples with both qualitative and quantitative query/filter parameters
omics_samples = omics_api.omics_search_samples_as_user(
    sample_query=sample_query,
    sample_filter=sample_filter,
    ex_query=ex_query
)

# show query
print(json.dumps(omics_samples.log, indent=2))

# extract sample ids
sample_ids = [
    item['metadata'].get('genestack:accession') for item in omics_samples.data
]
print("genestack:accession:", sample_ids)

---

## 4. Summary and Next Steps

This notebook completes the preprocessing and analysis steps up to the integration of transcriptomics features and protein–gene mappings. The final multi-omics integration step (retrieval of matched transcriptomics expression data and downstream analyses) will be added once the transcriptomics expression group and API parameters are confirmed.

---
