In [61]:
from torch.utils.data import Dataset
from torchvision import transforms
from torch import no_grad
from torch.nn import Flatten

In [62]:
import scipy
import os
from PIL import Image
import matplotlib.pyplot as plt

In [63]:
transform = transforms.Compose([
    transforms.Resize(256), # Resize shorter edge to 256
    transforms.CenterCrop(224), # Extract 224*224 center square
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.48, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [64]:
class OxfordFlowersDataset(Dataset):
    def __init__(self, root_dir, transform):
        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir)
        self.transform = transform

        # Load Matlab labels
        labels_mat = scipy.io.loadmat(os.path.join(root_dir, 'imagelabels.mat'))
        self.labels = labels_mat['labels'][0] - 1

        # Keep track of any errors we encounter
        self.error_log = []

        print(len(self.labels)) # Number of image labels

        # MIN Label
        print(f"Min Label: {self.labels.min()}")
        # MAX Label
        print(f"Min Label: {self.labels.max()}")


    def __len__(self):
        return len(self.labels) # 8189 samples
    

    def __getitem__(self, idx):
        try:
        # Build the image filename
            img_name = f'image_{idx+1:05d}.jpg'
            img_path = os.path.join(self.img_dir, img_name)
            image = Image.open(img_path)


            # Check for corruption
            image.verify() # Verify() closes the file
            image = Image.open(img_path)

            # Skip tiny images
            if image.size[0] < 32 or image.size[1] < 32:
                raise ValueError(f"Image too small: {image.size}")
            
            # Covert grayscale to RGB
            if image.mode != 'RGB':
                image = image.convert('RGB')


            # Load the image
            image = Image.open(img_path)
            label = self.labels[idx]

            if self.transform:
                image = self.transform(image)

            return image, label
        

        except Exception as e:
            # Log the issue instead of crashing
            self.error_log.append({
                'index': idx,
                'error': str(e),
                'path': img_path if 'img_path' in locals() else 'unknown'
            })

            print(f"Warning: Skipping corrupted image {idx}: {e}")
            # Try the next image (warp around if needed)
            next_idx = (idx + 1) % len(self)
            return self.__getitem__(next_idx)
        
    def get_error_summary(self):
        """Review what went wrong after training."""
        if not self.error_log:
            print("No errors encountered - dataset is clean")
        else:
            print(f"\nEncountered {len(self.error_log)} Problematic images: ")
            for error in self.error_log[:5]: # Show first 5
                print(f"Index {error['index']}: {error['error']}")
                if len(self.error_log):
                    print(f" ... and {len(self.error_log) - 5} more")

    

In [None]:
# Create Dataset
dataset = OxfordFlowersDataset(
    r"Oxford_102_flowers", transform=transform
    )
print(f"Total samples: {len(dataset)}") # Shows 8189
# Try loading one image
img, label = dataset[0]

8189
Min Label: 0
Min Label: 101
Total samples: 8189


In [None]:
len(dataset)

8189

In [67]:
# Check a few images
for i in [0, 100, 500]:
    img, _ = dataset[i]
    print(f'Image {i}: {img.size}')
    print(f'Image {i}: {type(img)}')

Image 0: <built-in method size of Tensor object at 0x00000287544ACE30>
Image 0: <class 'torch.Tensor'>
Image 100: <built-in method size of Tensor object at 0x0000028754522570>
Image 100: <class 'torch.Tensor'>
Image 500: <built-in method size of Tensor object at 0x00000287532D2090>
Image 500: <class 'torch.Tensor'>


In [68]:
# transform = transforms.Resize((224, 224)) # Can cause size

In [69]:
# img, _ = dataset[0]
# resized = transforms.Resize(256)(img)
# print(f"After resize: {resized.size}")  # (302,256) - Keeps aspect ratio!
# cropped = transforms.CenterCrop(224)(resized)
# print(f"After crop: {cropped.size}")    # (224,224) - Perfect Square

In [70]:
img, _ = dataset[1]
img.shape

torch.Size([3, 224, 224])

In [71]:
from torch.utils.data import random_split
from torch.utils.data import DataLoader

In [72]:
# Split into train/val/test: 70/15/15
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * (len(dataset)))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

In [73]:
print(f"Training: {len(train_dataset)}")
print(f"Training: {len(val_dataset)}")
print(f"Training: {len(test_dataset)}")

Training: 5732
Training: 1228
Training: 1229


In [74]:
# Create Dataloders for each set
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [75]:
for images, labels in train_loader:
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of labels shape: {labels.shape}")
    break

Batch of images shape: torch.Size([32, 3, 224, 224])
Batch of labels shape: torch.Size([32])


In [76]:
# Batching Numbers
batch_count = 0
total_images = 0

for images, labels in train_loader:
    batch_count += 1
    total_images += len(images)

    # Show the last few batches
    if batch_count >= 178:
        print(f"Batch {batch_count}: {len(images)} images")

print(f"\nTotal batches in one epoch: {batch_count}")
print(f'total_images seen: {total_images}')

Batch 178: 32 images
Batch 179: 32 images
Batch 180: 4 images

Total batches in one epoch: 180
total_images seen: 5732


In [77]:
# Verify everything works
print(f"Train: {len(train_loader)} batches")
print(f"Val: {len(val_loader)} batches")
print(f"Train: {len(test_loader)} batches")

# Quick test - get one batch from each
for name, loader, in [("Train", train_loader), ("Val", val_loader), ("Test", test_loader)]:
    images, labels = next(iter(loader))
    print(f"{name} batch: {images.shape}")

