In [None]:


## MG IMAGE BINARY SPLIT 1

import numpy as np
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential
from sklearn.model_selection import train_test_split
from kerastuner.tuners import RandomSearch

#classificatino
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report
from sklearn.metrics import roc_curve, auc
from sklearn.metrics import precision_recall_curve

##find your data, has it been processed yet? if not, go back to the preprocessing notebook

images = np.load('confocal_data.npy')
labels = np.load('labels.npy')

x_coords = np.load('x_coords.npy')  # Shape: (num_images, )
y_coords = np.load('y_coords.npy')  # Shape: (num_images, )

# Normalize images
images = images / 255.0  # Assuming images are grayscale with 8-bit depth

# Normalize fluorescence information
scaler = MinMaxScaler()
fluorescence_info = scaler.fit_transform(fluorescence_info.reshape(-1, 1))
fluorescence_info = fluorescence_info.squeeze()  # Reshape back to original shape

# Normalize x and y coordinates
x_coords = (x_coords - np.min(x_coords)) / (np.max(x_coords) - np.min(x_coords))
y_coords = (y_coords - np.min(y_coords)) / (np.max(y_coords) - np.min(y_coords))

# Concatenate all features into a single numpy array
# Each feature is a separate channel
features = np.stack((images, fluorescence_info, x_coords, y_coords), axis=-1)

# Normalize the images

images = images / 255.0



x_train, x_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42)


def build_model(hp):
    model = Sequential()

    model.add(layers.Conv2D(
        filters=hp.Int('conv_1_filter', min_value=32, max_value=128, step=16),
        kernel_size=hp.Choice('conv_1_kernel', values=[3,5]),
        activation='relu',
        input_shape=(128, 128, 3)
    ))

    model.add(layers.Conv2D(
        filters=hp.Int('conv_2_filter', min_value=32, max_value=64, step=16),
        kernel_size=hp.Choice('conv_2_kernel', values=[3,5]),
        activation='relu'
    ))

    model.add(layers.Flatten())
    model.add(layers.Dense(
        units=hp.Int('dense_1_units', min_value=32, max_value=128, step=16),
        activation='relu'
    ))

    model.add(layers.Dense(1, activation='sigmoid'))

    model.compile(
        optimizer=tf.keras.optimizers.Adam(hp.Choice('learning_rate', values=[1e-2, 1e-3])),
        loss='binary_crossentropy',
        metrics=['accuracy']
    )

    return model

# Set up Keras Tuner
tuner = RandomSearch(
    build_model,
    objective='val_accuracy',
    max_trials=5,
    executions_per_trial=3,
    directory='output',
    project_name='Microglia_Classifier'
)


tuner.search(x_train, y_train, epochs=5, validation_data=(x_test, y_test))


best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]

# FIND THE RIGHT HYPERPATAMETERS FOR YOU
model = tuner.hypermodel.build(best_hps)
history = model.fit(x_train, y_train, epochs=50, validation_data=(x_test, y_test))

# Evaluate the model
_, accuracy = model.evaluate(x_test, y_test)
print(f'Accuracy: {accuracy * 100}')

# If you want to save the model
model.save('best_microglia_model.h5')



# Make predictions on the test set
y_pred = model.predict_classes(x_test)
confusion_mtx = confusion_matrix(y_test, y_pred)
print('Confusion Matrix:\n', confusion_mtx)

# CLASSIFICATION REPORT
classification_rep = classification_report(y_test, y_pred)


print('Classification Report:\n', classification_rep)

#ROC AUC
fpr, tpr, thresholds = roc_curve(y_test, model.predict(x_test))


roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()


precision, recall, thresholds = precision_recall_curve(y_test, model.predict(x_test))

# Plot
plt.figure()
plt.plot(recall, precision, label='Precision-Recall curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.show()


#SAVE SOMTHING




