In [1]:
import utils
import anndata
import numpy as np
import pandas as pd
import requests
import scanpy as sc
import scanpy.external as sce
import scanorama
import os
import anndata
from typing import List, Optional
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import umap
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.metrics import silhouette_score
from sklearn.cluster import KMeans
import plotly.express as px
import plotly.graph_objects as go

In [2]:
def plot_pca_2d_interactive(data, labels_df, plotname):
    print("Generating interactive 2D PCA plot...")
    
    # Perform PCA
    pca = PCA(n_components=2)
    pca_result = pca.fit_transform(data)
    explained_var = pca.explained_variance_ratio_
    
    # Create base DataFrame with PCA results
    pca_df = pd.DataFrame(pca_result, columns=['PC1', 'PC2'])
    
    # Add all label columns to the PCA DataFrame
    for col in labels_df.columns:
        pca_df[col] = labels_df[col].values
    
    # Create figure
    fig = go.Figure()
    
    # Add traces for each column (initially all invisible except the first)
    trace_idx = 0
    column_trace_mapping = {}  # Track which traces belong to which column
    
    for i, col in enumerate(labels_df.columns):
        column_trace_mapping[col] = []
        
        # Check if column is categorical or numerical
        is_categorical = pca_df[col].dtype == 'object' or pca_df[col].dtype.name == 'category'
        
        if is_categorical:
            # For categorical data, create separate traces for each category
            unique_vals = sorted(pca_df[col].unique())
            colors = px.colors.qualitative.Set1
            
            for j, val in enumerate(unique_vals):
                mask = pca_df[col] == val
                
                trace = go.Scatter(
                    x=pca_df.loc[mask, 'PC1'],
                    y=pca_df.loc[mask, 'PC2'],
                    mode='markers',
                    marker=dict(
                        size=6,
                        color=colors[j % len(colors)],
                        opacity=0.7
                    ),
                    name=str(val),
                    hovertemplate="PC1: %{x:.3f}<br>PC2: %{y:.3f}<extra></extra>",
                    visible=True if i == 0 else False,
                    legendgroup=col
                )
                
                fig.add_trace(trace)
                column_trace_mapping[col].append(trace_idx)
                trace_idx += 1
                
        else:
            # For numerical data, use continuous color scale
            trace = go.Scatter(
                x=pca_df['PC1'],
                y=pca_df['PC2'],
                mode='markers',
                marker=dict(
                    size=6,
                    color=pca_df[col],
                    colorscale='Viridis',
                    opacity=0.7,
                    colorbar=dict(title=col),
                    showscale=True
                ),
                name=col,
                hovertemplate="PC1: %{x:.3f}<br>PC2: %{y:.3f}<extra></extra>",
                visible=True if i == 0 else False,
                showlegend=False
            )
            
            fig.add_trace(trace)
            column_trace_mapping[col].append(trace_idx)
            trace_idx += 1
    
    # Create dropdown buttons
    dropdown_buttons = []
    
    for col in labels_df.columns:
        # Create visibility array
        visible = [False] * len(fig.data)
        
        # Set traces for this column to visible
        for trace_idx in column_trace_mapping[col]:
            visible[trace_idx] = True
        
        # Determine button label
        is_categorical = pca_df[col].dtype == 'object' or pca_df[col].dtype.name == 'category'
        label = f"{col} ({'Categorical' if is_categorical else 'Continuous'})"
        
        dropdown_buttons.append(
            dict(
                args=[{"visible": visible}],
                label=label,
                method="restyle"
            )
        )
    
    # Update layout with dropdown
    fig.update_layout(
        updatemenus=[
            dict(
                buttons=dropdown_buttons,
                direction="down",
                pad={"r": 10, "t": 10},
                showactive=True,
                x=0.01,
                xanchor="left",
                y=1.02,
                yanchor="top",
                bgcolor="black",
                bordercolor="gray",
                borderwidth=1
            ),
        ],
        annotations=[
            dict(
                text="Color by:",
                showarrow=False,
                x=0.01, y=1.05,
                xref="paper", yref="paper",
                align="left",
                font=dict(size=12)
            )
        ]
    )
    
    # Update layout
    default_col = labels_df.columns[0]
    fig.update_layout(
        title=f'Interactive PCA Plot',
        xaxis_title=f'PC1 ({explained_var[0]:.1%} variance explained)',
        yaxis_title=f'PC2 ({explained_var[1]:.1%} variance explained)',
        width=900,
        height=700,
        hovermode='closest',
        legend=dict(
            yanchor="top",
            y=0.99,
            xanchor="left",
            x=1.01
        )
    )
    
    # Save and display
    filename = f"{plotname}_interactive.html"
    fig.write_html(filename)
    fig.show()
    
    # Print summary information
    print(f"\nPCA Summary:")
    print(f"- PC1 explains {explained_var[0]:.1%} of variance")
    print(f"- PC2 explains {explained_var[1]:.1%} of variance") 
    print(f"- Total variance explained: {sum(explained_var):.1%}")
    print(f"\nAvailable columns: {list(labels_df.columns)}")
    print(f"Plot saved as: {filename}")
    
    return fig

