In [None]:
import os
import cv2
import imghdr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import datetime
from tqdm import tqdm
from collections import Counter
from PIL import Image

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, f1_score
from collections import Counter
from sklearn.utils import shuffle

from imblearn.over_sampling import RandomOverSampler

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import Sequential
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.layers import (
    Conv2D, MaxPooling2D, Flatten, Dense, Activation, Dropout, BatchNormalization
)
from tensorflow.keras import regularizers
from tensorflow.keras.metrics import AUC
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau

import warnings
warnings.filterwarnings("ignore")

In [None]:
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print("GPU is set for TensorFlow")
    except RuntimeError as e:
        print(e)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
for dirname, _, filenames in os.walk('/content/drive/Othercomputers/My Laptop/AI-ML/LJMU/Dataset/Brain Tumor/Input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))


In [None]:
base_path = '/content/drive/Othercomputers/My Laptop/AI-ML/LJMU/Dataset/Brain Tumor/Input'
categories = ["glioma", "meningioma", "notumor", "pituitary"]

In [None]:
image_paths = []
labels = []

for category in categories:
    category_path = os.path.join(base_path, category)
    for image_name in os.listdir(category_path):
        image_path = os.path.join(category_path, image_name)
        image_paths.append(image_path)
        labels.append(category)

df = pd.DataFrame({
    "image_path": image_paths,
    "label": labels
})

In [None]:
df.head()

In [None]:
df.tail()

In [None]:

df.columns

In [None]:
# 1️⃣ Dataset Overview
def dataset_overview(df, label_column='label', image_column='image_path'):
    print(f"Total samples: {df.shape}")
    print(f"Unique classes: {df[label_column].nunique()}")
    print(f"Class distribution:\n{df[label_column].value_counts()}")
    print("\nMissing labels:", df[label_column].isnull().sum())
    print("Missing image paths:", df[image_column].isnull().sum())
    print("Duplicate image paths:", df.duplicated().sum())
    print("Info:", df.info())

dataset_overview(df)

In [None]:
def clean_invalid_images(df, image_column='image_path', allowed_formats=('jpeg', 'png', 'jpg'),
                          save_corrupt_log=False, log_path="invalid_images.csv"):
    """
    Removes rows where image file is missing, unreadable or has invalid format.

    Returns:
        - cleaned dataframe
        - list of invalid file paths
    """
    invalid_images = []

    for path in tqdm(df[image_column], desc="Validating images"):
        if pd.isna(path) or not os.path.exists(path):
            invalid_images.append(path)
            continue

        # File format check
        file_type = imghdr.what(path)
        if file_type not in allowed_formats:
            invalid_images.append(path)
            continue

        # Readability check
        try:
            img = cv2.imread(path)
            if img is None:
                invalid_images.append(path)
        except:
            invalid_images.append(path)

    print(f"\nTotal invalid images detected: {len(invalid_images)}")

    df_clean = df[~df[image_column].isin(invalid_images)].reset_index(drop=True)
    print(f"Remaining clean samples: {len(df_clean)}")

    if save_corrupt_log and invalid_images:
        pd.DataFrame(invalid_images, columns=['invalid_path']).to_csv(log_path, index=False)
        print(f"Invalid image paths saved to: {log_path}")

    return df_clean, invalid_images

df_clean, invalid_files = clean_invalid_images(df, image_column='image_path', allowed_formats=('jpeg', 'png', 'jpg'))

df = df_clean

# Exploratory Data Analysis


In [None]:
# Image Dimension Analysis
def image_dimensions_analysis(df, image_column='image_path'):
    widths, heights = [], []
    for path in tqdm(df[image_column], desc="Reading image dimensions"):
        try:
            img = Image.open(path)
            w, h = img.size
            widths.append(w)
            heights.append(h)
        except:
            continue
    plt.figure(figsize=(10, 5))
    sns.histplot(widths, kde=True, color="skyblue", label="Width")
    sns.histplot(heights, kde=True, color="salmon", label="Height")
    plt.title("Image Dimension Distribution")
    plt.xlabel("Pixels")
    plt.legend()
    plt.show()
    print(f"Width: min={min(widths)}, max={max(widths)}, mean={np.mean(widths):.2f}")
    print(f"Height: min={min(heights)}, max={max(heights)}, mean={np.mean(heights):.2f}")


image_dimensions_analysis(df)

In [None]:
# Aspect Ratio Distribution
def aspect_ratio_distribution(df, image_column='image_path'):
    aspect_ratios = []
    for path in tqdm(df[image_column], desc="Computing aspect ratios"):
        try:
            img = Image.open(path)
            w, h = img.size
            aspect_ratios.append(w/h)
        except:
            continue
    plt.figure(figsize=(10, 5))
    sns.histplot(aspect_ratios, kde=True, color="purple")
    plt.title("Aspect Ratio Distribution")
    plt.xlabel("Width / Height")
    plt.show()

aspect_ratio_distribution(df)

In [None]:
# Per Channel Mean/Std (for normalization)
def image_channel_statistics(df, image_column='image_path', sample_size=1000):
    means, stds = [], []
    sampled_paths = df[image_column].sample(min(len(df), sample_size), random_state=42)

    for path in tqdm(sampled_paths, desc="Computing channel stats"):
        try:
            img = cv2.imread(path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
            means.append(np.mean(img, axis=(0, 1)))
            stds.append(np.std(img, axis=(0, 1)))
        except:
            continue

    mean = np.mean(means, axis=0)
    std = np.mean(stds, axis=0)
    print(f"Channel-wise Mean (R, G, B): {mean}")
    print(f"Channel-wise Std (R, G, B): {std}")
    return mean, std

# Channel statistics
mean, std = image_channel_statistics(df)

In [None]:
def plot_category_samples(df, categories, num_images=5, figsize=(15, 12)):
    """
    Display sample images for each category.

    Parameters:
    - df: DataFrame containing 'label' and 'image_path' columns.
    - categories: list of category labels.
    - num_images: number of images to display per category.
    - figsize: size of the entire figure.
    """
    plt.figure(figsize=figsize)

    for i, category in enumerate(categories):
        category_images = df[df['label'] == category]['image_path'].iloc[:num_images]

        for j, img_path in enumerate(category_images):
            img = cv2.imread(img_path)

            if img is None:
                print(f"Warning: Image not found at {img_path}")
                continue  # Skip if image not loaded

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            ax = plt.subplot(len(categories), num_images, i * num_images + j + 1)
            ax.imshow(img)
            ax.axis('off')

            # Only add title for first image of each category
            if j == 0:
                ax.set_title(f"Class: {category}", fontsize=12, fontweight='bold')

    plt.tight_layout()
    plt.show()


In [None]:
plot_category_samples(df, categories, num_images=5)

In [None]:
def label_distribution(df, label_column="label", figsize=(16, 6), palette="viridis", font_size=14, title_size=16):
    import matplotlib.pyplot as plt
    import seaborn as sns

    # Create subplots for side-by-side display
    fig, axes = plt.subplots(1, 2, figsize=figsize)

    # Countplot
    ax1 = sns.countplot(data=df, x=label_column, palette=palette, ax=axes[0])
    ax1.set_title(f"Distribution of {label_column.capitalize()} - Count Plot", fontsize=title_size)
    ax1.set_xlabel(label_column.capitalize(), fontsize=font_size)
    ax1.set_ylabel("Count", fontsize=font_size)
    ax1.tick_params(axis='both', labelsize=font_size)

    for p in ax1.patches:
        ax1.annotate(f'{int(p.get_height())}',
                     (p.get_x() + p.get_width() / 2., p.get_height()),
                     ha='center', va='center', fontsize=font_size, color='black', xytext=(0, 5),
                     textcoords='offset points')

    # Pie Chart
    label_counts = df[label_column].value_counts()
    wedges, texts, autotexts = axes[1].pie(
        label_counts,
        labels=label_counts.index,
        autopct='%1.1f%%',
        startangle=140,
        colors=sns.color_palette(palette, n_colors=len(label_counts)),
        textprops={'fontsize': font_size}
    )
    axes[1].set_title(f"Distribution of {label_column.capitalize()} - Pie Chart", fontsize=title_size)

    plt.tight_layout()
    plt.show()


label_distribution(df, font_size=14, title_size=18)

# Label Encoding

In [None]:
label_encoder = LabelEncoder()
df['label_encoded'] = label_encoder.fit_transform(df['label'])

In [None]:
df = df[['image_path', 'label_encoded']]

# Resampling

In [None]:
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(df[['image_path']], df['label_encoded'])

df_resampled = pd.DataFrame(X_resampled, columns=['image_path'])
df_resampled['label_encoded'] = y_resampled

In [None]:
print("\nClass distribution after oversampling:")
print(df_resampled['label_encoded'].value_counts())

In [None]:
df_resampled

In [None]:
df_resampled['label_encoded'] = df_resampled['label_encoded'].astype(str)

In [None]:
label_distribution(df_resampled, label_column= "label_encoded",font_size=14, title_size=18)

# Train Test Validation Split / Image genarator

In [None]:
train_df_new, temp_df_new = train_test_split(
    df_resampled,
    train_size=0.8,
    shuffle=True,
    random_state=42,
    stratify=df_resampled['label_encoded']
)

valid_df_new, test_df_new = train_test_split(
    temp_df_new,
    test_size=0.5,
    shuffle=True,
    random_state=42,
    stratify=temp_df_new['label_encoded']
)

In [None]:
batch_size = 32
img_size = (224, 224)
channels = 3
img_shape = (img_size[0], img_size[1], channels)

tr_gen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=15,           # Slightly increased for better variation
    width_shift_range=0.05,      # This mimics slight off-center placements that happen in real MRI scans.
    height_shift_range=0.05,     # This mimics slight off-center placements that happen in real MRI scans.
    zoom_range=0.1,
    channel_shift_range = 0.05,
    horizontal_flip=True,
    brightness_range=[0.9, 1.1], # Simulates brightness differences in scans
    shear_range=5,               # Mild shearing to simulate spatial warping
    fill_mode='nearest'          # Fills empty pixels after transformations
)

ts_gen = ImageDataGenerator(rescale=1./255)

train_gen_new = tr_gen.flow_from_dataframe(
    train_df_new,
    x_col='image_path',
    y_col='label_encoded',
    target_size=img_size,
    class_mode='sparse',
    color_mode='rgb',
    shuffle=True,
    batch_size=batch_size
)

valid_gen_new = ts_gen.flow_from_dataframe(
    valid_df_new,
    x_col='image_path',
    y_col='label_encoded',
    target_size=img_size,
    class_mode='sparse',
    color_mode='rgb',
    shuffle=True,
    batch_size=batch_size
)

test_gen_new = ts_gen.flow_from_dataframe(
    test_df_new,
    x_col='image_path',
    y_col='label_encoded',
    target_size=img_size,
    class_mode='sparse',
    color_mode='rgb',
    shuffle=False,
    batch_size=batch_size
)

# Common Functions

In [None]:
def get_callbacks(model_name, monitor_metric='val_loss', patience=5, save_dir='./'):
    """
    Generate EarlyStopping and ModelCheckpoint callbacks dynamically.

    Parameters:
    - model_name (str): Name of the model (will be used for the filename)
    - monitor_metric (str): Metric to monitor (default: 'val_accuracy')
    - patience (int): Number of epochs to wait before early stopping
    - save_dir (str): Directory where model should be saved (default: current directory)

    Returns:
    - List of callbacks [early_stopping, model_checkpoint]
    """
    # Build full model save path
    model_path = f"{save_dir}/best_{model_name}.keras"

    # Early stopping callback
    early_stopping = EarlyStopping(
        monitor=monitor_metric,
        patience=10,
        restore_best_weights=True
    )

    # Model checkpoint callback
    model_checkpoint = ModelCheckpoint(
        model_path,
        monitor=monitor_metric,
        save_best_only=True,
        verbose=1
    )

    reduce_lr = ReduceLROnPlateau(
        monitor=monitor_metric,
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1,
        mode='min'  # IMPORTANT for accuracy!
    )

    return [early_stopping, model_checkpoint, reduce_lr]


In [None]:
def plot_training_history(history, metrics=['accuracy', 'loss']):
    """
    Plots training history for given metrics side by side.

    Parameters:
    - history: History object returned by model.fit()
    - metrics: List of metrics to plot (default: ['accuracy', 'loss'])
    """
    n_metrics = len(metrics)
    fig, axes = plt.subplots(1, n_metrics, figsize=(7 * n_metrics, 5))  # side by side

    if n_metrics == 1:
        axes = [axes]  # make iterable if only one metric

    for ax, metric in zip(axes, metrics):
        if metric in history.history:
            ax.plot(history.history[metric], label='Train')
            ax.plot(history.history['val_' + metric], label='Validation')
            ax.set_title(f'Model {metric.capitalize()}')
            ax.set_ylabel(metric.capitalize())
            ax.set_xlabel('Epoch')
            ax.legend(loc='upper left')
            ax.grid(True)
        else:
            ax.text(0.5, 0.5, f"⚠️ Metric '{metric}' not found",
                    ha='center', va='center', fontsize=12)
            ax.set_axis_off()

    plt.tight_layout()
    plt.show()


In [None]:
def evaluate_model_performance(true_labels, predicted_labels, class_names, figsize=(10, 8), cmap='Blues'):
    """
    Prints classification report and displays confusion matrix heatmap.

    Parameters:
    - true_labels: Ground truth labels
    - predicted_labels: Predicted labels from the model
    - class_names: List of class names (usually from test_gen.class_indices.keys())
    - figsize: Tuple for the size of the heatmap figure
    - cmap: Color map for heatmap
    """
    # Classification report
    report = classification_report(true_labels, predicted_labels, target_names=class_names)
    print("Classification Report:\n", report)

    # Confusion matrix
    conf_matrix = confusion_matrix(true_labels, predicted_labels)

    # Plot heatmap
    plt.figure(figsize=figsize)
    sns.heatmap(conf_matrix, annot=True, fmt='d', cmap=cmap,
                xticklabels=class_names, yticklabels=class_names)
    plt.title('Confusion Matrix')
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.tight_layout()
    plt.show()

In [None]:
# Global results collector
model_results = []

def record_model_results(model_name, history, test_generator, model):
    """
    Records train/val/test accuracy, precision, recall, and F1 score.

    Parameters:
    - model_name (str): Name of the model (ViT/Swin/MaxViT)
    - history: History object from model.fit()
    - test_generator: Generator or dataset used for testing
    - model: Trained Keras model
    """
    train_acc = history.history['accuracy'][-1]
    val_acc = history.history['val_accuracy'][-1]

    # Get test labels and predictions
    y_true = []
    y_pred = []

    for batch in test_generator:
        x_batch, y_batch = batch
        preds = model.predict(x_batch, verbose=0)
        y_true.extend(y_batch)
        y_pred.extend(np.argmax(preds, axis=1))
        if len(y_true) >= test_generator.samples:
            break

    y_true = np.array(y_true)[:test_generator.samples]
    y_pred = np.array(y_pred)[:test_generator.samples]

    # Calculate metrics
    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)

    test_loss, test_acc = model.evaluate(test_generator, verbose=0)

    model_results.append({
        'Model': model_name,
        'Train Acc': round(train_acc * 100, 2),
        'Val Acc': round(val_acc * 100, 2),
        'Test Acc': round(test_acc * 100, 2),
        'Precision': round(precision * 100, 2),
        'Recall': round(recall * 100, 2),
        'F1 Score': round(f1 * 100, 2)
    })

