Skip to content

Binary image classifier to identify my cat using transfer learning

License

Notifications You must be signed in to change notification settings

lponik/Mila-Classifier

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Mila Classifier 🐱

A binary image classifier that distinguishes my cat (Mila) from other cats, built using transfer learning with PyTorch.

Python PyTorch License

Project Overview

This is a personal project I built to learn about transfer learning and image classification. The goal is simple: given a photo of a cat, can the model tell if it's my cat Mila or some other cat?

I wanted to see how well modern deep learning techniques work when you don't have a massive dataset — just the photos on my phone.

Key Results

  • Validation Accuracy: ~90%+ (varies by run)
  • Training Time: ~1-2 minutes on Apple M-series GPU
  • Model Size: ~16 MB (EfficientNet-B0)

Table of Contents


Dataset

My Cat (Positive Class)

  • Source: Personal photos of Mila
  • Count: ~54 images
  • Format: Mix of JPG, HEIC converted to JPG
  • Characteristics: Various poses, lighting conditions, backgrounds, and distances

Not My Cat (Negative Class)

  • Source: Oxford-IIIT Pet Dataset
  • Count: ~2,400 images (multiple cat breeds)
  • Why Oxford Pets?: It's a well-known academic dataset with high-quality, consistently labeled cat images — perfect for representing "other cats" the model should learn to reject

Data Organization

data/
├── my_cat/          # Images of Mila
│   ├── IMG_0056.JPG
│   ├── IMG_0068.JPG
│   └── ...
└── not_my_cat/      # Oxford Pets cat images (12 breeds)
    ├── Abyssinian_1.jpg
    ├── Bengal_1.jpg
    ├── Siamese_1.jpg
    └── ...

Class Imbalance

The dataset is pretty imbalanced (~1:45 ratio of my_cat to not_my_cat). This is actually realistic — I only have so many photos of Mila, but there are tons of cat images out there. To handle this, I used class-weighted loss during training so the model doesn't just learn to always guess "not my cat."


Technical Approach

Transfer Learning with EfficientNet-B0

Training a neural network from scratch would require way more images than I have. Instead, I used transfer learning — basically taking a model that's already learned to recognize images and teaching it my specific task:

  1. Start with a pretrained model: EfficientNet-B0, which was trained on ImageNet (1.2M images, 1000 classes)
  2. Freeze the feature extractor: The convolutional layers already know how to detect useful things like edges, textures, and shapes — no need to retrain those
  3. Replace the classifier head: Swap out the 1000-class output layer for a simple 2-class layer (my_cat vs not_my_cat)
  4. Train only the new layer: This means I'm only training ~2,500 parameters instead of ~4 million

Why EfficientNet-B0?

I looked at a few different pretrained models before settling on EfficientNet-B0:

Model Parameters ImageNet Acc My Take
ResNet-18 11.7M 69.8% Solid choice, but a bit dated
EfficientNet-B0 5.3M 77.1% ✅ Great accuracy for its size
EfficientNet-B7 66M 84.3% Way overkill for this project
ViT-Base 86M 81.8% Needs way more data to work well

EfficientNet-B0 hit the sweet spot — good accuracy without being unnecessarily large.

Training Configuration

Here are the hyperparameters I used and why:

Hyperparameter Value Why I Chose This
Optimizer Adam It's reliable and adapts learning rates automatically
Learning Rate 1e-3 Standard starting point for transfer learning
Batch Size 16 Fits comfortably in memory on my MacBook
Epochs 5 The model converges pretty quickly since I'm only training the classifier
Image Size 224×224 What EfficientNet expects as input

Design Decisions

These are some of the choices I made while building this project and why:

1. Why Freeze the Feature Extractor?

Even with ~2,400 negative examples, I only have ~54 images of Mila. If I tried to fine-tune all 5.3M parameters, the model would probably just memorize my training images instead of learning what actually makes Mila look like Mila. Freezing the backbone is a simple way to prevent this overfitting.

2. Why Class-Weighted Loss?

My dataset is heavily skewed — there are way more "not my cat" images than "my cat" images. Without any correction, the model could hit ~98% accuracy by literally always guessing "not my cat." Class weights penalize mistakes on the minority class more heavily:

weight[class_i] = total_samples / count[class_i]

This forces the model to actually learn both classes.

3. Why Not Data Augmentation?

Data augmentation (random flips, rotations, color changes) usually helps when you have limited data. I decided to skip it here to keep things simple and see how far transfer learning alone could get me. Adding augmentation would be an easy next step if I wanted to improve the model.

4. Why Threshold Tuning?

By default, the model predicts "my cat" if the probability is > 0.5. But depending on what you care about, you might want to adjust this:

  • High threshold (0.9): Be really sure before saying it's Mila — fewer false positives
  • Low threshold (0.3): Don't miss any Mila photos — fewer false negatives

I included code to explore this trade-off at different thresholds.


Frequently Asked Questions

Here are some concepts that came up while I was building this project:

Q: What is transfer learning?

A: Transfer learning is when you take a model that was trained on one task (like classifying 1000 different objects) and adapt it for a different task (like identifying my cat). The idea is that the model has already learned useful low-level features — edges, textures, shapes — that transfer well to new image tasks. This saves a ton of time and data compared to training from scratch.

Q: What's the difference between logits, softmax, and probabilities?

A: This confused me at first too:

  • Logits: The raw numbers that come out of the model — they can be any value (e.g., [-2.3, 1.7])
  • Softmax: A function that squishes logits into probabilities
  • Probabilities: Nice numbers between 0 and 1 that add up to 1 (e.g., [0.02, 0.98])

Q: Why use CrossEntropyLoss instead of BCELoss?

A: Both work for binary classification. I went with CrossEntropyLoss because:

  • It takes raw logits directly (no need to apply softmax first)
  • It works with integer labels (0 or 1) which is simpler
  • If I ever wanted to add more classes, it would already support that

Q: What does model.eval() do?

A: It switches the model to evaluation mode. This matters because some layers (like dropout and batch normalization) behave differently during training vs. inference. Important note: model.eval() does NOT stop gradient computation — you need torch.no_grad() for that.

Q: Why is MPS (Metal Performance Shaders) used?

A: I'm running this on a MacBook with an M2 chip. MPS is Apple's framework for GPU-accelerated machine learning on their silicon. It makes training noticeably faster than running on CPU. The code automatically falls back to CPU if MPS isn't available.

Q: What's the difference between precision and recall?

A: These measure different types of mistakes:

  • Precision = TP / (TP + FP) → "When I predict 'my cat', how often am I right?"
  • Recall = TP / (TP + FN) → "Out of all the actual 'my cat' images, how many did I catch?"

High precision means few false alarms. High recall means you don't miss much. Usually there's a trade-off between them.

Q: How could I improve this model?

A: A few ideas I might try:

  1. Add data augmentation — random flips, rotations, brightness changes
  2. Take more photos of Mila — more variety would help
  3. Unfreeze some layers — let the top convolutional blocks fine-tune
  4. Try a slightly larger model — EfficientNet-B2 or B3
  5. Use more diverse negative examples — include more cat breeds

Q: Can this tell if an image is a cat or not?

A: Nope! The model assumes you're already giving it a cat photo. It just decides if that cat is Mila or not. If you fed it a picture of a dog, it would still output one of the two classes (probably "not my cat," but not for the right reasons). You'd need a separate model for general cat detection.


Setup & Usage

Prerequisites

  • Python 3.13
  • PyTorch 2.0+
  • torchvision
  • matplotlib
  • tqdm

Installation

# Clone the repository
git clone https://github.com/yourusername/mila-classifier.git
cd mila-classifier

# Install dependencies
pip install torch torchvision matplotlib tqdm

# Run the notebook
jupyter notebook main.ipynb

Using Your Own Cat

  1. Create a folder data/my_cat/ with 30-50 images of your cat
  2. Download cat images for data/not_my_cat/ (Oxford Pets, Google Images, etc.)
  3. Run all cells in main.ipynb
  4. Adjust the decision threshold based on your precision/recall requirements

Project Structure

mila-classifier/
├── main.ipynb          
├── readme.md           
└── data/
    ├── my_cat/         
    └── not_my_cat/     

Acknowledgments


About

Binary image classifier to identify my cat using transfer learning

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors