In [14]:
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
import os
from PIL import Image
import xml.etree.ElementTree as ET

In [15]:
data = {
    'filename': [],
    'width': [],
    'height': [],
    'class': [],
    'xmin': [],
    'ymin': [],
    'xmax': [],
    'ymax': []
}

In [16]:
def get_file_image_dimensions(file_path):
    if not os.path.isfile(file_path):
        return None, None
    with Image.open(file_path) as img:
        width, height = img.size
    return width, height

def get_xml_image_dimensions(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    size = root.find('size')
    if size is not None:
        width = size.find('width').text
        height = size.find('height').text
        if width and height:
            return int(width), int(height)
    return 0, 0  


def get_image_dimensions(xml_file, image_file_path):
    width, height = get_xml_image_dimensions(xml_file)
    
    if width == 0 or height == 0:
        width, height = get_file_image_dimensions(image_file_path)
        
    return width, height


def parse_xml(xml_file, image_file_path):
    tree = ET.parse(xml_file)
    root = tree.getroot()

    filename = root.find('filename').text
    
    width, height = get_image_dimensions(xml_file, image_file_path)


    for obj in root.iter('object'):
        obj_class = obj.find('name').text
        bbox = obj.find('bndbox')
        xmin = int(bbox.find('xmin').text)
        ymin = int(bbox.find('ymin').text)
        xmax = int(bbox.find('xmax').text)
        ymax = int(bbox.find('ymax').text)

        data['filename'].append(filename)
        data['width'].append(width)
        data['height'].append(height)
        data['class'].append(obj_class)
        data['xmin'].append(xmin)
        data['ymin'].append(ymin)
        data['xmax'].append(xmax)
        data['ymax'].append(ymax)


In [21]:
class FruitDataset(Dataset):
    def __init__(self, data_dir, transforms=None, image_size=(224, 224)):
        self.data_dir = data_dir
        self.transforms = transforms
        self.image_size = image_size 
        
        self.images = [f for f in os.listdir(data_dir) if f.endswith('.jpg')]
        
        for image_file in self.images:
            xml_file = image_file.replace('.jpg', '.xml')
            xml_path = os.path.join(data_dir, xml_file)
            image_path = os.path.join(data_dir, image_file)
            if os.path.exists(xml_path):
                parse_xml(xml_path, image_path)
        
        self.dataframe = pd.DataFrame(data)

    def __len__(self):
        return len(self.images)
    
    def class_to_label(self, class_name):
        class_mapping = {'apple': 0, 'banana': 1, 'orange': 2}  
        return class_mapping.get(class_name, 0) 
    
    def __getitem__(self, idx):
        image_name = self.images[idx]
        image_path = os.path.join(self.data_dir, image_name)
    
        image = cv2.imread(image_path)
        
        # Convert BGR (OpenCV format) to RGB (PIL format)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
        # Convert numpy array (OpenCV) to PIL Image
        image = Image.fromarray(image)
    
        image_data = self.dataframe[self.dataframe['filename'] == image_name]
        class_name = image_data.iloc[0]['class']
        label = self.class_to_label(class_name)
        
        # Resize the image
        image = image.resize(self.image_size)
    
        # Apply the transformations
        if self.transforms:
            image = self.transforms(image)
        
        label = torch.tensor(label, dtype=torch.long)
        
        return image, label


In [22]:
transform = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


In [28]:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np
from torchvision import models

# Define class mapping and parameters
class_mapping = {0: 'apple', 1: 'banana', 2: 'orange'}
num_classes = len(class_mapping)
data_dir = './datasets/train_zip/train'
num_epochs = 4
batch_size = 32
learning_rate = 0.001

# Load the dataset and DataLoader
train_dataset = FruitDataset(data_dir, transforms=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

# Initialize and modify AlexNet model
model = models.alexnet(pretrained=True)
model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Initialize lists to store results
all_predictions, all_labels = [], []
incorrect_images, incorrect_labels, incorrect_preds = [], [], []

# Training loop
for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0.0, 0, 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()

        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()  # Backward pass
        optimizer.step()  # Optimization

        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Store predictions and labels
        all_predictions.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

        # Save incorrect predictions
        incorrect = predicted != labels
        incorrect_indices = incorrect.nonzero(as_tuple=True)[0]
        incorrect_images.extend(images[incorrect_indices].cpu())
        incorrect_labels.extend(labels[incorrect_indices].cpu().tolist())
        incorrect_preds.extend(predicted[incorrect_indices].cpu().tolist())

    epoch_loss = running_loss / len(train_loader)
    epoch_accuracy = correct / total
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}")

# Compute and plot confusion matrix
cm = confusion_matrix(all_labels, all_predictions, labels=list(class_mapping.keys()))
plt.figure(figsize=(6, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
            xticklabels=list(class_mapping.values()), yticklabels=list(class_mapping.values()))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

# Display incorrect images
if incorrect_images:
    print("Displaying images with incorrect predictions:")
    plt.figure(figsize=(10, 10))
    for i, img in enumerate(incorrect_images[:9]):  # Display up to 9 incorrect images
        plt.subplot(3, 3, i+1)
        img = img.permute(1, 2, 0)  # Convert from (C, H, W) to (H, W, C)
        plt.imshow(img.numpy())
        true_label = class_mapping[incorrect_labels[i]]
        pred_label = class_mapping[incorrect_preds[i]]
        plt.title(f"True: {true_label}, Pred: {pred_label}")
        plt.axis('off')
    plt.show()

print("Training finished")


TypeError: parse_xml() missing 1 required positional argument: 'image_file_path'