In [None]:
from napatrackmater.Trackvector import TrackVector
from pathlib import Path 
import os
import torch
import ipywidgets as widgets
from IPython.display import clear_output
import matplotlib.pyplot as plt
import pandas as pd
from torch.nn.modules.loss import CrossEntropyLoss
from kapoorlabs_lightning.optimizers import Adam

from kapoorlabs_lightning.pytorch_models import DenseNet
from kapoorlabs_lightning.lightning_trainer import LightningModel
from napatrackmater.Trackvector import (
    TrackVector,
    SHAPE_FEATURES,
    DYNAMIC_FEATURES
)





In [None]:
home_folder = '/home/debian/jz/'

In [None]:
dataset_name = 'Third'
timelapse_nuclei_to_track = f'timelapse_{dataset_name.lower()}_dataset'
tracking_directory = f'{home_folder}Mari_Data_Oneat/Mari_{dataset_name}_Dataset_Analysis/nuclei_membrane_tracking/'
channel = 'nuclei_'
timelapse_to_track = timelapse_nuclei_to_track
base_dir = tracking_directory
master_xml_name = 'master_' + 'marching_cubes_filled_' + channel + timelapse_to_track + ".xml"
xml_path = Path(os.path.join(base_dir, master_xml_name))
oneat_detections = f'{home_folder}Mari_Data_Oneat/Mari_{dataset_name}_Dataset_Analysis/oneat_detections/non_maximal_oneat_mitosis_locations_nuclei_timelapse_{dataset_name.lower()}_dataset.csv'
tracklet_length = 50
save_file = os.path.join(base_dir , f'results_dataframe_normalized_{channel}.csv')

dataset_dataframe = pd.read_csv(save_file)



In [None]:
model_dir = f'{home_folder}Mari_Models/TrackModels/'
shape_model_json = f'{model_dir}shape_feature_lightning_densenet_mitosis/shape_densenet.json'

device = 'cpu'
loss_func =  CrossEntropyLoss()
shape_lightning_model, shape_torch_model = LightningModel.extract_mitosis_model(
    DenseNet,
    shape_model_json,
    loss_func,
    Adam,
    map_location=torch.device(device)
    
)



shape_torch_model.eval()


In [None]:

dynamic_model_json = f'{model_dir}dynamic_feature_lightning_densenet_mitosis/dynamic_densenet.json'

dynamic_lightning_model, dynamic_torch_model = LightningModel.extract_mitosis_model(
    DenseNet,
    dynamic_model_json,
    loss_func,
    Adam,
    map_location=torch.device(device)
    
)

dynamic_torch_model.eval()

In [None]:
dividing_tracklets = dataset_dataframe[(dataset_dataframe['Dividing'] == 1) & (dataset_dataframe['Generation ID'] >= 0) & (dataset_dataframe['Track Duration'] > tracklet_length/0.07)]

non_dividing_tracklets = dataset_dataframe[dataset_dataframe['Dividing'] == 0 & (dataset_dataframe['Track Duration'] > tracklet_length/0.07)]



In [None]:
random_dividing_track = dividing_tracklets['Track ID'].sample().iloc[0]
random_non_dividing_track = non_dividing_tracklets['Track ID'].sample().iloc[0]

sub_dividing_dataframe = dividing_tracklets[dividing_tracklets['Track ID']==random_dividing_track]
sub_dividing_dataframe_dynamic = sub_dividing_dataframe[DYNAMIC_FEATURES].values
sub_dividing_dataframe_shape = sub_dividing_dataframe[SHAPE_FEATURES].values

sub_non_dividing_dataframe = non_dividing_tracklets[non_dividing_tracklets['Track ID']==random_non_dividing_track]
sub_non_dividing_dataframe_dynamic = sub_non_dividing_dataframe[DYNAMIC_FEATURES].values
sub_non_dividing_dataframe_shape = sub_non_dividing_dataframe[SHAPE_FEATURES].values

print(f'Loaded data from prediction, mitosis dynamic data {sub_dividing_dataframe_dynamic.shape} and shape data {sub_dividing_dataframe_shape.shape}')
print(f'Loaded data from prediction, non mitosis dynamic data {sub_non_dividing_dataframe_dynamic.shape} and shape data {sub_non_dividing_dataframe_shape.shape}')

print(f'Choosing random mitosis track {random_dividing_track} and a random non-mitosis track {random_non_dividing_track}')

In [None]:
class_map = {
    0: "non-mitosis",
    1: "mitosis"
}

def create_prediction_arrays(array, start_index, legend='shape'):
    
    end_index = start_index + tracklet_length
    sub_array = array[start_index:end_index, :]
    
    plot_input_data(sub_array, legend)
    return sub_array

def make_prediction(input_data, model):
    with torch.no_grad():
        input_tensor = torch.tensor(input_data).unsqueeze(0).permute(0, 2, 1).float()
        model_predictions = model(input_tensor)
        probabilities = torch.softmax(model_predictions[0], dim=0)
        print(probabilities)
        _, predicted_class = torch.max(probabilities, 0)
    return predicted_class

def plot_input_data(input_data, legend):

  
    clear_output(wait=True)
    plt.figure(figsize=(10, 5))
    plt.plot(input_data)
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.title('Input Data')
    plt.legend(legend)
   

In [None]:

def interactive_shape_prediction(start_index):
    sub_array = create_prediction_arrays(sub_dividing_dataframe_shape, start_index, legend=SHAPE_FEATURES)
    prediction = make_prediction(sub_array, shape_torch_model)
    print("Prediction:", class_map[int(prediction)])


start_index_slider = widgets.IntSlider(value=0, min=0, max=sub_dividing_dataframe_shape.shape[0] - tracklet_length, description='Start Index:')
widgets.interact(interactive_shape_prediction, start_index=start_index_slider)

In [None]:
def interactive_dynamic_prediction(start_index):
    sub_array = create_prediction_arrays(sub_dividing_dataframe_dynamic, start_index, legend=DYNAMIC_FEATURES)
    prediction = make_prediction(sub_array, dynamic_torch_model)
    print("Prediction:", class_map[int(prediction)])

second_start_index_slider = widgets.IntSlider(value=0, min=0, max=sub_dividing_dataframe_dynamic.shape[0] - tracklet_length, description='Start Index:')
widgets.interact(interactive_dynamic_prediction, start_index=second_start_index_slider)

In [None]:
def interactive_dual_prediction(start_index):
    sub_array = create_prediction_arrays(sub_dividing_dataframe_dynamic, start_index, legend=DYNAMIC_FEATURES)
    prediction = make_prediction(sub_array, dynamic_torch_model)
    
    sub_array = create_prediction_arrays(sub_dividing_dataframe_shape, start_index, legend=SHAPE_FEATURES)
    prediction_shape = make_prediction(sub_array, shape_torch_model)
    print(prediction)
    
    prediction = prediction or prediction_shape
    
    print(prediction_shape)
    print("Prediction:",  class_map[int(prediction)])

second_start_index_slider = widgets.IntSlider(value=0, min=0, max=sub_dividing_dataframe_dynamic.shape[0] - tracklet_length, description='Start Index:')
widgets.interact(interactive_dual_prediction, start_index=second_start_index_slider)