In [1]:
import numpy as np
import os
import os
import torch
from torch.utils.data import DataLoader, Dataset
from torch.nn.modules.loss import CrossEntropyLoss
from kapoorlabs_lightning.optimizers import Adam

from kapoorlabs_lightning.pytorch_models import DenseNet
from kapoorlabs_lightning.lightning_trainer import LightningModel
from napatrackmater.Trackvector import (
    SHAPE_FEATURES,
    DYNAMIC_FEATURES
)
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns


  def groupby_count(xyz, indices, out):
  def groupby_sum(xyz, indices, N, out):
  def groupby_max(xyz, indices, N, out):
2024-07-22 13:17:49.927671: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2024-07-22 13:17:50.048260: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
home_folder = '/home/debian/jz/'
base_dir = '/home/debian/jz/Mari_Data_Training/track_training_data/'


dynamic_mitosis_npz_file = 'dynamic_training_data_mitosis.npz'
shape_mitosis_npz_file = 'shape_training_data_mitosis.npz'

shape_validation_data = np.load(os.path.join(base_dir, shape_mitosis_npz_file))
dynamic_validation_data = np.load(os.path.join(base_dir, dynamic_mitosis_npz_file))

val_shape_dividing_arrays = shape_validation_data["dividing_val_arrays"]
val_shape_dividing_labels = shape_validation_data["dividing_val_labels"]
val_shape_non_dividing_arrays = shape_validation_data["non_dividing_val_arrays"]
val_shape_non_dividing_labels = shape_validation_data["non_dividing_val_labels"]

val_dynamic_dividing_arrays = dynamic_validation_data["dividing_val_arrays"]
val_dynamic_dividing_labels = dynamic_validation_data["dividing_val_labels"]
val_dynamic_non_dividing_arrays = dynamic_validation_data["non_dividing_val_arrays"]
val_dynamic_non_dividing_labels = dynamic_validation_data["non_dividing_val_labels"]

: 

In [None]:
print(f'Mitosis {val_shape_dividing_arrays.shape}, Non-Mitosis {val_shape_non_dividing_arrays.shape}')

In [None]:
model_dir = f'{home_folder}Mari_Models/TrackModels/'
gbr_shape_model_json = f'{model_dir}shape_feature_lightning_densenet_mitosis/shape_densenet.json'
device = 'cpu'
loss_func =  CrossEntropyLoss()
gbr_shape_lightning_model, gbr_shape_torch_model = LightningModel.extract_mitosis_model(
    DenseNet,
    gbr_shape_model_json,
    loss_func,
    Adam,
    map_location=torch.device(device),
    local_model_path = os.path.join(model_dir, 'shape_feature_lightning_densenet_mitosis'),
    
)



gbr_shape_torch_model.eval()

In [None]:

gbr_dynamic_model_json = f'{model_dir}dynamic_feature_lightning_densenet_mitosis/dynamic_densenet.json'

gbr_dynamic_lightning_model, gbr_dynamic_torch_model = LightningModel.extract_mitosis_model(
    DenseNet,
    gbr_dynamic_model_json,
    loss_func,
    Adam,
    map_location=torch.device(device),
    local_model_path = os.path.join(model_dir, 'dynamic_feature_lightning_densenet_mitosis'),
    
)

gbr_dynamic_torch_model.eval()

In [None]:
class_map_gbr = {
    0: "Non-Mitosis",
    1: "Mitosis"
}

class CustomDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        return torch.tensor(sample).permute(1, 0).float(), label

def make_predictions(model, data, labels, batch_size=32, subset_size=None):
    if subset_size is not None:
        indices = np.random.choice(len(data), size=subset_size, replace=False)
        data = data[indices]
        labels = labels[indices]
    predictions = []
    true_labels = []
    
    dataset = CustomDataset(data, labels)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    
    model.eval()  
    
    with torch.no_grad():
        for batch_data, batch_labels in dataloader:
            batch_data = batch_data.to('cuda' if torch.cuda.is_available() else 'cpu')
            batch_labels = batch_labels.to('cuda' if torch.cuda.is_available() else 'cpu')
            
            model_predictions = model(batch_data)
            probabilities = torch.softmax(model_predictions, dim=1)
            predicted_classes = torch.argmax(probabilities, dim=1)
            
            predictions.extend(predicted_classes.cpu().numpy())
            true_labels.extend(batch_labels.cpu().numpy())
    
    return np.array(predictions), np.array(true_labels)


In [None]:
subset_size = 10000

shape_dividing_preds, shape_dividing_labels = make_predictions(gbr_shape_torch_model, val_shape_dividing_arrays, val_shape_dividing_labels, batch_size=32, subset_size = subset_size)
dynamic_dividing_preds, dynamic_dividing_labels = make_predictions(gbr_dynamic_torch_model, val_dynamic_dividing_arrays, val_dynamic_dividing_labels, batch_size=32, subset_size = subset_size)

shape_non_dividing_preds, shape_non_dividing_labels = make_predictions(gbr_shape_torch_model, val_shape_non_dividing_arrays, val_shape_non_dividing_labels, batch_size=32, subset_size = subset_size)
dynamic_non_dividing_preds, dynamic_non_dividing_labels = make_predictions(gbr_dynamic_torch_model, val_dynamic_non_dividing_arrays, val_dynamic_non_dividing_labels, batch_size=32, subset_size = subset_size)


all_preds = np.concatenate([shape_dividing_preds, dynamic_dividing_preds, shape_non_dividing_preds, dynamic_non_dividing_preds])
all_labels = np.concatenate([shape_dividing_labels, dynamic_dividing_labels, shape_non_dividing_labels, dynamic_non_dividing_labels])

conf_matrix = confusion_matrix(all_labels, all_preds)
print(conf_matrix)

In [None]:
def plot_confusion_matrix(true_labels, predictions, class_names):
    conf_matrix = confusion_matrix(true_labels, predictions)
    
    plt.figure(figsize=(10, 7))
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.xlabel('Predicted Labels')
    plt.ylabel('True Labels')
    plt.title('Confusion Matrix')
    plt.savefig('classification_metrics/confusion_matrix_cell_fate_mitosis,png', dpi=300)
    plt.show()

class_names = class_map_gbr.values()
plot_confusion_matrix(all_labels, all_preds, class_names)