# Import necessary libraries

In [None]:
from saltup.ai.classification.datagenerator import ClassificationDataloader, pytorch_ClassificationDataGenerator
from saltup.ai.base_dataformat.base_datagen import *
from saltup.ai.training.app_callbacks import ClassificationEvaluationsCallback
from saltup.ai.classification.evaluate import evaluate_model
from saltup.utils.jupyter_notebook import generate_notebook_id, save_current_notebook
from saltup.ai.training.train import training
from saltup.utils.data.image.image_utils import Image, ColorMode

import zipfile
import os
import numpy as np
from torch.utils.data import DataLoader as pytorch_DataGenerator
from glob import glob
from datetime import datetime
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam

In [None]:
#download the dataset
!wget --no-check-certificate https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip -O cats_and_dogs.zip

# Unzip the file using Python
with zipfile.ZipFile("cats_and_dogs.zip", "r") as zip_ref:
    extract_dir = "dataset"
    if not os.path.exists(extract_dir):
        os.makedirs(extract_dir)
    zip_ref.extractall(extract_dir)
    
    os.remove("cats_and_dogs.zip")

# Define constants

In [None]:
NUM_CLASSES = 2
BATCH_SIZE = 32
INPUT_SIZE = (128, 128)
EPOCHS = 4
CALLBACK_EPOCH = 3  # Test the model every CALLBACK_EPOCH epochs
TRAIN_DATA_DIR = './dataset/cats_and_dogs_filtered/train'
TEST_DATA_DIR = './dataset/cats_and_dogs_filtered/validation'
CLASS_NAMES = ['cats', 'dogs']
CLASS_DICTIONARY = {'cats': 0, 'dogs': 1}

# Load a pre-trained model or define your own architecture

In [None]:
import torch.nn as nn
import torch.nn.functional as F
import keras
#Define the model architecture

class CNN2Class(nn.Module):
    def __init__(self, input_shape=(3, 128, 128), num_classes=NUM_CLASSES):
        super(CNN2Class, self).__init__()
        self.input_shape = input_shape
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3, 4, kernel_size=3, padding=1)
        self.pool1 = nn.MaxPool2d(3)
        self.bn1 = nn.BatchNorm2d(4)
        self.drop1 = nn.Dropout(0.1)

        self.conv2 = nn.Conv2d(4, 8, kernel_size=3, padding=1)
        self.pool2 = nn.MaxPool2d(3)
        self.bn2 = nn.BatchNorm2d(8)
        self.drop2 = nn.Dropout(0.1)

        self.conv3 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.pool3 = nn.MaxPool2d(3)
        self.bn3 = nn.BatchNorm2d(16)
        self.drop3 = nn.Dropout(0.1)

        # Calculate the correct flattened dimension after convolutions and pooling
        # Input: (3, 128, 128)
        # After conv1 + pool1: (4, 42, 42)
        # After conv2 + pool2: (8, 14, 14)
        # After conv3 + pool3: (16, 4, 4)
        self.flat_dim = 16 * 4 * 4
        self.fc = nn.Linear(self.flat_dim, self.num_classes)

    def forward(self, x):
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.drop1(x)
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.drop2(x)
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = self.drop3(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)
    

def build_model(input_shape=(128, 128, 3), num_classes=2):
    inputs = keras.Input(shape=input_shape)
    x = keras.layers.Conv2D(8, (3, 3), activation='relu')(inputs)
    x = keras.layers.MaxPooling2D()(x)
    x = keras.layers.Conv2D(16, (3, 3), activation='relu')(x)
    x = keras.layers.MaxPooling2D()(x)
    x = keras.layers.Flatten()(x)
    x = keras.layers.Dense(64, activation='relu')(x)
    outputs = keras.layers.Dense(num_classes, activation='softmax')(x)
    model = keras.Model(inputs, outputs)
    return model




# Data generator

In [None]:
#Define the preprocessing function

def preprocess(image:np.ndarray, target_size:tuple) -> np.ndarray:
    """Preprocess the image by resizing and normalizing."""
    temp_image = Image(image)
    temp_image = temp_image.resize(target_size)
    img = temp_image.get_data()
    img = img / 255.0  # Normalize pixel values between [0, 1]
    return img

In [None]:
#Define the augmentation transformations
transformed_img = A.Compose([
        A.HorizontalFlip(p=0.5),
        A.RandomBrightnessContrast(p=0.2),
        A.GaussianBlur(blur_limit=(3, 7), p=0.5)
   ])

In [None]:
#Create the data loaders

train_dataloader = ClassificationDataloader(
    source=TRAIN_DATA_DIR,
    classes_dict=CLASS_DICTIONARY,
    img_size=(224, 224, 3)
)

test_dataloader = ClassificationDataloader(
    source=TEST_DATA_DIR,
    classes_dict=CLASS_DICTIONARY,
    img_size=(224, 224, 3)
)

In [None]:
# Create Classification Data Generator