def display_results_table():
    df = pd.DataFrame(model_results)
    display(df)


# Vision Transformer

In [None]:
### 1. DropPath Layer
class DropPath(layers.Layer):
    def __init__(self, drop_prob=0.0):
        super().__init__()
        self.drop_prob = drop_prob

    def call(self, x, training=False):
        if (not training) or (self.drop_prob == 0.0):
            return x
        keep_prob = 1.0 - self.drop_prob
        shape = (tf.shape(x)[0],) + (1,) * (len(x.shape) - 1)
        random_tensor = keep_prob + tf.random.uniform(shape, dtype=x.dtype)
        binary_tensor = tf.floor(random_tensor)
        return tf.math.divide(x, keep_prob) * binary_tensor

### 2. Patch Embedding Layer
class PatchEmbedding(layers.Layer):
    def __init__(self, patch_size, embed_dim):
        super().__init__()
        self.proj = layers.Conv2D(embed_dim, patch_size, strides=patch_size, padding='valid')

    def call(self, images):
        patches = self.proj(images)
        return tf.reshape(patches, (tf.shape(images)[0], -1, patches.shape[-1]))

### 3. Multi-Head Self Attention with Tracking
class MultiHeadSelfAttention(layers.Layer):
    def __init__(self, num_heads, embed_dim):
        super().__init__()
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.attn_weights = None

    def call(self, x):
        output, weights = self.attn(x, x, return_attention_scores=True)
        self.attn_weights = weights
        return output

