In [73]:
import os
import torch
import pandas as pd
from pathlib import Path
import numpy as np
import matplotlib
from tqdm import tqdm
from itertools import combinations
from napatrackmater.Trackvector import (
    BROWNIAN_FEATURES
)
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


In [None]:
dataset_name = 'Sixth'
home_folder = '/home/debian/jz/'
channel = 'nuclei_'
tracking_directory = f'{home_folder}Mari_Data_Oneat/Mari_{dataset_name}_Dataset_Analysis/nuclei_membrane_tracking/'
data_frames_dir = os.path.join(tracking_directory, f'dataframes/')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
tracklet_length = 25
normalized_dataframe = os.path.join(data_frames_dir , f'goblet_basal_dataframe_normalized_{channel}.csv')
print(f'reading data from {normalized_dataframe}')
tracks_dataframe = pd.read_csv(normalized_dataframe)
save_dir = os.path.join(tracking_directory, f'{channel}phasespaces')
Path(save_dir).mkdir(exist_ok=True)
deltat = 10
class_map_gbr = {
        0: "Basal",
        1: "Radial",
        2: "Goblet"
    }

In [56]:
all_trackmate_ids = [trackid for trackid in tracks_dataframe['TrackMate Track ID'].unique()]
result_dict = {cell_type: {} for cell_type in class_map_gbr.values()}
unique_time_points = tracks_dataframe['t'].unique()

for time_point in unique_time_points:
    time_data = tracks_dataframe[tracks_dataframe['t'] == time_point]

    for cell_type in class_map_gbr.values():
        cell_type_data = time_data[time_data['Cell_Type'] == cell_type]
        
        if not cell_type_data.empty:
            if time_point not in result_dict[cell_type]:
                result_dict[cell_type][time_point] = {}

            for track_id in cell_type_data['Track ID'].unique():
                track_features = cell_type_data[cell_type_data['Track ID'] == track_id][BROWNIAN_FEATURES].to_numpy()

                feature_pairs = list(combinations(BROWNIAN_FEATURES, 2))

                for (feature_name1, feature_name2) in feature_pairs:
                    pair_key = f"{feature_name1}_vs_{feature_name2}"

                    pairwise_values = np.stack(
                        [track_features[:, BROWNIAN_FEATURES.index(feature_name1)],
                         track_features[:, BROWNIAN_FEATURES.index(feature_name2)]],
                        axis=1
                    )

                    if pair_key not in result_dict[cell_type][time_point]:
                        result_dict[cell_type][time_point][pair_key] = []

                    result_dict[cell_type][time_point][pair_key].append({
                        'Track ID': int(track_id),  
                        'Pairwise Values': pairwise_values
                    })


In [110]:
def build_dataframe(result_dict):
    """Converts result_dict to a pandas DataFrame for plotting, including Track ID."""
    rows = []

    for cell_type, time_data in result_dict.items():
        for time_point, feature_dict in time_data.items():
            for feature_pair, track_data in feature_dict.items():
                for track_entry in track_data:
                    track_id = track_entry['Track ID']
                    pairwise_array = track_entry['Pairwise Values']

                    feature_name1, feature_name2 = feature_pair.split('_vs_')
                    for feature1, feature2 in pairwise_array:
                        rows.append({
                            'Cell Type': cell_type,
                            'Time Point': time_point,
                            'Feature Pair': feature_pair,
                            'Track ID': track_id, 
                            feature_name1: feature1,
                            feature_name2: feature2,
                        })

    df = pd.DataFrame(rows)
    return df





