# Mushroom Toxicity Classifier (Keras + EfficientNetB0)

This notebook trains a binary image classifier to predict whether a mushroom is toxic (1) or edible (0) using transfer learning with EfficientNetB0.

It includes:

- Automatic Kaggle download of a mushroom images dataset (MO-106, 94 species, ~27k images).
- A species->toxicity mapping step to build a binary folder-of-folders dataset (edible/, toxic/).
- A tf.data pipeline with AUTOTUNE for performance.
- EfficientNetB0 feature extraction + fine-tuning.
- Class weights and threshold tuning to reduce false edible on toxic images.
- (Optional) a tiny scratch CNN baseline for comparison.

> References

> - EfficientNet transfer learning (Keras official example): https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/

> - EfficientNetB0 API docs: https://www.tensorflow.org/api_docs/python/tf/keras/applications/EfficientNetB0

> - Kaggle MO-106 (94 species): https://www.kaggle.com/datasets/iftekhar08/mo-106


In [14]:
# Optional: install/upgrade packages if needed
#!pip install -U tensorflow scikit-learn kaggle

import os, glob, shutil, zipfile
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
print('TensorFlow:', tf.__version__)


TensorFlow: 2.20.0


## 1) Kaggle setup (one-time)

This notebook uses Kaggle CLI to download the dataset. You need an API token at `~/.kaggle/kaggle.json`.

How to obtain and place it:

1. Go to https://www.kaggle.com/ -> Profile -> Account -> Create New API Token (downloads kaggle.json).

2. Upload it to this runtime at `~/.kaggle/kaggle.json` and set permissions `0o600`.

Uncomment and run the cell below if you need to place the token programmatically.


In [None]:
# PLACE YOUR KAGGLE TOKEN (paste the JSON content between triple quotes)
kaggle_token = '''
{"username":"maurobonacina","key":"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"}
 '''
# import os
# os.makedirs(os.path.expanduser('~/.kaggle'), exist_ok=True)
# with open(os.path.expanduser('~/.kaggle/kaggle.json'), 'w') as f:
#     f.write(kaggle_token.strip())
# os.chmod(os.path.expanduser('~/.kaggle/kaggle.json'), 0o600)


## 2) Download & unpack MO-106 (images) from Kaggle

We will download MO-106 (94 species, ~27k images) and unpack under `data/mo106_raw/`.


In [4]:
DATA_DIR = 'data'
RAW_DIR  = os.path.join(DATA_DIR, 'mo106_raw')
os.makedirs(DATA_DIR, exist_ok=True)
print('DATA_DIR:', DATA_DIR)

# Download (uncomment in a real runtime with Kaggle CLI available)
# !pip -q install kaggle
#!kaggle datasets download -d iftekhar08/mo-106 -p {DATA_DIR} -o

print('Expect raw dataset under:', os.path.join(os.getcwd(),RAW_DIR))


DATA_DIR: data
Expect raw dataset under: C:\Users\MauroBonacina\data\mo106_raw


In [9]:
#unzip
for z in glob.glob(os.path.join(DATA_DIR, '*.zip')):
     print('Unzipping:', z)
     with zipfile.ZipFile(z, 'r') as f:
         f.extractall(RAW_DIR)

Unzipping: data\mo-106.zip


## 3) Build a binary dataset (edible vs toxic)

The MO-106 dataset is organized by species, with each subfolder representing one species.
To build a binary classifier, we need to map each species as either "edible" or "toxic"
and then copy the corresponding images into two separate folders:

data/mushrooms_binary/

   ├── edible/
   
   └── toxic/

A file named 'species_edibility_lookup.csv' will be generated, listing all species found
by scanning the subdirectory names in the archive.
The challenge is that the original dataset does not include any toxicity information.
You can update the toxicity labels manually in the generated CSV file.
As a starting point, this notebook will attempt to retrieve edibility information from Wikipedia.



In [None]:
# ===================================================================================
# === 1) Dataset directory scan with species_edibility_lookup.csv file generation ===
# ===================================================================================

import os, re, shutil
from pathlib import Path
import pandas as pd

# Percorso locale con le cartelle per specie (Windows)
SOURCE_ROOT = Path(os.path.join(RAW_DIR,"MO_94"))
DATA_DIR = Path('data')
BIN_DIR  = Path(os.path.join(DATA_DIR,'mushrooms_binary'))
EDIBLE_DIR = Path(os.path.join(BIN_DIR, 'edible'))
TOXIC_DIR  = Path(os.path.join(BIN_DIR, 'toxic'))
DATA_DIR.mkdir(parents=True, exist_ok=True)