In [3]:
path = os.getcwd()[:-3]

In [3]:

anndata_613  = sc.read_h5ad(path + 'data/processed_613_data.h5ad')

In [4]:
anndata_613.obs = anndata_613.obs.drop(columns=['dataset'])
sc.pp.highly_variable_genes(anndata_613, n_top_genes=2000, flavor='cell_ranger')
anndata_613 = anndata_613[:, anndata_613.var['highly_variable']]

In [5]:
arr_613 = anndata_613.X.toarray()
#unbatch_corrected = unbatch_corrected.X.toarray()
#combined_data = np.vstack([batch_corrected, unbatch_corrected])
#labels = np.array(['Batch Corrected'] * batch_corrected.shape[0] + ['UnBatch Corrrected'] * unbatch_corrected.shape[0])

In [11]:
print(anndata_613.shape)
print(anndata_613.obs)

(116403, 2000)
                     batch  n_genes       Strain     Sex  Age at Launch  \
AAACAGCCAAAGCGCA-1       1     3001  C57BL/6NTac  Female            0.0   
AAACATGCAATTAGGA-1       1     1767  C57BL/6NTac  Female            0.0   
AAACATGCAGGGAGGA-1       1     3368  C57BL/6NTac  Female            0.0   
AAACATGCATGTTGCA-1       1     1074  C57BL/6NTac  Female            0.0   
AAACCAACACAACCTA-1       1     1355  C57BL/6NTac  Female            0.0   
...                    ...      ...          ...     ...            ...   
TTTGTTGGTAACGAGG-18     18     2304  C57BL/6NTac  Female            0.0   
TTTGTTGGTAGGATTT-18     18     2837  C57BL/6NTac  Female            0.0   
TTTGTTGGTGCTGGTG-18     18     2290  C57BL/6NTac  Female            0.0   
TTTGTTGGTGCTTACT-18     18     3550  C57BL/6NTac  Female            0.0   
TTTGTTGGTTCGGGAT-18     18     1422  C57BL/6NTac  Female            0.0   

                     Duration  Flight  
AAACAGCCAAAGCGCA-1       53.0     1.0  
AAAC

In [4]:
anndata_613_no_gene_drop  = sc.read_h5ad(path + 'data/processed_613_data.h5ad')

In [None]:

plot_pca_2d_interactive(anndata_613_no_gene_drop.X.toarray(), anndata_613_no_gene_drop.obs, 'Interactive_613_no_gene_selection_PCA')

Generating interactive 2D PCA plot...


In [0]:
batch_corrected_no_arr  = sc.read_h5ad(path + 'data/corrected_data.h5ad')
unbatch_corrected_no_arr = sc.read_h5ad(path + 'data/unbatch_corrected_data.h5ad')

In [0]:
print(batch_corrected_no_arr.obs['Dataset'])
labels = list(map(lambda x: 'Ground Control' if x == 0 else 'Flight', batch_corrected_no_arr.obs['Flight'])) +  list(map(lambda x: 'Ground Control' if x == 0 else 'Flight', unbatch_corrected_no_arr.obs['Flight'])) 
print(set(labels))

In [64]:
print(unbatch_corrected_no_arr.obs['dataset'])
labels = list(batch_corrected_no_arr.obs['dataset']) + list(unbatch_corrected_no_arr.obs['dataset'])

AAACAGCCAAGGTGCA-1     352
AAACAGCCACAATTAC-1     352
AAACAGCCACAGGGAC-1     352
AAACAGCCACCTCACC-1     352
AAACAGCCAGGCATCT-1     352
                      ... 
TTTGTTGGTAACGAGG-18    613
TTTGTTGGTAGGATTT-18    613
TTTGTTGGTGCTGGTG-18    613
TTTGTTGGTGCTTACT-18    613
TTTGTTGGTTCGGGAT-18    613
Name: dataset, Length: 212653, dtype: category
Categories (3, object): ['352', '612', '613']


In [71]:
anndata_613  = sc.read_h5ad(path + 'data/processed_613_data.h5ad')
labels_613 = sc.
annadata_613_arr = anndata_613.X.toarray()

MemoryError: Unable to allocate 11.0 GiB for an array with shape (116403, 25323) and data type float32

In [69]:

plot_pca_2d(unbatch_corrected, list(unbatch_corrected_no_arr.obs['dataset']), 'unbatch_dataset_labels')

Generating 2D PCA plot...
