In [None]:
from google.colab import drive
drive.mount('/content/drive')

%cd /content/drive/MyDrive/TSA_Project/

Mounted at /content/drive
/content/drive/MyDrive/TSA_Project


In [None]:
import os
import pickle
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import flax.linen as nn
import optax
from flax.training import train_state
from tqdm.auto import tqdm
from sklearn.metrics import roc_auc_score, roc_curve, precision_recall_fscore_support
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator

In [None]:
# Hyperparameters
learning_rate = 0.0001
batch_size = 32
num_epochs = 30
rng = random.PRNGKey(0)

# Info for saving model
model_folder = "model_1"
os.makedirs(model_folder, exist_ok=True)

In [None]:
# Ensure JAX uses GPU
jax.config.update('jax_platform_name', 'gpu')

# Data Loading and Preprocessing
def load_and_preprocess_data(directory, batch_size=32, img_height=224, img_width=224):
    datagen = ImageDataGenerator(rescale=1./255, validation_split=0.2)
    train_generator = datagen.flow_from_directory(directory, target_size=(img_height, img_width),
                                                  batch_size=batch_size, class_mode='categorical', subset='training')
    validation_generator = datagen.flow_from_directory(directory, target_size=(img_height, img_width),
                                                       batch_size=batch_size, class_mode='categorical', subset='validation')
    return train_generator, validation_generator

# Model Definition
class ResNet(nn.Module):
    @nn.compact
    def __call__(self, x, training=True):
        x = nn.Conv(64, (7, 7), strides=(2, 2), padding='SAME', use_bias=False)(x)
        # x = nn.BatchNorm(use_running_average=False)(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        for _ in range(5):
            x = ResBlock(64)(x)
        x = nn.avg_pool(x, (7, 7))
        x = x.reshape((x.shape[0], -1))
        x = nn.Dense(4)(x)
        if not training:
            x = nn.softmax(x)
        return x

class ResBlock(nn.Module):
    features: int
    strides: tuple = (1, 1)
    @nn.compact
    def __call__(self, x):
        conv_args = {'use_bias': False, 'kernel_size': (3, 3), 'padding': 'SAME'}
        # norm = nn.BatchNorm(use_running_average=False)
        y = nn.Conv(self.features, **conv_args)(x)
        # y = norm(y)
        y = nn.relu(y)
        y = nn.Conv(self.features, strides=self.strides, **conv_args)(y)
        # y = norm(y)
        if x.shape != y.shape:
            x = nn.Conv(self.features, kernel_size=(1, 1), strides=self.strides, use_bias=False)(x)
            # x = norm(x)
        return nn.relu(y + x)

# Train state creation
def create_train_state(rng, learning_rate):
    model = ResNet()
    params = model.init(rng, jnp.ones([1, 224, 224, 3]))['params']
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

# Training and Evaluation Functions
def train_epoch(state, train_data, batch_size, rng):
    train_loss = []
    for i, (inputs, targets) in enumerate(tqdm(train_data)):
        if i >= len(train_data): break
        inputs, targets = jnp.array(inputs), jnp.array(targets)
        loss, state = train_step(state, inputs, targets, rng)
        train_loss.append(loss)
    return np.mean(train_loss), state

# Training step function
def train_step(state, inputs, targets, rng):
    def loss_fn(params):
        logits = state.apply_fn({'params': params}, inputs)
        loss = optax.softmax_cross_entropy(logits=logits, labels=targets).mean()
        return loss
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    state = state.apply_gradients(grads=grads)
    return loss, state

# Evaluation function
def evaluate(state, val_data):
    metrics = {'loss': [], 'accuracy': []}
    for i, (inputs, targets) in enumerate(tqdm(val_data)):
        if i >= len(val_data): break
        inputs, targets = jnp.array(inputs), jnp.array(targets)
        logits = state.apply_fn({'params': state.params}, inputs)
        loss = optax.softmax_cross_entropy(logits=logits, labels=targets).mean()
        accuracy = (logits.argmax(1) == targets.argmax(1)).mean()
        metrics['loss'].append(loss.item())
        metrics['accuracy'].append(accuracy.item())
    return {k: np.mean(v) for k, v in metrics.items()}

# Load data
train_data, val_data = load_and_preprocess_data('dataset/')

