# Tutorial 2: Using ``CAP``

The ``CAP`` class is designed to perform CAPs analyses (on all subjects or group of subjects). It offers the flexibility
to analyze data from all subjects or focus on specific groups, compute CAP-specific metrics, and generate visualizations
to aid in the interpretation of results.


In [None]:
# Download packages
try:
    import neurocaps
except:
    !pip install neurocaps[windows,demo]

# Set headless display for google colab
import os, sys

if "google.colab" in sys.modules:
    os.environ["DISPLAY"] = ":0.0"
    !apt-get install -y xvfb
    !Xvfb :0 -screen 0 1024x768x24 &> /dev/null &
    !Xvfb :0 -screen 0 1024x768x24 &> /dev/null &

# Performing CAPs on All Subjects
All information pertaining to CAPs (k-means models, activation vectors/cluster centroids, etc) are stored as attributes
in the ``CAP`` class and this information is used by all methods in the class. These attributes are accessible via
[properties](https://neurocaps.readthedocs.io/en/stable/generated/neurocaps.analysis.CAP.html#properties).
**Some properties can also be used as setters.**

In [None]:
import numpy as np
from neurocaps.analysis import CAP

# Extracting timseries
parcel_approach = {"Schaefer": {"n_rois": 100, "yeo_networks": 7, "resolution_mm": 2}}

# Simulate data for example
subject_timeseries = {
    str(x): {f"run-{y}": np.random.rand(100, 100) for y in range(1, 4)} for x in range(1, 11)
}

# Initialize CAP class
cap_analysis = CAP(parcel_approach=parcel_approach)

# Get CAPs
cap_analysis.get_caps(
    subject_timeseries=subject_timeseries,
    n_clusters=range(2, 11),
    cluster_selection_method="elbow",
    show_figs=True,
    step=2,
    progress_bar=True,
)

``print`` can be used to return a string representation of the ``CAP`` class.

In [None]:
print(cap_analysis)

## Performing CAPs on Groups

In [None]:
cap_analysis = CAP(groups={"A": ["1", "2", "3", "5"], "B": ["4", "6", "7", "8", "9", "10"]})

cap_analysis.get_caps(
    subject_timeseries=subject_timeseries,
    n_clusters=range(2, 21),
    cluster_selection_method="silhouette",
    show_figs=True,
    step=2,
    progress_bar=True,
)

# The concatenated data can be safely deleted since only the kmeans models and any standardization parameters are
# used for computing temporal metrics.

del cap_analysis.concatenated_timeseries

## Calculate Metrics

Note that if ``standardize`` was set to True in ``CAP.get_caps()``, then the column (ROI) means and standard deviations
computed from the concatenated data used to obtain the CAPs are also used to standardize each subject in the timeseries
data inputted into ``CAP.calculate_metrics()``. This ensures proper CAP assignments for each subjects frames.


In [None]:
df_dict = cap_analysis.calculate_metrics(
    subject_timeseries=subject_timeseries,
    return_df=True,
    metrics=["temporal_fraction", "counts", "transition_probability"],
    continuous_runs=True,
    progress_bar=True,
)

print(df_dict["temporal_fraction"])

## Plotting CAPs

In [None]:
import seaborn as sns

cap_analysis = CAP(
    parcel_approach={"Schaefer": {"n_rois": 100, "yeo_networks": 7, "resolution_mm": 1}}
)

cap_analysis.get_caps(subject_timeseries=subject_timeseries, n_clusters=6)

sns.diverging_palette(145, 300, s=60, as_cmap=True)

palette = sns.diverging_palette(260, 10, s=80, l=55, n=256, as_cmap=True)

kwargs = {
    "subplots": True,
    "fontsize": 14,
    "ncol": 3,
    "sharey": True,
    "tight_layout": False,
    "xlabel_rotation": 0,
    "hspace": 0.3,
    "cmap": palette,
}

cap_analysis.caps2plot(
    visual_scope="regions", plot_options="outer_product", show_figs=True, **kwargs
)

cap_analysis.caps2plot(
    visual_scope="nodes",
    plot_options="heatmap",
    xticklabels_size=7,
    yticklabels_size=7,
    show_figs=True,
)

## Generate Pearson Correlation Matrix

In [None]:
cap_analysis.caps2corr(annot=True, cmap="viridis", show_figs=True)

corr_dict = cap_analysis.caps2corr(return_df=True)
print(corr_dict["All Subjects"])

## Creating Surface Plots


In [None]:
from matplotlib.colors import LinearSegmentedColormap

# Create the colormap
colors = [
    "#1bfffe",
    "#00ccff",
    "#0099ff",
    "#0066ff",
    "#0033ff",
    "#c4c4c4",
    "#ff6666",
    "#ff3333",
    "#FF0000",
    "#ffcc00",
    "#FFFF00",
]

custom_cmap = LinearSegmentedColormap.from_list("custom_cold_hot", colors, N=256)

# Apply custom cmap to surface plots
cap_analysis.caps2surf(progress_bar=True, cmap=custom_cmap, size=(500, 100), layout="row")

## Plotting CAPs to Radar

In [None]:
radialaxis = {
    "showline": True,
    "linewidth": 2,
    "linecolor": "rgba(0, 0, 0, 0.25)",
    "gridcolor": "rgba(0, 0, 0, 0.25)",
    "ticks": "outside",
    "tickfont": {"size": 14, "color": "black"},
    "range": [0, 0.6],
    "tickvals": [0.1, "", "", 0.4, "", "", 0.6],
}

legend = {
    "yanchor": "top",
    "y": 0.99,
    "x": 0.99,
    "title_font_family": "Times New Roman",
    "font": {"size": 12, "color": "black"},
}

colors = {"High Amplitude": "red", "Low Amplitude": "blue"}

kwargs = {
    "radialaxis": radialaxis,
    "fill": "toself",
    "legend": legend,
    "color_discrete_map": colors,
    "height": 400,
    "width": 600,
}

cap_analysis.caps2radar(**kwargs)