In [1]:
import torch
import torchvision
import torchvision.transforms as transforms # input for augmentation, resizing etc
from torch.utils.data import Subset
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, random_split
from sklearn.model_selection import train_test_split
from PIL import Image # For checking GIF handling if needed

import os
import numpy as np


#Global dependecies
image_size = (224,224) #change depending on model
batch_size = 64 # Adjust as per your GPU memory
data_path = r"C:\Users\tsili\Documents\Meme_cleaner\dataset"
num_workers = os.cpu_count()

#calculate dataset mean and std values for normalisation

#Create a temporary dataset with minimal transforms (NO normalization)
stats_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor()
])

# Load the entire dataset with these minimal transforms
stats_dataset = ImageFolder(root=data_path, transform=stats_transforms)

# Create a loader to iterate over the dataset for stats calculation
# shuffle=False is fine here. num_workers can be set for speed.
stats_loader = DataLoader(
    stats_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count()
)

print("Starting Pass 1: Calculating dataset statistics...")

# Initialize accumulators
channels_sum, channels_squared_sum, num_batches = 0, 0, 0

for data, _ in stats_loader:
    # data shape: [B, C, H, W]
    # Sum over all dimensions except the channel dimension (C)
    channels_sum += torch.mean(data, dim=[0, 2, 3])
    channels_squared_sum += torch.mean(data**2, dim=[0, 2, 3])
    num_batches += 1

# Calculate final mean and std
mean = channels_sum / num_batches
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

print("Pass 1 Complete.")
print(f"Calculated Mean: {mean}")
print(f"Calculated Std:  {std}\n")


#Pipeline 
train_transforms = transforms.Compose([
    transforms.Resize(image_size),

    #augmentation here-------------------------
    transforms.RandomHorizontalFlip(p=0.5), # p=0.5 is default, just being explicit
    transforms.RandomRotation(15),
    # ColorJitter is very effective
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    # RandomAffine can do rotation, translation, scaling, and shear all in one
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
    #---------------------------------------------------------
    transforms.ToTensor(),
    transforms.Normalize(mean=mean.tolist(), std=std.tolist()) # Use the calculated stats

    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #Image net stas, though it is best to calucate the mean and std of the dataset to then normalise

])

val_test_transforms = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean.tolist(), std=std.tolist()) # Use the calculated stats
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

#load dataset
full_dataset =ImageFolder(root=data_path)

# Check class to index mapping
print(f"Classes: {full_dataset.classes}")
print(f"Class to index mapping: {full_dataset.class_to_idx}")
num_classes = len(full_dataset.classes) # Should be 2 for binary classification



Starting Pass 1: Calculating dataset statistics...
Pass 1 Complete.
Calculated Mean: tensor([0.4715, 0.4526, 0.4411])
Calculated Std:  tensor([0.3132, 0.3114, 0.3194])

Classes: ['meme', 'no_meme']
Class to index mapping: {'meme': 0, 'no_meme': 1}


In [2]:
train_dataset = ImageFolder(root=data_path, transform=train_transforms)
val_test_dataset = ImageFolder(root=data_path, transform=val_test_transforms)


# Get the list of targets (labels) for stratification
# .targets is an attribute of ImageFolder containing the label for each image
targets = train_dataset.targets

# Create a list of indices from 0 to len(dataset)-1
indices = list(range(len(targets)))

train_indices, temp_indices,_,_ = train_test_split (
    indices,
    targets,
    test_size= 0.2,
    random_state= 42,
    stratify= targets
)


temp_targets = [targets[i] for i in temp_indices]
val_indices, test_indicies, _, _= train_test_split(
    temp_indices,
    temp_targets,
    test_size=0.5,
    random_state=42,
    stratify=temp_targets
)

#combine the list of indices with the transforms

train_data= Subset(train_dataset, train_indices)
val_data=Subset(val_test_dataset, val_indices)
test_data = Subset(val_test_dataset, test_indicies)

# --- Verification ---
print(f"Number of training samples: {len(train_data)}")
print(f"Number of validation samples: {len(val_data)}")
print(f"Number of test samples: {len(test_data)}")



Number of training samples: 10828
Number of validation samples: 1354
Number of test samples: 1354


In [None]:
import torch
import torch.nn as nn
import torchvision.models as models
from torch.utils.data import DataLoader 
import pytorch_lightning as pl 
from pytorch_lightning.callbacks import ModelCheckpoint,EarlyStopping
import os 
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd # Often useful for this