# Initialize state
state = create_train_state(rng, learning_rate)

# Training loop
best_acc = 0.0
for epoch in range(num_epochs):
    print(f'Starting Epoch {epoch+1}/{num_epochs}')
    rng, input_rng = random.split(rng)
    train_loss, state = train_epoch(state, train_data, batch_size, input_rng)
    metrics = evaluate(state, val_data)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {metrics["loss"]:.4f}, Validation Accuracy: {metrics["accuracy"]:.4f}')

    # Save the model if it has the best accuracy
    if metrics['accuracy'] > best_acc:
        best_acc = metrics['accuracy']
        with open(f'{model_folder}/model.pkl', 'wb') as f:
            pickle.dump(state.params, f)
            print(f'Best model saved to {model_folder}/model.pkl.')

    model_accuracy = metrics['accuracy']
    model_loss = metrics['loss']
    with open(f'{model_folder}/epoch_{epoch}_of_{num_epochs}_accuracy_{model_accuracy}_loss_{model_loss}.pkl', 'wb') as f:
      pickle.dump(state.params, f)
      print(f"Model saved to {model_folder}/epoch_{epoch}_of_{num_epochs}_accuracy_{model_accuracy}_loss_{model_loss}.pkl")

Found 3384 images belonging to 4 classes.
Found 843 images belonging to 4 classes.
Starting Epoch 1/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 1, Train Loss: 1.3517, Validation Loss: 0.9814, Validation Accuracy: 0.5491
Best model saved to model_1/model.pkl.
Model saved to model_1/epoch_0_of_30_accuracy_0.5491372059892725_loss_0.9813518170957212.pkl
Starting Epoch 2/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 2, Train Loss: 0.8040, Validation Loss: 0.8352, Validation Accuracy: 0.6151
Best model saved to model_1/model.pkl.
Model saved to model_1/epoch_1_of_30_accuracy_0.6151094282114947_loss_0.8351839758731701.pkl
Starting Epoch 3/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 3, Train Loss: 0.6056, Validation Loss: 0.9341, Validation Accuracy: 0.5886
Model saved to model_1/epoch_2_of_30_accuracy_0.5885942765959987_loss_0.9340691831376817.pkl
Starting Epoch 4/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 4, Train Loss: 0.4766, Validation Loss: 0.6655, Validation Accuracy: 0.7188
Best model saved to model_1/model.pkl.
Model saved to model_1/epoch_3_of_30_accuracy_0.71875_loss_0.665452089022707.pkl
Starting Epoch 5/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 5, Train Loss: 0.4175, Validation Loss: 0.9954, Validation Accuracy: 0.5767
Model saved to model_1/epoch_4_of_30_accuracy_0.5767045462573016_loss_0.9953696440767359.pkl
Starting Epoch 6/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 6, Train Loss: 0.4000, Validation Loss: 0.9046, Validation Accuracy: 0.6116
Model saved to model_1/epoch_5_of_30_accuracy_0.6116372059892725_loss_0.9046219322416518.pkl
Starting Epoch 7/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 7, Train Loss: 0.3942, Validation Loss: 0.6790, Validation Accuracy: 0.6842
Model saved to model_1/epoch_6_of_30_accuracy_0.6842382174951059_loss_0.6790322164694468.pkl
Starting Epoch 8/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 8, Train Loss: 0.4131, Validation Loss: 1.1520, Validation Accuracy: 0.5908
Model saved to model_1/epoch_7_of_30_accuracy_0.5908038726559391_loss_1.1519865923457675.pkl
Starting Epoch 9/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 9, Train Loss: 0.3429, Validation Loss: 0.8023, Validation Accuracy: 0.6534
Model saved to model_1/epoch_8_of_30_accuracy_0.6534090914108135_loss_0.8022554538868092.pkl
Starting Epoch 10/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 10, Train Loss: 0.3068, Validation Loss: 1.0339, Validation Accuracy: 0.5954
Model saved to model_1/epoch_9_of_30_accuracy_0.5954335022855688_loss_1.0338569239333824.pkl
Starting Epoch 11/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 11, Train Loss: 0.2890, Validation Loss: 0.7296, Validation Accuracy: 0.6785
Model saved to model_1/epoch_10_of_30_accuracy_0.6784511804580688_loss_0.7295970762217486.pkl
Starting Epoch 12/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 12, Train Loss: 0.2765, Validation Loss: 1.0253, Validation Accuracy: 0.5895
Model saved to model_1/epoch_11_of_30_accuracy_0.5895412453898677_loss_1.0252691352808918.pkl
Starting Epoch 13/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 13, Train Loss: 0.2682, Validation Loss: 0.7869, Validation Accuracy: 0.6554
Model saved to model_1/epoch_12_of_30_accuracy_0.6554082499610053_loss_0.7869342234399583.pkl
Starting Epoch 14/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 14, Train Loss: 0.2779, Validation Loss: 0.7805, Validation Accuracy: 0.6924
Model saved to model_1/epoch_13_of_30_accuracy_0.6924452869980423_loss_0.7805014374079527.pkl
Starting Epoch 15/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 15, Train Loss: 0.2519, Validation Loss: 1.1088, Validation Accuracy: 0.5886
Model saved to model_1/epoch_14_of_30_accuracy_0.5885942765959987_loss_1.1088248447135642.pkl
Starting Epoch 16/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 16, Train Loss: 0.2480, Validation Loss: 0.7584, Validation Accuracy: 0.7029
Model saved to model_1/epoch_15_of_30_accuracy_0.702861953664709_loss_0.7584205446419893.pkl
Starting Epoch 17/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 17, Train Loss: 0.2453, Validation Loss: 0.8263, Validation Accuracy: 0.6219
Model saved to model_1/epoch_16_of_30_accuracy_0.6219486527972751_loss_0.8262578436621913.pkl
Starting Epoch 18/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 18, Train Loss: 0.2384, Validation Loss: 0.8014, Validation Accuracy: 0.7074
Model saved to model_1/epoch_17_of_30_accuracy_0.7073863656432541_loss_0.8014487028121948.pkl
Starting Epoch 19/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 19, Train Loss: 0.2463, Validation Loss: 0.6241, Validation Accuracy: 0.7237
Best model saved to model_1/model.pkl.
Model saved to model_1/epoch_18_of_30_accuracy_0.7236952869980423_loss_0.6241147087679969.pkl
Starting Epoch 20/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 20, Train Loss: 0.2205, Validation Loss: 0.7572, Validation Accuracy: 0.6843
Model saved to model_1/epoch_19_of_30_accuracy_0.6843434351461904_loss_0.7571600245104896.pkl
Starting Epoch 21/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 21, Train Loss: 0.2131, Validation Loss: 0.6546, Validation Accuracy: 0.7305
Best model saved to model_1/model.pkl.
Model saved to model_1/epoch_20_of_30_accuracy_0.7305345137914022_loss_0.6546022130383385.pkl
Starting Epoch 22/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 22, Train Loss: 0.2082, Validation Loss: 1.1235, Validation Accuracy: 0.6497
Model saved to model_1/epoch_21_of_30_accuracy_0.6497264305750529_loss_1.123515789155607.pkl
Starting Epoch 23/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 23, Train Loss: 0.2063, Validation Loss: 0.6565, Validation Accuracy: 0.7261
Model saved to model_1/epoch_22_of_30_accuracy_0.7261153194639418_loss_0.6565069081606688.pkl
Starting Epoch 24/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 24, Train Loss: 0.1865, Validation Loss: 0.6726, Validation Accuracy: 0.7353
Best model saved to model_1/model.pkl.
Model saved to model_1/epoch_23_of_30_accuracy_0.7352693610721164_loss_0.6726313774232511.pkl
Starting Epoch 25/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 25, Train Loss: 0.1854, Validation Loss: 0.8995, Validation Accuracy: 0.6901
Model saved to model_1/epoch_24_of_30_accuracy_0.6901304721832275_loss_0.8994558895075763.pkl
Starting Epoch 26/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 26, Train Loss: 0.1877, Validation Loss: 0.8053, Validation Accuracy: 0.7203
Model saved to model_1/epoch_25_of_30_accuracy_0.7203282824269047_loss_0.8052720559967889.pkl
Starting Epoch 27/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 27, Train Loss: 0.1757, Validation Loss: 0.9037, Validation Accuracy: 0.6914
Model saved to model_1/epoch_26_of_30_accuracy_0.6913930972417196_loss_0.9037355951688908.pkl
Starting Epoch 28/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 28, Train Loss: 0.1710, Validation Loss: 0.9293, Validation Accuracy: 0.6857
Model saved to model_1/epoch_27_of_30_accuracy_0.6857112800633466_loss_0.9292708668443892.pkl
Starting Epoch 29/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 29, Train Loss: 0.1574, Validation Loss: 1.3151, Validation Accuracy: 0.6104
Model saved to model_1/epoch_28_of_30_accuracy_0.610374578723201_loss_1.3151011886420074.pkl
Starting Epoch 30/30


  0%|          | 0/106 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