### 4. Transformer Block with DropPath
class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate, drop_path_rate):
        super().__init__()
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.attn = MultiHeadSelfAttention(num_heads, embed_dim)
        self.drop_path1 = DropPath(drop_path_rate)

        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.mlp = tf.keras.Sequential([
            layers.Dense(mlp_dim, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Dense(embed_dim),
            layers.Dropout(dropout_rate),
        ])
        self.drop_path2 = DropPath(drop_path_rate)

    def call(self, x, training=False):
        shortcut = x
        x = self.norm1(x)
        x = self.attn(x)
        x = shortcut + self.drop_path1(x, training=training)

        shortcut = x
        x = self.norm2(x)
        x = self.mlp(x, training=training)
        x = shortcut + self.drop_path2(x, training=training)
        return x

### 5. Vision Transformer Model
class VisionTransformer(tf.keras.Model):
    def __init__(self, image_size, patch_size, embed_dim, num_heads, num_blocks, mlp_dim,
                 num_classes, dropout_rate=0.1, max_drop_path_rate=0.1):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)

        height, width, _ = image_size
        num_patches = (height // patch_size) * (width // patch_size)
        self.cls_token = self.add_weight(name="cls_token", shape=(1, 1, embed_dim), initializer="random_normal", trainable=True)
        self.pos_embed = self.add_weight(name="pos_embed", shape=(1, num_patches + 1, embed_dim), initializer="random_normal", trainable=True)
        self.dropout = layers.Dropout(dropout_rate)

        drop_path_rates = np.linspace(0.0, max_drop_path_rate, num_blocks)
        self.transformer_blocks = [
            TransformerBlock(embed_dim, num_heads, mlp_dim, dropout_rate, drop_path_rate)
            for drop_path_rate in drop_path_rates
        ]

        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.head = layers.Dense(num_classes, activation='softmax')

    def call(self, images, training=False):
        batch_size = tf.shape(images)[0]
        x = self.patch_embed(images)

        cls_tokens = tf.broadcast_to(self.cls_token, [batch_size, 1, x.shape[-1]])
        x = tf.concat([cls_tokens, x], axis=1)
        x = x + self.pos_embed
        x = self.dropout(x, training=training)

        for block in self.transformer_blocks:
            x = block(x, training=training)

        x = self.norm(x)
        return self.head(x[:, 0])

    def get_attention_weights(self):
        return [block.attn.attn_weights for block in self.transformer_blocks]

image_size = (224, 224, 3)
patch_size = 16
embed_dim = 256
num_heads = 8
num_blocks = 6
mlp_dim = 256
num_classes = 4
dropout_rate = 0.1
max_drop_path_rate = 0.1

# Changed to a fixed learning rate to work with ReduceLROnPlateau
learning_rate = 2e-4

optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

vit_model = VisionTransformer(image_size=image_size,
                              patch_size=patch_size,
                              embed_dim=embed_dim,
                              num_heads=num_heads,
                              num_blocks=num_blocks,
                              mlp_dim=mlp_dim,
                              num_classes=num_classes,
                              dropout_rate=dropout_rate,
                              max_drop_path_rate=max_drop_path_rate)

vit_model.compile(optimizer=optimizer,
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

In [None]:
epochs = 80
callbacks = get_callbacks(model_name='vit_model')
vit_history = vit_model.fit(train_gen_new, epochs=epochs, batch_size = 32, validation_data=valid_gen_new, callbacks=callbacks)

In [None]:
plot_training_history(vit_history)

In [None]:
test_labels = test_gen_new.classes
predictions = vit_model.predict(test_gen_new)
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
evaluate_model_performance(test_labels, predicted_classes, list(test_gen_new.class_indices.keys()), figsize=(10, 8), cmap='Blues')

In [None]:
record_model_results("ViT", vit_history, test_gen_new, vit_model)
display_results_table()

# Swin Transformer



In [None]:
class WindowPartition(layers.Layer):
    def __init__(self, window_size):
        super().__init__()
        self.window_size = window_size

    def call(self, x):
        B, H, W, C = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
        x = tf.reshape(x, [B, H // self.window_size, self.window_size,
                           W // self.window_size, self.window_size, C])
        x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
        windows = tf.reshape(x, [-1, self.window_size * self.window_size, C])
        return windows

class WindowReverse(layers.Layer):
    def __init__(self, window_size, H, W):
        super().__init__()
        self.window_size = window_size
        self.H = H
        self.W = W

    def call(self, windows, B):
        x = tf.reshape(windows, [B, self.H // self.window_size, self.W // self.window_size,
                                 self.window_size, self.window_size, -1])
        x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
        x = tf.reshape(x, [B, self.H, self.W, -1])
        return x

class WindowAttention(layers.Layer):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.num_heads = num_heads
        self.window_size = window_size
        self.scale = (dim // num_heads) ** -0.5
        self.qkv = layers.Dense(dim * 3, use_bias=False)
        self.proj = layers.Dense(dim)

    def call(self, x):
        B_, N, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        qkv = self.qkv(x)
        qkv = tf.reshape(qkv, (B_, N, 3, self.num_heads, C // self.num_heads))
        qkv = tf.transpose(qkv, [2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = tf.matmul(q, k, transpose_b=True) * self.scale
        attn = tf.nn.softmax(attn)
        x = tf.matmul(attn, v)
        x = tf.transpose(x, [0, 2, 1, 3])
        x = tf.reshape(x, (B_, N, C))
        return self.proj(x)

class SwinTransformerBlock(layers.Layer):
    def __init__(self, dim, num_heads, window_size, shift_size=0):
        super().__init__()
        self.norm1 = layers.LayerNormalization()
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = layers.LayerNormalization()
        self.mlp = [
            layers.Dense(dim * 4, activation='gelu'),
            layers.Dense(dim)
        ]
        self.window_size = window_size
        self.shift_size = shift_size

    def call(self, x, H, W):
        B, L, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x = tf.reshape(x, (B, H, W, C))

        if self.shift_size > 0:
            shifted_x = tf.roll(x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])
        else:
            shifted_x = x

        x_windows = WindowPartition(self.window_size)(shifted_x)
        attn_windows = self.attn(x_windows)
        shifted_x = WindowReverse(self.window_size, H, W)(attn_windows, B)

        if self.shift_size > 0:
            x = tf.roll(shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2])
        else:
            x = shifted_x

        x = tf.reshape(x, (B, H * W, C))
        residual = x
        x = self.norm2(x + self.attn(self.norm1(x))) # Fixed: Passing only tensor to self.attn
        for mlp_layer in self.mlp: # Fixed: Iterating through mlp layers directly
          x = mlp_layer(x)
        x = residual + x
        return x


class PatchMerging(layers.Layer):
    def __init__(self, input_dim):
        super().__init__()
        self.reduction = layers.Dense(input_dim * 2, use_bias=False)

    def call(self, x, H, W):
        B, L, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        x = tf.reshape(x, (B, H, W, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat([x0, x1, x2, x3], axis=-1)
        x = tf.reshape(x, (B, -1, 4 * C))
        return self.reduction(x)


class SwinTransformer(layers.Layer):
    def __init__(self, input_shape, patch_size=4, embed_dim=96, depths=[2,2], num_heads=[3,6], window_size=7):
        super().__init__()
        self.patch_embed = layers.Conv2D(embed_dim, patch_size, strides=patch_size, padding='same')
        self.pos_drop = layers.Dropout(0.1)
        self.layers = []
        self.H, self.W = input_shape[0] // patch_size, input_shape[1] // patch_size


        for i_layer in range(len(depths)):
            for i_block in range(depths[i_layer]):
                shift_size = 0 if (i_block % 2 == 0) else window_size // 2
                self.layers.append(
                    SwinTransformerBlock(embed_dim * (2 ** i_layer), num_heads[i_layer], window_size, shift_size)
                )
            if i_layer < len(depths) - 1:
                self.layers.append(PatchMerging(embed_dim * (2 ** i_layer)))


        self.norm = layers.LayerNormalization()
        self.pool = layers.GlobalAveragePooling1D()
        self.fc = layers.Dense(4, activation='softmax')

    def call(self, x):
        x = self.patch_embed(x)
        H, W = x.shape[1], x.shape[2]
        x = tf.reshape(x, (tf.shape(x)[0], -1, x.shape[-1]))
        x = self.pos_drop(x)


        for layer in self.layers:
            if isinstance(layer, SwinTransformerBlock):
                x = layer(x, H=H, W=W)
            elif isinstance(layer, PatchMerging):
                x = layer(x, H=H, W=W)
                H, W = H // 2, W // 2

        x = self.norm(x)
        x = self.pool(x)
        return self.fc(x)

input_shape = (224, 224, 3)
swin_model = Sequential([
    layers.Input(shape=input_shape),
    SwinTransformer(input_shape=input_shape)
])

In [None]:
# Changed to a fixed learning rate to work with ReduceLROnPlateau
learning_rate = 2e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

swin_model.compile(optimizer=optimizer,
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

callbacks = get_callbacks(model_name='swin_model')

swin_history = swin_model.fit(
    train_gen_new,
    validation_data=valid_gen_new,
    epochs=80,
    callbacks=callbacks
)

In [None]:
plot_training_history(swin_history)

In [None]:
test_labels = test_gen_new.classes
predictions = swin_model.predict(test_gen_new)
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
evaluate_model_performance(test_labels, predicted_classes, list(test_gen_new.class_indices.keys()), figsize=(10, 8), cmap='Blues')

In [None]:
record_model_results("Swin", swin_history, test_gen_new, swin_model)
display_results_table()

# MaxVit Transformer

In [None]:
# === Hyperparameters ===
batch_size = 32
img_size = (224, 224, 3)
patch_size = 16
embed_dim = 64
num_heads = 4
window_size = 7
num_blocks = 2
mlp_dim = 128
num_classes = 4
dropout_rate = 0.1
weight_decay = 1e-4

initial_learning_rate = 1e-4
# === MBConv Block ===
class MBConv(layers.Layer):
    def __init__(self, embed_dim, expansion_factor=4):
        super().__init__()
        self.expand = layers.Conv2D(embed_dim * expansion_factor, 1, padding='same', activation='gelu')
        self.depthwise = layers.DepthwiseConv2D(3, padding='same', activation='gelu')
        self.project = layers.Conv2D(embed_dim, 1, padding='same')
        self.norm = layers.BatchNormalization()

    def call(self, x):
        residual = x
        x = self.expand(x)
        x = self.depthwise(x)
        x = self.project(x)
        return self.norm(x + residual)

# === Local and Grid Attention (Simplified) ===
class WindowAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.num_heads = num_heads
        self.window_size = window_size
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)

    def call(self, x):
        B = tf.shape(x)[0]
        H, W, C = x.shape[1], x.shape[2], x.shape[3]
        x_reshaped = tf.reshape(x, (B, H // self.window_size, self.window_size, W // self.window_size, self.window_size, C))
        x_reshaped = tf.transpose(x_reshaped, [0, 1, 3, 2, 4, 5])
        x_windows = tf.reshape(x_reshaped, (-1, self.window_size * self.window_size, C))
        attn_windows = self.attn(x_windows, x_windows)
        attn_windows = tf.reshape(attn_windows, (B, H // self.window_size, W // self.window_size, self.window_size, self.window_size, C))
        attn_windows = tf.transpose(attn_windows, [0, 1, 3, 2, 4, 5])
        return tf.reshape(attn_windows, (B, H, W, C))

class GridAttention(layers.Layer):
    def __init__(self, embed_dim, num_heads):
        super().__init__()
        self.attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim // num_heads)

    def call(self, x):
        B, H, W, C = tf.shape(x)[0], x.shape[1], x.shape[2], x.shape[3]
        x = tf.transpose(x, [0, 2, 1, 3])
        x = tf.reshape(x, (B * W, H, C))
        x = self.attn(x, x)
        x = tf.reshape(x, (B, W, H, C))
        return tf.transpose(x, [0, 2, 1, 3])

# === MaxViT Block ===
class MaxViTBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, window_size):
        super().__init__()
        self.mbconv = MBConv(embed_dim)
        self.window_attn = WindowAttention(embed_dim, num_heads, window_size)
        self.grid_attn = GridAttention(embed_dim, num_heads)
        self.ffn = tf.keras.Sequential([
            layers.Conv2D(embed_dim * 2, 1, activation='gelu'),
            layers.Dropout(dropout_rate),
            layers.Conv2D(embed_dim, 1),
            layers.Dropout(dropout_rate),
        ])

    def call(self, x):
        x = self.mbconv(x)
        win_out = self.window_attn(x)
        x = x + win_out
        x = x + self.grid_attn(x)
        x = x + self.ffn(x)
        return x

# === Patch Embedding Layer ===
class PatchEmbedding(layers.Layer):
    def __init__(self, patch_size, embed_dim):
        super().__init__()
        self.proj = layers.Conv2D(embed_dim, patch_size, strides=patch_size, padding='valid')

    def call(self, images):
        return self.proj(images)

# === MaxViT Model ===
class MaxViT(tf.keras.Model):
    def __init__(self, image_size, patch_size, embed_dim, num_heads, num_blocks, window_size, num_classes):
        super().__init__()
        self.patch_embed = PatchEmbedding(patch_size, embed_dim)
        self.pos_drop = layers.Dropout(dropout_rate)
        self.blocks = [MaxViTBlock(embed_dim, num_heads, window_size) for _ in range(num_blocks)]
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.pool = layers.GlobalAveragePooling2D()
        self.head = layers.Dense(num_classes, activation='softmax')

    def call(self, x):
        x = self.patch_embed(x)
        x = self.pos_drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = self.pool(x)
        return self.head(x)

# === Instantiate and Compile ===
maxVit_model = MaxViT(
    image_size=img_size,
    patch_size=patch_size,
    embed_dim=embed_dim,
    num_heads=num_heads,
    num_blocks=num_blocks,
    window_size=window_size,
    num_classes=num_classes
)

In [None]:
# Changed to a fixed learning rate to work with ReduceLROnPlateau
learning_rate = 2e-4
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)

maxVit_model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [None]:
callbacks = get_callbacks(model_name='maxvit_model')

maxVit_history = maxVit_model.fit(
    train_gen_new,
    validation_data=valid_gen_new,
    epochs=80,
    batch_size=batch_size,
    callbacks=callbacks
)

In [None]:
plot_training_history(maxVit_history)

In [None]:
test_labels = test_gen_new.classes
predictions = maxVit_model.predict(test_gen_new)
predicted_classes = np.argmax(predictions, axis=1)

In [None]:
evaluate_model_performance(test_labels, predicted_classes, list(test_gen_new.class_indices.keys()), figsize=(10, 8), cmap='Blues')

In [None]:
record_model_results("MaxViT", maxVit_history, test_gen_new, maxVit_model)
display_results_table()

# Pre-Trained Models

In [None]:
!pip install transformers

In [None]:
from transformers import ViTForImageClassification, ViTFeatureExtractor
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image



In [None]:
torch.cuda.empty_cache()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class CustomImageDataset(Dataset):
    def __init__(self, df, feature_extractor):
        self.df = df.reset_index(drop=True)
        self.feature_extractor = feature_extractor
        # self.size = self.feature_extractor.size["shortest_edge"] if "shortest_edge" in self.feature_extractor.size else 224

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        image_path = self.df.loc[idx, 'image_path']
        label = int(self.df.loc[idx, 'label_encoded'])

        image = Image.open(image_path).convert("RGB")
        # Apply feature extractor
        encoding = self.feature_extractor(images=image, return_tensors="pt")

        return {
            'pixel_values': encoding['pixel_values'].squeeze(),
            'label': torch.tensor(label)
        }

In [None]:
def plot_training_curves(history_dict):
    """
    Plots training and validation accuracy & loss side by side.

    Parameters:
    - history_dict: dict returned by train_model
    """
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Accuracy
    axes[0].plot(history_dict["train_acc"], label="Train")
    axes[0].plot(history_dict["val_acc"], label="Validation")
    axes[0].set_title("Model Accuracy")
    axes[0].set_xlabel("Epoch")
    axes[0].set_ylabel("Accuracy")
    axes[0].legend()
    axes[0].grid(True)

    # Loss
    axes[1].plot(history_dict["train_loss"], label="Train")
    axes[1].plot(history_dict["val_loss"], label="Validation")
    axes[1].set_title("Model Loss")
    axes[1].set_xlabel("Epoch")
    axes[1].set_ylabel("Loss")
    axes[1].legend()
    axes[1].grid(True)

    plt.tight_layout()
    plt.show()


In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating", leave=False):
            inputs = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)
            outputs = model(inputs).logits
            preds = torch.argmax(outputs, dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    acc = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='macro')
    recall = recall_score(y_true, y_pred, average='macro')
    f1 = f1_score(y_true, y_pred, average='macro')

    print(f"\nTest Accuracy: {acc:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1 Score: {f1:.4f}")


# Diagnostic Code Block

In [None]:
# 1. Check Class Balance
def check_class_distribution(y, title='Class Distribution'):
    counter = Counter(y)
    classes = list(counter.keys())
    counts = list(counter.values())

    plt.figure(figsize=(8, 4))
    sns.barplot(x=classes, y=counts)
    plt.title(title)
    plt.ylabel("Count")
    plt.xlabel("Class")
    plt.grid(True)
    plt.show()
    print("Sample count per class:", dict(counter))

# 2. Check Train-Val Split Leakage
def check_data_leak(train_paths, val_paths):
    train_set = set(train_paths)
    val_set = set(val_paths)
    common = train_set.intersection(val_set)
    print(f"🔍 Common images between train and validation sets: {len(common)}")
    if len(common) > 0:
        print("⚠️ Potential data leakage! Review your split logic.")
    else:
        print("✅ No image overlap between train and validation sets.")

# 3. Confusion Matrix
def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(6, 5))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=class_names, yticklabels=class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

In [None]:
train_image_paths = train_df_new['image_path'].tolist()
val_image_paths = valid_df_new['image_path'].tolist()

y_train = train_df_new['label_encoded'].tolist()
y_val = valid_df_new['label_encoded'].tolist()

In [None]:
check_class_distribution(y_train, "Train Class Distribution")
check_class_distribution(y_val, "Validation Class Distribution")
check_data_leak(train_image_paths, val_image_paths)

In [None]:
def run_pretrained_evaluation(model, test_loader, device, class_names, model_name):
    """
    Runs inference on the test_loader with progress bar and uses evaluate_model_performance.
    """
    model.eval()
    y_true, y_pred = [], []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc=f"Evaluating {model_name}"):
            inputs = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)

            outputs = model(inputs).logits
            _, preds = torch.max(outputs, dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    print(f"\n===== {model_name} Evaluation =====")
    evaluate_model_performance(y_true, y_pred, class_names)


In [None]:
def record_model_results_hf(model_name, train_acc, val_acc, test_loader, model, device):
    """
    Records results for Hugging Face / PyTorch models.

    Parameters:
    - model_name (str): Model identifier
    - train_acc (list): Training accuracies per epoch
    - val_acc (list): Validation accuracies per epoch
    - test_loader: PyTorch DataLoader for test data
    - model: Trained Hugging Face model
    - device: 'cuda' or 'cpu'
    """
    y_true, y_pred = [], []
    model.eval()

    with torch.no_grad():
        for batch in test_loader:
            inputs = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)

            outputs = model(inputs).logits
            preds = torch.argmax(outputs, dim=1)

            y_true.extend(labels.cpu().numpy())
            y_pred.extend(preds.cpu().numpy())

    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)

    # Compute test accuracy
    test_acc = np.mean(np.array(y_true) == np.array(y_pred))

    model_results.append({
        'Model': model_name,
        'Train Acc': round(train_acc[-1] * 100, 2),
        'Val Acc': round(val_acc[-1] * 100, 2),
        'Test Acc': round(test_acc * 100, 2),
        'Precision': round(precision * 100, 2),
        'Recall': round(recall * 100, 2),
        'F1 Score': round(f1 * 100, 2)
    })


# ViT Pretained

In [None]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

In [None]:
vit_model_pretrained = ViTForImageClassification.from_pretrained(
    "google/vit-base-patch16-224-in21k",
    num_labels=len(label_encoder.classes_)
)
vit_model_pretrained.to('cuda' if torch.cuda.is_available() else 'cpu')


In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vit_model_pretrained.parameters(), lr=1e-4)

In [None]:
def train_model(model, train_loader, val_loader, epochs=5):
    train_acc, val_acc = [], []
    train_losses, val_losses = [], []
    best_val_accuracy = 0

    for epoch in range(epochs):
        # Training
        model.train()
        correct, total, train_loss = 0, 0, 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Training]"):
            inputs = batch['pixel_values'].to(device)
            labels = batch['label'].to(device)

            outputs = model(inputs).logits
            loss = criterion(outputs, labels)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

        epoch_train_acc = correct / total
        epoch_train_loss = train_loss / len(train_loader)
        train_acc.append(epoch_train_acc)
        train_losses.append(epoch_train_loss)

        # Validation
        model.eval()
        correct, total, val_loss = 0, 0, 0
        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating"):
                inputs = batch['pixel_values'].to(device)
                labels = batch['label'].to(device)

                outputs = model(inputs).logits
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                correct += (predicted == labels).sum().item()
                total += labels.size(0)

        epoch_val_acc = correct / total
        epoch_val_loss = val_loss / len(val_loader)
        val_acc.append(epoch_val_acc)
        val_losses.append(epoch_val_loss)

        print(f"Epoch {epoch+1}: "
              f"Train Acc = {epoch_train_acc:.4f}, Val Acc = {epoch_val_acc:.4f}, "
              f"Train Loss = {epoch_train_loss:.4f}, Val Loss = {epoch_val_loss:.4f}")

        # Save best model
        if epoch_val_acc > best_val_accuracy:
            best_val_accuracy = epoch_val_acc
            torch.save(model.state_dict(), "best_vit_hf_model.pth")

    return {
        "train_acc": train_acc,
        "val_acc": val_acc,
        "train_loss": train_losses,
        "val_loss": val_losses
    }

In [None]:
train_dataset_hf = CustomImageDataset(train_df_new, feature_extractor)
val_dataset_hf = CustomImageDataset(valid_df_new, feature_extractor)
test_dataset_hf = CustomImageDataset(test_df_new, feature_extractor)

train_loader = DataLoader(train_dataset_hf, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset_hf, batch_size=16)
test_loader = DataLoader(test_dataset_hf, batch_size=16)

In [None]:
hist_vit = train_model(vit_model_pretrained, train_loader, val_loader, epochs=6)

In [None]:
plot_training_curves(hist_vit)

In [None]:
# evaluate_model(vit_model_pretrained, test_loader)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class_names = label_encoder.classes_

run_pretrained_evaluation(vit_model_pretrained, test_loader, device, class_names, "ViT (Pretrained)")

In [None]:
record_model_results_hf(
    "ViT (Pretrained)",
    hist_vit["train_acc"],
    hist_vit["val_acc"],
    test_loader,
    vit_model_pretrained,
    device
)
display_results_table()

# Swin Pretrained

In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification

In [None]:
# Load Swin-Tiny preprocessor
swin_model_name = "microsoft/swin-tiny-patch4-window7-224"
swin_processor = AutoImageProcessor.from_pretrained(swin_model_name)

# Load the pre-trained model state dictionary, excluding the classifier weights
try:
    state_dict = AutoModelForImageClassification.from_pretrained(swin_model_name).state_dict()
    # Filter out the classifier weights and bias
    state_dict = {k: v for k, v in state_dict.items() if 'classifier' not in k}

    # Define a new Swin model with the correct number of labels
    swin_model = AutoModelForImageClassification.from_pretrained(
        swin_model_name,
        num_labels=len(label_encoder.classes_),
        ignore_mismatched_sizes=True # Keep this as a safeguard, though not strictly necessary with state dict loading
    ).to(device)

    # Load the filtered state dictionary into the new model
    # The `strict=False` argument allows loading even if some keys (like the new classifier) are missing in the state_dict
    swin_model.load_state_dict(state_dict, strict=False)

    print("Pre-trained Swin model loaded successfully, excluding classifier weights.")

except Exception as e:
    print(f"Error loading pre-trained Swin model: {e}")


# Define optimizer and loss function
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(swin_model.parameters(), lr=1e-4)

# Create datasets and data loaders
swin_train_dataset_hf = CustomImageDataset(train_df_new, swin_processor)
swin_val_dataset_hf = CustomImageDataset(valid_df_new, swin_processor)
swin_test_dataset_hf = CustomImageDataset(test_df_new, swin_processor)

swin_train_loader = DataLoader(swin_train_dataset_hf, batch_size=16, shuffle=True)
swin_val_loader = DataLoader(swin_val_dataset_hf, batch_size=16)
swin_test_loader = DataLoader(swin_test_dataset_hf, batch_size=16)

In [None]:
#Train
hist_swin = train_model(swin_model, swin_train_loader, swin_val_loader, epochs=5)

In [None]:
plot_training_curves(hist_swin)

In [None]:
# evaluate_model(swin_model, swin_test_loader)
class_names = label_encoder.classes_
run_pretrained_evaluation(swin_model, test_loader, device, class_names, "Swin (Pretrained)")

In [None]:
record_model_results_hf(
    "Swin (Pretrained)",
    hist_swin["train_acc"],
    hist_swin["val_acc"],
    test_loader,
    swin_model,
    device
)
display_results_table()

# MaxVit Pretrained

In [None]:
from transformers import AutoImageProcessor, AutoModelForImageClassification, AutoConfig

In [None]:
maxvit_model_name = "timm/maxvit_tiny_rw_224.sw_in1k"

maxvit_processor = AutoImageProcessor.from_pretrained(
    maxvit_model_name,
    do_resize=True,
    size={"shortest_edge": 224},
    do_normalize=True
)

maxvit_model = AutoModelForImageClassification.from_pretrained(
    maxvit_model_name,
    num_labels=len(label_encoder.classes_),
    ignore_mismatched_sizes=True
).to(device)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(maxvit_model.parameters(), lr=1e-4)

In [None]:
maxVit_train_dataset_hf = CustomImageDataset(train_df_new, maxvit_processor)
maxVit_val_dataset_hf = CustomImageDataset(valid_df_new, maxvit_processor)
maxVit_test_dataset_hf = CustomImageDataset(test_df_new, maxvit_processor)

maxVit_train_loader = DataLoader(maxVit_train_dataset_hf, batch_size=4, shuffle=True)
maxVit_val_loader = DataLoader(maxVit_val_dataset_hf, batch_size=4)
maxVit_test_loader = DataLoader(maxVit_test_dataset_hf, batch_size=4)

In [None]:
torch.cuda.empty_cache()

In [None]:
# Train
hist_maxvit = train_model(maxvit_model, maxVit_train_loader, maxVit_val_loader, epochs=5)

In [None]:
plot_training_curves(hist_maxvit)

In [None]:
# evaluate_model(maxvit_model, maxVit_test_loader)
class_names = label_encoder.classes_
run_pretrained_evaluation(maxvit_model, test_loader, device, class_names, "MaxViT (Pretrained)")

In [None]:
record_model_results_hf(
    "MaxVit (Pretrained)",
    hist_maxvit["train_acc"],
    hist_maxvit["val_acc"],
    test_loader,
    maxvit_model,
    device
)
display_results_table()