#Data loaders
train_loader = DataLoader(train_data, batch_size=128, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_data, batch_size=128, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_data, batch_size=128, shuffle=False, num_workers=num_workers)

class MemeClassifier(pl.LightningModule):
    def __init__(self, learning_rate=0.001 ,num_classes=2):
        super().__init__()
        self.save_hyperparameters()

        self.model = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.DEFAULT)

        print(self.model)

        #Freeze pre-trained layers
        for param in self.model.parameters():
            param.requires_grad=False
        
        #Replace the backbone
        num_ftrs = self.model.classifier[3].in_features
        self.model.classifier[3] = nn.Linear(in_features=num_ftrs, out_features=self.hparams.num_classes)

        self.loss_fn = nn.CrossEntropyLoss()

        self.test_preds = []
        self.test_labels = []

    def forward(self, x):
        "Forward pass on the model"
        return self.model(x)

    def _calculate_metrics(self, batch):
        "Helper function to avoid code repetion"
        inputs, labels = batch
        outputs = self.forward(inputs)
        loss = self.loss_fn(outputs, labels)

        _, preds =torch.max(outputs, 1)
        acc = (preds == labels).float().mean()

        return loss, acc, preds, labels
    
    def training_step(self, batch, batch_idx):
        "this is what happens for one batch of training data"
        loss,acc, _, _ = self._calculate_metrics(batch)
        # In your MemeClassifier's training_step method
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        loss, acc, _, _ = self._calculate_metrics(batch)
        self.log('val_loss', loss)
        self.log('val_acc', acc)

    def test_step(self, batch, batch_idx):
        loss, acc, preds, labels = self._calculate_metrics(batch)
        self.log('test_loss', loss)
        self.log('test_acc', acc)
        self.test_preds.append(preds)
        self.test_labels.append(labels)
    
    def on_test_epoch_end(self):
        # This hook is called automatically at the end of the test run

        # Concatenate all the collected predictions and labels from each batch
        all_preds = torch.cat(self.test_preds).cpu().numpy()
        all_labels = torch.cat(self.test_labels).cpu().numpy()

        # Get the class names from your dataset object
        # For this example, we'll use the hardcoded names.
        # IMPORTANT: Ensure this order matches your dataset's classes
        class_names = ['meme', 'no meme']

        # --- 1. Plot the Confusion Matrix (Code from before) ---
        cm = confusion_matrix(all_labels, all_preds)
        plt.figure(figsize=(8, 6))
        sns.heatmap(
            cm,
            annot=True,
            fmt='d',
            cmap='Blues',
            xticklabels=class_names,
            yticklabels=class_names
        )
        plt.xlabel('Predicted Label', fontsize=12)
        plt.ylabel('True Label', fontsize=12)
        plt.title('Confusion Matrix', fontsize=16)
        plt.show() # Display the plot

        print("\n" + "="*50) # Add a separator for clarity

        # --- 2. NEW: Calculate and Print the Classification Report ---
        print("Classification Report:")
        report = classification_report(
            all_labels,
            all_preds,
            target_names=class_names
        )
        print(report)

        print("="*50)

        # Clear the lists for the next potential test run
        self.test_preds.clear()
        self.test_labels.clear()
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        return optimizer


#initialise the model
model_pl = MemeClassifier(learning_rate=0.001, num_classes=2)

#Callbacks
checkpoint_callback = ModelCheckpoint(
    monitor='val_acc',
    dirpath='my_checkpoints/', #change this
    filename = 'meme-classifier-{epoch:02d}-{val_acc:.2f}', 
    save_top_k=1,                # Save only the best 1 model
    mode='max'                  # The mode should be 'max' for accuracy
)

early_stopping_callback = EarlyStopping(
    monitor='val_acc', # The metric to monitor
    patience=5,        # How many epochs to wait for improvement before stopping
    verbose=True,      # Print a message when stopping
    mode='min'         # 'max' for accuracy, 'min' for loss
)

trainer = pl.Trainer(
    accelerator="auto",          # Automatically uses GPU or MPS if available
    precision="16-mixed",
    max_epochs=100,               # The number of epochs to train for
    callbacks=[checkpoint_callback,early_stopping_callback], # Add our checkpoint callback
    logger=True                  # Enables logging (e.g., for TensorBoard)
)

# --- Start Training! ---
# This single line replaces our entire manual for-loop.
print("Starting training with PyTorch Lightning...")
trainer.fit(model=model_pl, train_dataloaders=train_loader, val_dataloaders=val_loader)