print('SOURCE_ROOT =', SOURCE_ROOT)
assert SOURCE_ROOT.is_dir(), f"Il path esiste ma non è una cartella: {SOURCE_ROOT}"

# === 1) SCANSIONA CARTELLE E CREA species_edibility_lookup.csv ===
LOOKUP_CSV = Path('species_edibility_lookup.csv')

species_dirs = sorted([p for p in SOURCE_ROOT.iterdir() if p.is_dir()])
print('Cartelle specie trovate:', len(species_dirs))

import unicodedata
def normalize_ws(s: str) -> str:
    s = unicodedata.normalize('NFKC', s)
    s = re.sub(r"\s+", " ", s.strip())
    return s

def split_genus_species(folder_name: str):
    name = normalize_ws(folder_name)
    parts = name.split(' ')
    if len(parts) >= 2:
        return parts[0], ' '.join(parts[1:])
    return parts[0], ''

rows = []
for sp in species_dirs:
    folder = sp.name
    g, s = split_genus_species(folder)
    rows.append({'genus': g, 'species': s, 'folder_name': folder, 'edibility': 'unknown'})

df_lookup = pd.DataFrame(rows, columns=['genus','species','folder_name','edibility'])
df_lookup.to_csv(LOOKUP_CSV, index=False)
print('Scritto:', LOOKUP_CSV.resolve())
df_lookup.head()





SOURCE_ROOT = data\mo106_raw\MO_94
Cartelle specie trovate: 94
Scritto: C:\Users\MauroBonacina\species_edibility_lookup.csv


Unnamed: 0,genus,species,folder_name,edibility
0,Agaricus,augustus,Agaricus augustus,unknown
1,Agaricus,xanthodermus,Agaricus xanthodermus,unknown
2,Amanita,amerirubescens,Amanita amerirubescens,unknown
3,Amanita,augusta,Amanita augusta,unknown
4,Amanita,brunnescens,Amanita brunnescens,unknown


In [None]:
# ===========================================================================
# 2) (OPZIONAL) AUTO-WIKI LOOKUP to udate 'unknown' edibility entries
#    Build URL: https://en.wikipedia.org/wiki/<Genus>_<species>
#    Example: Agaricus augustus -> Agaricus_augustus
#    Search for keywords in the page content to determine edibility
#   - toxic keywords: "poisonous", "toxic", "inedible"
#   - edible keywords: "choice edible", "edible"
#   Update the CSV file with the found edibility information
#   Log results to wiki_lookup_log.csv
# ===========================================================================
AUTO_WIKI = True  # set to False to disable auto-lookup on Wikipedia

if AUTO_WIKI:
    import requests
    from time import sleep
    from urllib.parse import quote
    import re

    BASE = "https://en.wikipedia.org/wiki/"
    HEADERS = {"User-Agent": "MushroomEdibilityDemo/1.0 (contact: demo@example.com)"}
    TIMEOUT = 10
    SLEEP_S = 0.2

    TOK_TOXIC  = (r"\bpoisonous\b", r"\btoxic\b", r"\binedible\b")
    TOK_EDIBLE = (r"\bchoice edible\b", r"\bedible\b")

    df_lookup = pd.read_csv(LOOKUP_CSV)
    results = []
    updated = 0

    for idx, row in df_lookup[df_lookup['edibility'].astype(str).str.lower().eq('unknown')].iterrows():
        genus   = str(row['genus']).strip()
        species = str(row['species']).strip()
        title = f"{genus.capitalize()}_{species.replace(' ', '_')}".strip('_')
        url   = BASE + quote(title)

        try:
            r = requests.get(url, headers=HEADERS, timeout=TIMEOUT)
            status = r.status_code
            body   = r.text.lower() if status == 200 else ""
        except Exception as e:
            status = f"ERR:{e.__class__.__name__}"
            body   = ""

        decision = "unknown"; reason = ""
        if isinstance(status, int) and status == 200 and len(body) > 500:
            if any(re.search(p, body) for p in TOK_TOXIC):
                decision, reason = "toxic", "keyword toxic/poisonous/inedible"
            elif any(re.search(p, body) for p in TOK_EDIBLE):
                decision, reason = "edible", "keyword (choice) edible"
        else:
            reason = f"http_status={status}"

        if decision in ("edible","toxic"):
            df_lookup.at[idx, 'edibility'] = decision
            updated += 1

        results.append({
            "folder_name": row.get("folder_name", f"{genus} {species}"),
            "genus": genus, "species": species,
            "wiki_url": url, "http_status": status,
            "decision": decision, "reason": reason
        })
        print (results[-1])
        sleep(SLEEP_S)

    df_lookup.to_csv(LOOKUP_CSV, index=False)
    log_path = Path("wiki_lookup_log.csv")
    pd.DataFrame(results).to_csv(log_path, index=False, encoding="utf-8")
    print(f"Updated via Wikipedia: {updated}")
    print("Log:", log_path.resolve())

