# 5.0 Training
It is time that we train the model. Simply, just use the script below!

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision import transforms
import matplotlib.pyplot as plt
import os

os.path.abspath(os.path.join(os.getcwd(), '..', 'dog_and_cat_classifier_cnn_from_scratch'))

from dog_and_cat_classifier_cnn_from_scratch.model import ResNet50
from dog_and_cat_classifier_cnn_from_scratch.data import CatAndDogDataset

# --- Hyperparameters ---
LEARNING_RATE = 0.001
NUM_EPOCHS = 10
BATCH_SIZE = 8
NUM_CLASSES = 2

# --- Setup ---
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using device: mps 🚀")
else:
    device = torch.device("cpu")
    print("MPS device not found. Using device: cpu 🐌")

# Instantiate model, dataset, and dataloader
transform = transforms.ToTensor()
model = ResNet50(num_classes=NUM_CLASSES, lr=LEARNING_RATE, in_channels=3).to(device)

# Instantiate the loss function and optimizer
criterion = model.loss
optimizer = model.configure_optimizers()

# Instantiate datasets and dataloaders
dataset = CatAndDogDataset(img_dir='../data/processed')
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)

# List to store average loss per epoch
epoch_losses = []

# --- Training Loop ---
for epoch in range(NUM_EPOCHS):
    model.train()

    # Track loss for the current epoch
    total_loss = 0

    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", unit="batch", colour="green")

    for batch_idx, (images, labels) in enumerate(progress_bar):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        progress_bar.set_postfix(loss=f'{loss.item():.4f}')

    avg_loss = total_loss / len(dataloader)
    epoch_losses.append(avg_loss)

    print(f"\n✨ Epoch {epoch+1} completed! Average loss: {avg_loss:.4f}\n")

print("🌟 Training finished! 🌟")

# --- Plotting the Results ---
plt.figure(figsize=(10, 6))
plt.plot(range(1, NUM_EPOCHS + 1), epoch_losses, marker='o', linestyle='-')
plt.title('Training Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Average Loss')
plt.grid(True)
plt.show()

Using device: mps 🚀


Epoch 1/10:   0%|[32m          [0m| 0/3119 [00:00<?, ?batch/s]


KeyboardInterrupt: 