Epoch 30, Train Loss: 0.1622, Validation Loss: 0.9067, Validation Accuracy: 0.6981
Model saved to model_1/epoch_29_of_30_accuracy_0.6981271063839948_loss_0.9067362800792411.pkl


In [None]:
import pickle

# Step to load the model from the pickle file
def load_pickle(file_path):
    with open(file_path, 'rb') as file:
        obj = pickle.load(file)
    return obj

best_model_params = load_pickle(f'{model_folder}/model.pkl')

In [None]:
print(train_data.class_indices)
print(val_data.class_indices)
print(train_data.class_indices.keys())

{'cataract': 0, 'diabetic_retinopathy': 1, 'glaucoma': 2, 'normal': 3}
{'cataract': 0, 'diabetic_retinopathy': 1, 'glaucoma': 2, 'normal': 3}
dict_keys(['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal'])


In [None]:
class Model:
  def __init__(self, nn_class=ResNet, params=best_model_params, classes=['cataract', 'diabetic_retinopathy', 'glaucoma', 'normal']):
    self.model = nn_class()
    self.params = params
    self.classes = classes

  def apply_model(self, inputs):
    return self.model.apply({'params': self.params}, inputs, training=False)

  # Prediction function to be pickled
  def predict(self, input_image):
      input_processed = np.array(input_image)  # Add actual preprocessing steps as required
      input_processed = input_processed.reshape(1, 224, 224, 3)  # Example reshape
      preds = self.apply_model(jnp.array(input_processed))
      idx = int(np.argmax(preds, axis=1).item())  # Convert logits to class prediction
      return self.classes[idx]

def save_model(filename="ModelObj.pkl", model=None):
    if model is None:
        model = Model()
    with open(filename, "wb") as f:
        pickle.dump(model, f)

In [None]:
save_model(filename="ModelObj.pkl", model=Model())

In [None]:
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score, roc_auc_score

model = Model()

def evaluate_model(model, val_data):
    all_preds = []
    all_labels = []
    all_probs = []  # To store probabilities for AUC calculation

    # Iterate over the validation data
    for i, (inputs, labels) in enumerate(tqdm(val_data)):
        if i >= len(val_data): break
        inputs = jnp.array(inputs)  # Convert to JAX array
        probs = model.apply_model(inputs)  # This should now return probabilities
        preds = probs.argmax(axis=1)
        all_preds.extend(preds)
        all_labels.extend(labels.argmax(axis=1))
        all_probs.extend(probs)

    # Convert lists to arrays for use with scikit-learn
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    # Calculate metrics
    print("Classification Report:")
    print(classification_report(all_labels, all_preds, target_names=list(val_data.class_indices.keys())))

    accuracy = accuracy_score(all_labels, all_preds)
    print("Accuracy:", accuracy)

    # Compute AUC ROC
    if all_probs.shape[1] == 2:  # Binary classification
        auc = roc_auc_score(all_labels, all_probs[:, 1])
    else:  # Multiclass classification
        auc = roc_auc_score(all_labels, all_probs, multi_class='ovr')

    print("AUC ROC:", auc)

# Assuming model is loaded and val_data is prepared
evaluate_model(model, val_data)

  0%|          | 0/27 [00:00<?, ?it/s]

Classification Report:
                      precision    recall  f1-score   support

            cataract       0.84      0.71      0.77       214
diabetic_retinopathy       1.00      1.00      1.00       215
            glaucoma       0.79      0.35      0.48       208
              normal       0.50      0.87      0.64       206

            accuracy                           0.73       843
           macro avg       0.78      0.73      0.72       843
        weighted avg       0.79      0.73      0.72       843

Accuracy: 0.7319098457888493
AUC ROC: 0.9288825135702397