{'folder_name': 'Agaricus augustus', 'genus': 'Agaricus', 'species': 'augustus', 'wiki_url': 'https://en.wikipedia.org/wiki/Agaricus_augustus', 'http_status': 200, 'decision': 'toxic', 'reason': 'keyword toxic/poisonous/inedible'}
{'folder_name': 'Agaricus xanthodermus', 'genus': 'Agaricus', 'species': 'xanthodermus', 'wiki_url': 'https://en.wikipedia.org/wiki/Agaricus_xanthodermus', 'http_status': 200, 'decision': 'toxic', 'reason': 'keyword toxic/poisonous/inedible'}
{'folder_name': 'Amanita amerirubescens', 'genus': 'Amanita', 'species': 'amerirubescens', 'wiki_url': 'https://en.wikipedia.org/wiki/Amanita_amerirubescens', 'http_status': 404, 'decision': 'unknown', 'reason': 'http_status=404'}
{'folder_name': 'Amanita augusta', 'genus': 'Amanita', 'species': 'augusta', 'wiki_url': 'https://en.wikipedia.org/wiki/Amanita_augusta', 'http_status': 200, 'decision': 'toxic', 'reason': 'keyword toxic/poisonous/inedible'}
{'folder_name': 'Amanita brunnescens', 'genus': 'Amanita', 'species': 

In [None]:
########################################################################################################
# 3) COPY IMAGES INTO edible/ AND toxic/ BASED ON THE CSV
# This step organizes the dataset into two folders according to the edibility labels in the CSV file.
# Species marked as 'unknown' will be skipped by default.
# If you prefer to include them, you can modify the logic to assign a default class (e.g., toxic).
########################################################################################################

# destination cleanup
if BIN_DIR.exists():
    shutil.rmtree(BIN_DIR)
EDIBLE_DIR.mkdir(parents=True, exist_ok=True)
TOXIC_DIR.mkdir(parents=True, exist_ok=True)

valid_ext = {'.jpg','.jpeg','.png','.bmp','.gif','.tif','.tiff'}
stats = {'edible':0, 'toxic':0, 'skipped_species':0}

df_lookup = pd.read_csv(LOOKUP_CSV)

for _, row in df_lookup.iterrows():
    label = row['edibility']
    folder = row['folder_name']
    src_dir = SOURCE_ROOT / folder
    if label not in ('edible','toxic'):
        stats['skipped_species'] += 1
        continue
    dest_dir = EDIBLE_DIR if label=='edible' else TOXIC_DIR
    for img in src_dir.glob('*'):
        if img.suffix.lower() in valid_ext and img.is_file():
            new_name = f"{folder}__{img.name}".replace('/', '_')  # evita collisioni nomi
            shutil.copy2(str(img), str(dest_dir / new_name))
            stats[label] += 1

print('Copy stats:', stats)
print('Binary root:', BIN_DIR.resolve())

Copy stats: {'edible': 4006, 'toxic': 22265, 'skipped_species': 4}
Binary root: C:\Users\MauroBonacina\data\mushrooms_binary


## 4) tf.data pipeline (with AUTOTUNE)

We build a train/validation split directly from the `edible/` and `toxic/` folders.
- Image Preprocessing: Images are resized and normalized to prepare them for training.
- Data Augmentation: Applied only to training data to improve generalization.
- Dataset Splitting: 80% for training, 20% for validation using a fixed seed for reproducibility.
- Class Weights: Calculated to address class imbalance, useful for training models that are sensitive to skewed class distributions.

In [9]:
# Set the target size for image resizing (width, height)
IMG_SIZE = (224, 224)

# Define the number of images to process in each batch
BATCH_SIZE = 32

# Use TensorFlow's automatic tuning for performance optimization
AUTOTUNE = tf.data.AUTOTUNE

# Load the training dataset from a directory, using 80% of the data for training
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    BIN_DIR,                      # Path to the directory containing image data
    validation_split=0.2,         # Reserve 20% of data for validation
    subset='training',            # This is the training subset
    seed=42,                      # Seed for reproducibility of the split
    label_mode='binary',          # Use binary labels (e.g., 0 or 1)
    image_size=IMG_SIZE,          # Resize all images to 224x224
    batch_size=BATCH_SIZE         # Number of images per batch
)