def plot_phasespace(df, title='phasespace'):
    """Plots the phase space using seaborn's kdeplot for all time points in one plot per feature pair, with a color bar."""
    
    cmap = plt.cm.viridis
    norm = matplotlib.colors.Normalize(vmin=df['Time Point'].min(), vmax=df['Time Point'].max())
    
    all_feature_pairs = list(df['Feature Pair'].unique())
    n_feature_pairs = len(all_feature_pairs)
    
    n_cols = 3  
    n_rows = (n_feature_pairs // n_cols) + (n_feature_pairs % n_cols > 0)
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows), constrained_layout=True)
    
    axes = axes.flatten()  
    
    plot_idx = 0 

    for cell_type in df['Cell Type'].unique():
        cell_type_df = df[df['Cell Type'] == cell_type]
        
        for feature_pair in all_feature_pairs:
            pair_df = cell_type_df[cell_type_df['Feature Pair'] == feature_pair]
            print(f"Plotting for {cell_type} and feature pair {feature_pair}")

            actual_feature_name1, actual_feature_name2 = feature_pair.split('_vs_')
            ax = axes[plot_idx]

            for time_point in sorted(pair_df['Time Point'].unique()):
                time_df = pair_df[pair_df['Time Point'] == time_point]
                
                color = cmap(norm(time_point))

                sns.kdeplot(
                    data=time_df,
                    x=actual_feature_name1,
                    y=actual_feature_name2,
                    label=f"Time {time_point}",
                    alpha=0.5,
                    ax=ax,
                    color=color
                )

            ax.set_xlabel(f"{actual_feature_name1}", fontsize=12)
            ax.set_ylabel(f"{actual_feature_name2}", fontsize=12)
            ax.set_title(f"{cell_type}: {actual_feature_name1} vs {actual_feature_name2}", fontsize=14)
            ax.legend(title="Time Points", loc="upper right", fontsize=10)

            sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
            sm.set_array([]) 
            cbar = fig.colorbar(sm, ax=ax, orientation='vertical')
            cbar.set_label('Time Point', fontsize=12)

            plot_filename = os.path.join(save_dir, f"{title}_{cell_type}_{actual_feature_name1}_{actual_feature_name2}_{dataset_name}_phasespace.png")
            plt.savefig(plot_filename, dpi=300, bbox_inches='tight')
            
            plot_idx += 1

    combined_plot_filename = os.path.join(save_dir, f"{title}_combined_phasespace.png")
    plt.savefig(combined_plot_filename, dpi=300, bbox_inches='tight')
    
    plt.show()





def test_ergodicity(df, feature1='Radius', feature2='Eccentricity_Comp_First', error_tolerance=0.2, time_delta=10, max_failed_tracks=10):
    ergodicity_times = []

    for feature_pair in df['Feature Pair'].unique():
        if feature1 not in feature_pair or feature2 not in feature_pair:
            continue
        
        pair_df = df[df['Feature Pair'] == feature_pair]
        actual_feature_name1, actual_feature_name2 = feature_pair.split('_vs_')

        for cell_type in df['Cell Type'].unique():
            cell_type_df = pair_df[pair_df['Cell Type'] == cell_type]

            total_time_points = len(cell_type_df['Time Point'].unique())
            time_points = sorted(cell_type_df['Time Point'].unique())

            for start_time in tqdm(range(0, total_time_points, time_delta)):
                end_time = min(start_time + time_delta, total_time_points)
                time_interval_points = time_points[start_time:end_time]

                time_averages = {}
                for track_id in cell_type_df['Track ID'].unique():
                    track_df = cell_type_df[cell_type_df['Track ID'] == track_id]
                    time_averages[track_id] = (
                        np.mean(track_df[actual_feature_name1].values),
                        np.mean(track_df[actual_feature_name2].values)
                    )

                ensemble_averages = {}
                for time_point in time_interval_points:
                    time_df = cell_type_df[cell_type_df['Time Point'] == time_point]
                    ensemble_averages[time_point] = (
                        np.mean(time_df[actual_feature_name1].values),
                        np.mean(time_df[actual_feature_name2].values)
                    )

                failed_tracks_count = 0
                for track_id, (time_avg1, time_avg2) in time_averages.items():
                    for time_point in time_interval_points:
                        ensemble_avg1, ensemble_avg2 = ensemble_averages[time_point]
                        if abs(time_avg1 - ensemble_avg1) > error_tolerance or abs(time_avg2 - ensemble_avg2) > error_tolerance:
                            failed_tracks_count += 1
                            break
                    if failed_tracks_count >= max_failed_tracks:
                        break

                if failed_tracks_count < max_failed_tracks:
                    ergodicity_times.append((cell_type, end_time))
                    break

    if ergodicity_times:
        plt.figure(figsize=(10, 6))

        for cell_type, time in ergodicity_times:
            plt.plot(time, 1, 'go', label=f'Ergodicity Reached for {cell_type}')

        plt.title(f"Ergodicity Reached for {feature1} and {feature2}")
        plt.xlabel('Time (Time Interval)')
        plt.ylabel('Ergodicity Reached')
        plt.legend()
        plt.show()
    else:
        print(f"Ergodicity was not reached for {feature1} and {feature2} within the specified intervals.")




            

In [58]:
feature_dataframe = build_dataframe(result_dict)


In [111]:
test_ergodicity(feature_dataframe)

 31%|███       | 11/36 [00:12<00:28,  1.14s/it]

In [None]:

plot_phasespace(feature_dataframe)