train_gen = pytorch_ClassificationDataGenerator(
    dataloader=train_dataloader,
    target_size=INPUT_SIZE,
    num_classes=NUM_CLASSES,
    batch_size=BATCH_SIZE,
    preprocess=preprocess,
    transform=transformed_img
)


test_gen = pytorch_ClassificationDataGenerator(
    dataloader=test_dataloader,
    target_size=INPUT_SIZE,
    num_classes=NUM_CLASSES,
    batch_size=BATCH_SIZE,
    preprocess=preprocess,
    transform=None  # no augmentation
)

callback_test_data = pytorch_DataGenerator(test_gen, batch_size=BATCH_SIZE, shuffle=False)

images, labels = next(iter(train_gen))
print("image shape", images.shape)
print("label shape", labels.shape)

In [None]:
# Plot a single example image and its label
import matplotlib.pyplot as plt

def plot_image(image, label):
    # If image is a torch tensor, convert to numpy
    if hasattr(image, 'detach'):
        image = image.detach().cpu().numpy()
    # If image has shape (C, H, W), transpose to (H, W, C)
    if image.ndim == 3 and image.shape[0] in [1, 3]:
        image = image.transpose(1, 2, 0)
    plt.imshow(image.squeeze())
    plt.title(f"Label: {label}")
    #plt.axis('off')
    plt.show()

example_image = train_gen[0][0][0]  # Get the first image from the first batch
example_label = train_gen[0][1][0]  # Get the corresponding label
print(f"Example image shape: {example_image.shape}")
plot_image(example_image, example_label)

# Training

In [None]:
#Define the output directory and create it if it doesn't exist
todaytime = datetime.now()
output_dir = "./training_outputs"
current_tests_folder_name = "train_{}".format(todaytime.strftime("%d-%m-%Y_%H-%M-%S"))
current_output_dir = os.path.join(output_dir, current_tests_folder_name)
if not os.path.exists(current_output_dir):
    os.makedirs(current_output_dir)
    

custom_cb = ClassificationEvaluationsCallback(
    datagen=callback_test_data,
    end_of_train_datagen=callback_test_data,
    every_epoch=CALLBACK_EPOCH,
    output_file=os.path.join(current_output_dir, "classification_evaluations.txt"),
    class_names=CLASS_NAMES
)

In [None]:
# Define model compilation parameters
torch_model = CNN2Class(num_classes=NUM_CLASSES)

#model = build_model(input_shape=(128, 128, 3), num_classes=NUM_CLASSES)
optimizer =  Adam(torch_model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()


# Define k-fold parameters
# This is set to False for simplicity, but can be enabled for k-fold cross-validation
# If enabled, the split parameter defines the proportion of data in each fold
kfold_parameters = {'enable':False, 'split':[0.2, 0.2, 0.2, 0.2, 0.2]}

# Define the model output name
model_output_name = "tiny_model"

# Start the training process
results_dict = training(
        train_gen,
        model=torch_model,
        loss_function=loss_function,
        optimizer=optimizer,
        epochs=EPOCHS,
        output_dir=current_output_dir,
        validation_split=[0.2, 0.8],
        kfold_param =kfold_parameters,
        model_output_name = model_output_name,
        training_callback=[custom_cb])

In [None]:
import onnx
from onnx2keras import onnx_to_keras

# Load ONNX model
onnx_model = onnx.load('./training_outputs/train_05-09-2025_16-44-06/saved_models/tiny_model_best.onnx')

# Now convert to Keras
k_model = onnx_to_keras(onnx_model, ['input'])

In [None]:
scripted_model = torch.jit.script(torch_model)
scripted_model.save("model.pt")

torch.save(torch_model.state_dict(), "model.pth")


In [None]:
model_torch = torch.load("model.pth")

In [None]:
train_gen[0][0].shape

In [None]:
img = torch.rand(3, 128, 128)

output = torch_model(train_gen[0][0].unsqueeze(0))
print(output)

In [None]:
print(torch_model)

# Inference on test dataset

In [None]:
# Check the results and evaluate the model
#model_path = [res for res in results_dict['models_paths'] if res.endswith('.onnx')][0]

model_path = "model.pt"

global_metric, metric_per_class = evaluate_model(
    model_path, 
    test_gen=callback_test_data,
    output_dir=current_output_dir,
    conf_matrix=True
)

# Print the evaluation results
print("Global metrics:")
print("FP:", global_metric.getFP())
print("FN:", global_metric.getFN())
print("Accuracy:", global_metric.getAccuracy())

print("\nPer-class metrics:")
for idx, class_name in enumerate(CLASS_NAMES):
    print(f"Class: {class_name}")
    print("  FP:", metric_per_class[idx].getFP())
    print("  FN:", metric_per_class[idx].getFN())
    print("  Accuracy:", metric_per_class[idx].getAccuracy())

In [None]:
# Save the current notebook with the results. This is done at the end to ensure all outputs are captured.
save_current_notebook(current_output_dir)