# Load the validation dataset from the same directory, using the remaining 20%
val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    BIN_DIR,
    validation_split=0.2,
    subset='validation',
    seed=42,
    label_mode='binary',
    image_size=IMG_SIZE,
    batch_size=BATCH_SIZE
)

# Define a normalization layer to scale pixel values from [0, 255] to [0, 1]
normalize = layers.Rescaling(1./255)

# Define a data augmentation pipeline to improve model generalization
augment = tf.keras.Sequential([
    layers.RandomFlip('horizontal'),  # Randomly flip images horizontally
    layers.RandomRotation(0.05),      # Randomly rotate images by up to 5%
    layers.RandomZoom(0.05),          # Randomly zoom images by up to 5%
])

# Define a preprocessing function to normalize and optionally augment images
def prep(x, y, training=False):
    x = normalize(x)           # Normalize pixel values
    if training:
        x = augment(x)         # Apply data augmentation only during training
    return x, y

# Apply preprocessing and augmentation to the training dataset
train_ds = train_ds.map(lambda x, y: prep(x, y, True), num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

# Apply only normalization (no augmentation) to the validation dataset
val_ds = val_ds.map(prep, num_parallel_calls=AUTOTUNE).prefetch(AUTOTUNE)

# Initialize a dictionary to count the number of samples per class
counts = {'edible': 0, 'toxic': 0}

# Iterate through the training dataset (unbatched) and count class occurrences
for _, y in train_ds.unbatch().take(20000):  # Limit to 20,000 samples for efficiency
    counts['toxic' if int(y.numpy()) == 1 else 'edible'] += 1

# Calculate total number of samples
total = max(1, counts['edible'] + counts['toxic'])

# Compute class weights to handle class imbalance during training
w_edible = total / (2.0 * max(1, counts['edible']))
w_toxic  = total / (2.0 * max(1, counts['toxic']))
class_weights = {0: w_edible, 1: w_toxic}

# Print the number of samples per class and the computed class weights
print('Class counts:', counts)
print('Class weights:', class_weights)

Found 26271 files belonging to 2 classes.
Using 21017 files for training.
Found 26271 files belonging to 2 classes.
Using 5254 files for validation.


  counts['toxic' if int(y.numpy()) == 1 else 'edible'] += 1


Class counts: {'edible': 3065, 'toxic': 16935}
Class weights: {0: 3.2626427406199023, 1: 0.5904930617065249}


## 5) EfficientNetB0 (feature extraction -> fine-tuning)

We follow Keras' transfer learning pattern: freeze the base (ImageNet weights), train the head, then unfreeze top blocks and fine-tune with a smaller learning rate.
Here's a concise summary of what it does:

1. Load Pretrained Model: It loads the EfficientNetB0 model without its top (classification) layer and with pretrained ImageNet weights. The base model is initially frozen to prevent its weights from being updated during training.

2. Build Custom Model: A new model is built on top of the base, adding:
- An input layer matching the image size.
- The EfficientNetB0 base model.
- A global average pooling layer to reduce the feature map dimensions.
- A dropout layer for regularization.
- A dense output layer with a sigmoid activation for binary classification.

3. Initial Training (Feature Extraction):
- The model is compiled with the Adam optimizer and binary crossentropy loss.
- It is trained for 5 epochs on the training dataset, using class weights to handle class imbalance.

4. Fine-Tuning:

- The base model is unfrozen, but only the last 30 layers are set to be trainable.
- The model is recompiled with a lower learning rate to avoid large updates.
- It is trained again for 5 more epochs to fine-tune the model for better performance.


In [10]:
# Import the EfficientNetB0 model from Keras applications
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras import layers, models
import tensorflow as tf

# Load the EfficientNetB0 model without the top classification layer
# Use pre-trained weights from ImageNet
# Set the input shape to match the image size with 3 color channels (RGB)
base = EfficientNetB0(include_top=False, weights='imagenet', input_shape=IMG_SIZE + (3,))

# Freeze the base model so its weights are not updated during training
base.trainable = False

# Define the input layer with the same shape as the images
inputs = layers.Input(shape=IMG_SIZE + (3,))

# Pass the inputs through the base model
# Set training=False to ensure batch normalization layers run in inference mode
x = base(inputs, training=False)

# Apply global average pooling to reduce the spatial dimensions
x = layers.GlobalAveragePooling2D()(x)

# Add dropout for regularization to prevent overfitting
x = layers.Dropout(0.25)(x)

# Add a dense output layer with sigmoid activation for binary classification
outputs = layers.Dense(1, activation='sigmoid')(x)

# Create the final model
model = models.Model(inputs, outputs, name='mushroom_effnet_b0')

# Compile the model with Adam optimizer, binary crossentropy loss, and accuracy metric
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3),
              loss='binary_crossentropy', metrics=['accuracy'])

