# IMAGE CLASSIFIER

We are going to build an image classifier that when given a picture, can classify whether it is an image or a fish

## Imports

In [1]:
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.optim as optim
import torch.nn.functional as F

## Dataset

We are going to use [imagenet](https://www.image-net.org/) a database with over hundred of thousands of images. It contains more than 14 million images and 20,000 image categories. It’s the standard that all image classifiers judge themselves against

In [2]:
import os
from PIL import Image
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader

# Define a helper function to validate images
def is_valid_image(filepath):
    try:
        with Image.open(filepath) as img:
            img.verify()  # Verify if it's a valid image
        return True
    except Exception:
        return False

# Custom ImageFolder to handle corrupted files
class SafeImageFolder(torchvision.datasets.ImageFolder):
    def __init__(self, root, transform=None):
        super().__init__(root, transform)
        self.samples = [(path, label) for path, label in self.samples if is_valid_image(path)]

# Data transforms
data_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Train Data
train_data_path = "/kaggle/input/fish-or-cat/images/train"
train_data = SafeImageFolder(root=train_data_path, transform=data_transforms)

# Validation Data
val_data_path = "/kaggle/input/fish-or-cat/images/val"
val_data = SafeImageFolder(root=val_data_path, transform=data_transforms)

# Test Data
test_data_path = "/kaggle/input/fish-or-cat/images/test"
test_data = SafeImageFolder(root=test_data_path, transform=data_transforms)

# Dataloaders
batch_size = 64
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=4)


## Config

In [3]:
config = {
    "batch" : 64,
    "epochs": 50,
    "lr":0.001
}

## DataLoader

In [4]:
train_dataloader = DataLoader(train_data, batch_size = config["batch"])
val_dataloader = DataLoader(val_data, batch_size = config["batch"])
test_dataloader = DataLoader(test_data, batch_size = config["batch"])

for image, label in train_dataloader:
    print(f"{image.shape}, {label.shape}")
    # the image shape should be (batch_size, channels, height, width) - so the first layer of the network should be channel X height X width
    # the label should be (64)
    break

torch.Size([64, 3, 64, 64]), torch.Size([64])


## Model Architecture

In [5]:
class CatorFish(nn.Module):
    def __init__(self):
        super(CatorFish, self).__init__()
        self.fc1 = nn.Linear(12288, 84)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(84, 50)
        self.relu2 = nn.ReLU()
        self.fc3 = nn.Linear(50, 2)
        # self.softmax = nn.Softmax()

    def forward(self, x):
        # x = x.view(-1, 12288) # flattening the image
        x = x.view(x.size(0), -1)  # Flatten the input
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu2(x)
        x = self.fc3(x)
        # x = self.softmax(x)
        return x

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
model = CatorFish()
model.to(device)
model

CatorFish(
  (fc1): Linear(in_features=12288, out_features=84, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=84, out_features=50, bias=True)
  (relu2): ReLU()
  (fc3): Linear(in_features=50, out_features=2, bias=True)
)

## OPTIMIZER & LOSS FUNCTIONS

In [6]:
criterion = nn.CrossEntropyLoss()
optimizer  = optim.Adam(model.parameters(), lr = config["lr"])

## TRAIN LOOP

In [7]:
def train(model, train_loader, val_loader, criterion, optimizer, epochs, device):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for images, labels in train_loader:
            # Forward pass
            images.to(device)
            labels.to(device)
            predictions = model(images)
            loss = criterion(predictions, labels)
    
            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item()
        training_loss /= len(train_loader)

        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs)
            targets = targets.to(device)
            loss = criterion(output,targets)
            valid_loss += loss.data.item()
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1],targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader)
        print('Epoch: {}, Training Loss: {:.2f},Validation Loss: {:.2f},accuracy = {:.2f}'.format(epoch, training_loss,valid_loss, num_correct / num_examples))    

In [8]:
train(model, train_dataloader, test_dataloader, criterion, optimizer, config["epochs"], device)

  correct = torch.eq(torch.max(F.softmax(output), dim=1)[1],targets).view(-1)


Epoch: 0, Training Loss: 1.61,Validation Loss: 1.63,accuracy = 0.45
Epoch: 1, Training Loss: 1.16,Validation Loss: 0.73,accuracy = 0.49
Epoch: 2, Training Loss: 0.66,Validation Loss: 0.58,accuracy = 0.72
Epoch: 3, Training Loss: 0.45,Validation Loss: 0.64,accuracy = 0.70
Epoch: 4, Training Loss: 0.50,Validation Loss: 0.56,accuracy = 0.77
Epoch: 5, Training Loss: 0.34,Validation Loss: 0.64,accuracy = 0.74
Epoch: 6, Training Loss: 0.39,Validation Loss: 0.58,accuracy = 0.77
Epoch: 7, Training Loss: 0.29,Validation Loss: 0.67,accuracy = 0.74
Epoch: 8, Training Loss: 0.33,Validation Loss: 0.61,accuracy = 0.80
Epoch: 9, Training Loss: 0.25,Validation Loss: 0.66,accuracy = 0.75
Epoch: 10, Training Loss: 0.27,Validation Loss: 0.65,accuracy = 0.80
Epoch: 11, Training Loss: 0.20,Validation Loss: 0.77,accuracy = 0.77
Epoch: 12, Training Loss: 0.24,Validation Loss: 0.70,accuracy = 0.79
Epoch: 13, Training Loss: 0.16,Validation Loss: 0.82,accuracy = 0.76
Epoch: 14, Training Loss: 0.19,Validation Lo

The model is experiencing overfitting, because the training loss flattens, whereas the validation loss is still high. 

- Training loss approaches zero, indicating that the model is learning the training data extremely well.
- Validation loss steadily increases, suggesting that the model is failing to generalize to unseen data.

## Save the Model

In [11]:
torch.save(model, "catorfish.pth")
# torch.save(model.parameters(), "catorfishparams.pth")