Train: 180 batches
Val: 39 batches
Train: 39 batches
Train batch: torch.Size([32, 3, 224, 224])
Val batch: torch.Size([32, 3, 224, 224])
Test batch: torch.Size([32, 3, 224, 224])


# Augmentation Transformation

In [78]:
# Training Transformations - with random augmentation
train_transform = transforms.Compose([
    # Random augmentations (different each time)
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(degrees=10),
    transforms.ColorJitter(brightness=0.2),

    # Standard Preprocessing
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Validation transformation
val_transforms = transforms.Compose([
    # Only Standard Preprocessing
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

In [79]:
# def visualize_augmentations(dataset, idx=0, num_version=8):
#     """ See what augmentation actually does to your images"""
#     fig, axes = plt.subplot(2, 4, figsize=(12, 6))
#     axes = axes.Flatten()

#     for i in range(num_version):
#         img, label = dataset[idx] # Get augmented version

#         # Denormalize for display
#         img = denormalize(img)

#         axes[i].imshow(img.permute(1,2,0)) # CHW --> HWC
#         axes[i].set_title(f"version {i+1}")
#         axes[i].axis('Off')

#         plt.subtitle(f"Same flower (index {idx}), 8 different augmentations")
#         plt.tight_layout()
#         plt.show()

# Define a simple CNN with only conv2d layers

In [80]:
import torch.nn as nn
from torch import optim

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        # First convolution block
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Second convolution block
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Third convolution block
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)

        # Flatten layer
        self.flatten = Flatten()

        # Fully connected layers
        # Input image is 32x32, after 3 pooling layers : 4x4
        self.fc1 = nn.Linear(128 * 28 * 28, 512)
        self.relu4 = nn.ReLU()
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, 102) # 15 classes in dataset


    def forward(self, x):
        # First conv block
        x = self.conv1(x)
        # print(f"Shape of conv1: {x.shape}")
        x = self.relu1(x)
        x = self.pool1(x)

        # Second conv block
        x = self.conv2(x)
        # print(f"Shape of conv2: {x.shape}")
        x = self.relu2(x)
        x = self.pool2(x)

        # Third conv block
        x = self.conv3(x)
        # print(f"Shape of conv3: {x.shape}")
        x = self.relu3(x)
        x = self.pool3(x)
        # print(f"Shape of pool3: {x.shape}")


        # Flatten before the fully connected layer
        x = self.flatten(x)
        # print(f"Shape of Flatten: {x.shape}")


      # Fully connected layers
        # Input image is 32x32, after 3 pooling layers : 4x4
        x = self.fc1(x)
        x = self.relu4(x)
        x = self.dropout(x)
        x = self.fc2(x)

        return x


In [82]:
# Create an instance of our CNN
model = SimpleCNN()
print(model)

SimpleCNN(
  (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (relu3): ReLU()
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=100352, out_features=512, bias=True)
  (relu4): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=512, out_features=102, bias=True)
)


## Defining loss and optimizer

In [83]:
# Define loss function
loss_function = nn.CrossEntropyLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=0.0005, weight_decay=0.0005)

## Training model

In [None]:
def train_epoch(model, train_loader, loss_function, optimizer, device=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (data, target) in enumerate(train_loader):
        # data, target = data.to(device), target.to(device)

        optimizer.zero_grad()
        output = model(data)
        loss = loss_function(output, target)
        loss.backward()
        optimizer.step()

        # Track progress
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()

        # Print every 100 batches
        if batch_idx % 100 == 0 and batch_idx > 0:
            avg_loss = running_loss/100
            accuracy = 100. * correct / total
            print(f" [{batch_idx * 64}/{60000}]"
                  f"Loss: {avg_loss:.3f} | Accuracy: {accuracy:.1f}%")
            running_loss = 0


## Setting up Evaluation

In [None]:
def evaluation(model, test_loader, device=None):
    model.eval()
    correct = 0
    total = 0

    with no_grad():
        for inputs, targets in test_loader:
            # inputs, targets = input.to(device), targets.to(device)
            outputs = model(inputs)
            _,predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
    return 100. * correct / total

## Putting it All Together

In [86]:
# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    print(f"\nEpoch: {epoch+1}")
    train_epoch(model, train_loader, loss_function, optimizer)
    accuracy = evaluation(model, val_loader)
    print(f"Test Accuracy: {accuracy:.2f}%")


Epoch: 1
 [6400/60000]Loss: 4.346 | Accuracy: 6.7%


KeyboardInterrupt: 

In [None]:
from pathlib import Path

# Create models directory
MODEL_PATH = Path('Saved_Models')
MODEL_PATH.mkdir(parents=True, exist_ok=True)

# Create model save path
MODEL_NAME = "Oxford_Flower_Classification.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

MODEL_SAVE_PATH

# Save the model state dict
print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=model.state_dict(), f=MODEL_SAVE_PATH)

In [None]:
# Load a PyTorch Model

# Create a new instance of linear regression model
loaded_model = SimpleCNN()

# Load the saved model state_dict
loaded_model.state_dict(torch.load(MODEL_SAVE_PATH))

# Put the loaded model to device
# loaded_model.to(device)

In [None]:
# next(loaded_model.parameters()).device

In [None]:
# Evaluate loaded model
with no_grad():
    for inputs, targets in test_loader:
        # inputs, targets = input.to(device), targets.to(device)
        outputs = model(inputs)

    # Example for classification
    probabilities = torch.softmax(outputs, dim=)
    predicted_class = torch.argmax(probabilities, dim=)
    print(f"Predicted class: {predicted_class.item()}")
    

<torch.utils.data.dataloader.DataLoader at 0x28857632d90>