# Train the model for 5 epochs using the training and validation datasets
# Apply class weights to handle class imbalance
history_fe = model.fit(train_ds, validation_data=val_ds, epochs=5, class_weight=class_weights)

# Unfreeze the base model to allow fine-tuning
base.trainable = True

# Freeze all layers in the base model except the last 30 layers
# This allows fine-tuning of the deeper layers while keeping earlier layers fixed
for layer in base.layers[:-30]:
    layer.trainable = False

# Re-compile the model with a lower learning rate for fine-tuning
model.compile(optimizer=tf.keras.optimizers.Adam(1e-4),
              loss='binary_crossentropy', metrics=['accuracy'])

# Continue training (fine-tuning) the model for another 5 epochs
history_ft = model.fit(train_ds, validation_data=val_ds, epochs=5, class_weight=class_weights)


Epoch 1/5
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m782s[0m 1s/step - accuracy: 0.4946 - loss: 0.7008 - val_accuracy: 0.1494 - val_loss: 0.6992
Epoch 2/5
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m729s[0m 1s/step - accuracy: 0.4988 - loss: 0.7003 - val_accuracy: 0.8506 - val_loss: 0.6626
Epoch 3/5
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m674s[0m 1s/step - accuracy: 0.5083 - loss: 0.6999 - val_accuracy: 0.1494 - val_loss: 0.7558
Epoch 4/5
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m683s[0m 1s/step - accuracy: 0.4920 - loss: 0.7002 - val_accuracy: 0.8506 - val_loss: 0.6771
Epoch 5/5
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m794s[0m 1s/step - accuracy: 0.5072 - loss: 0.6986 - val_accuracy: 0.1572 - val_loss: 0.7021
Epoch 1/5
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m926s[0m 1s/step - accuracy: 0.5044 - loss: 0.7084 - val_accuracy: 0.8506 - val_loss: 0.5990
Epoch 2/5
[1m657/657[0m [

## 6) Threshold tuning (prioritize toxic recall)

We sweep decision thresholds and pick one with high recall on the toxic class.
This code evaluates the model's performance on the validation set across a range of classification thresholds (from 0.30 to 0.80). It calculates precision, recall, F1-score, and the confusion matrix for each threshold, then selects and prints the threshold that gives the best combination of these metrics—prioritizing recall, then precision, then F1-score.

In [11]:
# Import necessary libraries
import numpy as np
from sklearn.metrics import precision_recall_fscore_support, confusion_matrix

# Initialize empty lists to store true labels and predicted probabilities
y_true, y_prob = [], []

# Iterate over the validation dataset
for xb, yb in val_ds:
    # Append the true labels (converted to NumPy arrays and flattened)
    y_true.append(yb.numpy().ravel())
    
    # Predict probabilities for each batch and flatten the output
    y_prob.append(model.predict(xb, verbose=0).ravel())

# Concatenate all batches into single NumPy arrays
y_true = np.concatenate(y_true)
y_prob = np.concatenate(y_prob)

# Define a range of threshold candidates from 0.30 to 0.80 (inclusive) with 21 evenly spaced values
cands = np.linspace(0.30, 0.80, 21)

# Initialize a list to store evaluation metrics for each threshold
records = []

# Loop through each threshold candidate
for th in cands:
    # Convert predicted probabilities to binary predictions using the current threshold
    y_pred = (y_prob >= th).astype(int)
    
    # Compute precision, recall, and F1-score for binary classification
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='binary', zero_division=0)
    
    # Compute the confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Store the threshold and corresponding metrics
    records.append((th, p, r, f1, cm))

# Find the threshold with the best performance
# Sorting by recall (x[2]), then precision (x[1]), then F1-score (x[3])
best = sorted(records, key=lambda x: (x[2], x[1], x[3]))[-1]

# Print the best threshold and its associated metrics
print('Best threshold:', round(best[0], 2))
print('Precision:', round(best[1], 3), 'Recall:', round(best[2], 3), 'F1:', round(best[3], 3))
print('Confusion matrix:', best[4])

Best threshold: 0.48
Precision: 0.851 Recall: 1.0 F1: 0.919
Confusion matrix: [[   0  785]
 [   0 4469]]


## 7) (Optional) Tiny CNN baseline (from scratch)

A compact CNN to visualize the lift you get from transfer learning.

This function defines a lightweight CNN model for binary image classification. It uses three convolutional layers with increasing filter sizes, followed by pooling, global average pooling, dropout for regularization, and a final sigmoid-activated dense layer. The model is compiled with the Adam optimizer and binary crossentropy loss, suitable for binary classification tasks.


In [13]:
# Define a simple Convolutional Neural Network (CNN) model
def tiny_cnn(input_shape=IMG_SIZE+(3,)):
    # Input layer with the specified image shape (height, width, channels)
    inputs = layers.Input(shape=input_shape)
    
    # First convolutional layer with 16 filters, 3x3 kernel, ReLU activation, and same padding
    x = layers.Conv2D(16, 3, padding='same', activation='relu')(inputs)
    
    # First max pooling layer to reduce spatial dimensions
    x = layers.MaxPooling2D()(x)
    
    # Second convolutional layer with 32 filters
    x = layers.Conv2D(32, 3, padding='same', activation='relu')(x)
    
    # Second max pooling layer
    x = layers.MaxPooling2D()(x)
    
    # Third convolutional layer with 64 filters
    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)
    
    # Global average pooling to reduce each feature map to a single value
    x = layers.GlobalAveragePooling2D()(x)
    
    # Dropout layer for regularization (25% of the neurons are randomly dropped during training)
    x = layers.Dropout(0.25)(x)
    
    # Output layer with 1 neuron and sigmoid activation for binary classification
    outputs = layers.Dense(1, activation='sigmoid')(x)
    
    # Create and return the model
    return models.Model(inputs, outputs, name='tiny_cnn')

# Instantiate the model
baseline = tiny_cnn()

# Compile the model with Adam optimizer, binary crossentropy loss, and accuracy as the evaluation metric
baseline.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Train the model on the training dataset with validation and class weights
# (This line is currently commented out)
baseline.fit(train_ds, validation_data=val_ds, epochs=6, class_weight=class_weights)


Epoch 1/6
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m344s[0m 516ms/step - accuracy: 0.5046 - loss: 0.6936 - val_accuracy: 0.2297 - val_loss: 0.6960
Epoch 2/6
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m339s[0m 514ms/step - accuracy: 0.4838 - loss: 0.6932 - val_accuracy: 0.7103 - val_loss: 0.6870
Epoch 3/6
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m313s[0m 474ms/step - accuracy: 0.3823 - loss: 0.6934 - val_accuracy: 0.3536 - val_loss: 0.6947
Epoch 4/6
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m321s[0m 471ms/step - accuracy: 0.4210 - loss: 0.6930 - val_accuracy: 0.5434 - val_loss: 0.6919
Epoch 5/6
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m440s[0m 669ms/step - accuracy: 0.4075 - loss: 0.6929 - val_accuracy: 0.4821 - val_loss: 0.6934
Epoch 6/6
[1m657/657[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m422s[0m 640ms/step - accuracy: 0.5279 - loss: 0.6922 - val_accuracy: 0.8022 - val_loss: 0.6465


<keras.src.callbacks.history.History at 0x289b0353c10>

## Appendix — What is EfficientNetB0 (in brief)?

EfficientNetB0 is the baseline model in the EfficientNet family. It achieves a strong accuracy-efficiency trade-off by compound scaling of depth, width, and resolution. It uses MBConv blocks, depthwise separable convolutions, and squeeze-and-excitation modules.

- Great for transfer learning on small/medium datasets.

- Input size: 224x224x3.

- Docs: https://www.tensorflow.org/api_docs/python/tf/keras/applications/EfficientNetB0

- Keras example: https://keras.io/examples/vision/image_classification_efficientnet_fine_tuning/
