## IMPORT

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

from torchvision import transforms, models
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score

# Constants
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
IMG_SIZE = 128
BATCH_SIZE = 228  # Adjust for available VRAM
NUM_CLASSES = 6  # Fresh/Rotten for each fruit
BRUISED_CLASSES = 2  # Bruised or not bruised
# full file path
# C:\Users\Kenan\Downloads\CNN_train_test_model\datasplit
DATASET_PATH = 'datasplit' 
NUM_EPOCHS = 10

In [2]:
# Preprocessing transformations
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Dataset and Dataloaders
train_dataset = ImageFolder(root=f"{DATASET_PATH}/train", transform=transform)
val_dataset = ImageFolder(root=f"{DATASET_PATH}/val", transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [3]:

# Model definition
class ResNetClassifier(nn.Module):
    def __init__(self, num_classes, bruised_classes):
        super(ResNetClassifier, self).__init__()
        self.base_model = models.resnet50(pretrained=True)
        
        # Save the input feature size of the original fc layer
        in_features = self.base_model.fc.in_features
        
        # Replace the fc layer with an Identity layer
        self.base_model.fc = nn.Identity()
        
        # Define new classification layers
        self.classifier = nn.Linear(in_features, num_classes)
        self.bruised_classifier = nn.Linear(in_features, bruised_classes)

    def forward(self, x):
        x = self.base_model(x)  # Feature extraction
        fruit_class = self.classifier(x)  # Fruit and freshness classification
        bruised_class = self.bruised_classifier(x)  # Bruised/Not Bruised classification
        return fruit_class, bruised_class



In [4]:

# Initialize the model
model = ResNetClassifier(NUM_CLASSES, BRUISED_CLASSES).to(DEVICE)

# Optimizer and Loss
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()




In [5]:

# Training Function
def train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct_preds = 0
        total_samples = 0
        
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            # Zero the gradients
            optimizer.zero_grad()
            
            # Forward pass
            fruit_pred, bruised_pred = model(inputs)
            loss1 = criterion(fruit_pred, labels)  # Fruit type and freshness loss
            loss2 = criterion(bruised_pred, labels % 2)  # Bruised/Not Bruised loss
            loss = loss1 + loss2
            
            # Backward pass and optimization
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item()
            _, preds = torch.max(fruit_pred, 1)
            correct_preds += (preds == labels).sum().item()
            total_samples += labels.size(0)
        
        epoch_loss = running_loss / len(train_loader)
        epoch_acc = correct_preds / total_samples
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}")
    
    print("Training complete!")
    return model


In [6]:

# Train the model
model = train_model(model, train_loader, val_loader, criterion, optimizer, NUM_EPOCHS)

# Save the trained model
torch.save(model.state_dict(), "resnetv1_fruit_model.pth")
print("Model saved to resnetv1_fruit_model.pth")


Epoch 1/10, Loss: 2.1312, Accuracy: 0.3050
Epoch 2/10, Loss: 2.6820, Accuracy: 0.3750
Epoch 3/10, Loss: 1.8759, Accuracy: 0.6500
Epoch 4/10, Loss: 0.9054, Accuracy: 0.8075
Epoch 5/10, Loss: 0.8336, Accuracy: 0.8175
Epoch 6/10, Loss: 0.4757, Accuracy: 0.9075
Epoch 7/10, Loss: 0.3120, Accuracy: 0.9400
Epoch 8/10, Loss: 0.3028, Accuracy: 0.9425
Epoch 9/10, Loss: 0.1134, Accuracy: 0.9775
Epoch 10/10, Loss: 0.1735, Accuracy: 0.9675
Training complete!
Model saved to resnetv1_fruit_model.pth


In [11]:
# Function to evaluate a model
def evaluate_model(model, val_loader):
    model.eval()
    all_preds, all_labels = [], []
    all_bruised_preds, all_bruised_labels = [], []
    val_loss = 0.0

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
            
            # Forward pass
            fruit_pred, bruised_pred = model(inputs)
            loss1 = criterion(fruit_pred, labels)
            loss2 = criterion(bruised_pred, labels % 2)
            loss = loss1 + loss2
            val_loss += loss.item()

            # Store predictions
            all_preds.extend(torch.argmax(fruit_pred, dim=1).cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_bruised_preds.extend(torch.argmax(bruised_pred, dim=1).cpu().numpy())
            all_bruised_labels.extend((labels % 2).cpu().numpy())

    # Metrics
    precision = precision_score(all_labels, all_preds, average='weighted')
    recall = recall_score(all_labels, all_preds, average='weighted')
    f1 = f1_score(all_labels, all_preds, average='weighted')
    accuracy = accuracy_score(all_labels, all_preds)
    avg_loss = val_loss / len(val_loader)

    return precision, recall, f1, accuracy, avg_loss



In [12]:
# Metric Score of the ResNet
resnet_metrics = evaluate_model(model, val_loader)
print(f"ResNet Metrics: Precision: {resnet_metrics[0]:.4f}, Recall: {resnet_metrics[1]:.4f}, F1-Score: {resnet_metrics[2]:.4f}, Accuracy: {resnet_metrics[3]:.4f}, Loss: {resnet_metrics[4]:.4f}")


ResNet Metrics: Precision: 0.5302, Recall: 0.6071, F1-Score: 0.5275, Accuracy: 0.6071, Loss: 15.2111


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


## TODO


In [None]:
# display the graph for f-score, training, testing, and validation accuracy, together with the confusion matrix

In [8]:
# try out other models and techniques

In [9]:
# train and test the model image preprocessed (10,000 imagesssss    )