# --- Run the Final Test ---
# After training, you can easily load the best model and run it on your test set.
print("\nTraining finished. Running on test set with the best model...")
trainer.test(model=model_pl,dataloaders=test_loader, ckpt_path='best') # 'best' loads the best checkpoint




Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


MobileNetV3(
  (features): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
      (2): Hardswish()
    )
    (1): InvertedResidual(
      (block): Sequential(
        (0): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=16, bias=False)
          (1): BatchNorm2d(16, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
          (2): ReLU(inplace=True)
        )
        (1): SqueezeExcitation(
          (avgpool): AdaptiveAvgPool2d(output_size=1)
          (fc1): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))
          (fc2): Conv2d(8, 16, kernel_size=(1, 1), stride=(1, 1))
          (activation): ReLU()
          (scale_activation): Hardsigmoid()
        )
        (2): Conv2dNormActivation(
          (0): Conv2d(16, 16, kernel_size=(1, 1), 

MisconfigurationException: No `configure_optimizers()` method defined. Lightning `Trainer` expects as minimum a `training_step()`, `train_dataloader()` and `configure_optimizers()` to be defined.

In [None]:
# plot_results.py

import pandas as pd
import matplotlib.pyplot as plt
import os

# --- Configuration ---
# Path to the lightning_logs directory
log_dir = "lightning_logs"
# ---

# --- Find the latest run and the metrics.csv file ---
try:
    # Get all subdirectories in lightning_logs (these are the 'version_X' folders)
    run_dirs = [os.path.join(log_dir, d) for d in os.listdir(log_dir) if os.path.isdir(os.path.join(log_dir, d))]
    # Find the most recently modified directory
    latest_run_dir = max(run_dirs, key=os.path.getmtime)
    metrics_file_path = os.path.join(latest_run_dir, "metrics.csv")
    print(f"Reading logs from: {metrics_file_path}")
except (ValueError, FileNotFoundError):
    print(f"Error: Could not find 'metrics.csv' in the '{log_dir}' directory.")
    print("Please ensure your training has run and created the log files.")
    exit()

if not os.path.exists(metrics_file_path):
    print(f"Error: The directory '{latest_run_dir}' does not contain a 'metrics.csv' file.")
    exit()

# --- Read the data using pandas ---
df = pd.read_csv(metrics_file_path)

# --- DEBUGGING STEP: Inspect the data ---
# This helps us see the exact column names
print("\n--- First 5 rows of the metrics file ---")
print(df.head())
print("\n--- Available columns in the CSV ---")
print(df.columns.tolist())
print("-------------------------------------------\n")


# --- Prepare the Data for Plotting ---
# The CSVLogger saves metrics at each step. Validation metrics like 'val_loss'
# only have a value at the end of an epoch, and are 'NaN' otherwise.
# We can filter the dataframe to get only the rows where validation happened.
epoch_data = df[df['val_loss'].notna()].reset_index()

if epoch_data.empty:
    print("Error: No completed validation epochs found in the log file.")
    print("Please make sure you have run at least one full training epoch.")
    exit()

# --- Create the Plots ---
plt.style.use('seaborn-v0_8-whitegrid')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
fig.suptitle(f'Performance for Run: {os.path.basename(latest_run_dir)}', fontsize=18)

# Plot 1: Accuracy vs. Epochs
# We check if the columns exist before trying to plot them
if 'train_acc_epoch' in epoch_data.columns and 'val_acc' in epoch_data.columns:
    ax1.plot(epoch_data['epoch'], epoch_data['train_acc_epoch'], label='Training Accuracy', marker='o', linestyle='--')
    ax1.plot(epoch_data['epoch'], epoch_data['val_acc'], label='Validation Accuracy', marker='o')
    ax1.set_title('Model Accuracy vs. Epochs', fontsize=14)
    ax1.set_xlabel('Epoch')
    ax1.set_ylabel('Accuracy')
    ax1.legend()
else:
    ax1.set_title("Accuracy data not found.\nCheck CSV columns and logging.", fontsize=14)

# Plot 2: Loss vs. Epochs
if 'train_loss_epoch' in epoch_data.columns and 'val_loss' in epoch_data.columns:
    ax2.plot(epoch_data['epoch'], epoch_data['train_loss_epoch'], label='Training Loss', marker='o', linestyle='--')
    ax2.plot(epoch_data['epoch'], epoch_data['val_loss'], label='Validation Loss', marker='o')
    ax2.set_title('Model Loss vs. Epochs', fontsize=14)
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Loss')
    ax2.legend()
else:
    ax2.set_title("Loss data not found.\nCheck CSV columns and logging.", fontsize=14)

plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.show()