In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
train_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], 
                         [0.229, 0.224, 0.225])
])

In [4]:
data_dir = "./dataset"
dataset = datasets.ImageFolder(root=data_dir, transform=train_transform)

In [5]:
train_size = int(0.7 * len(dataset))
val_size   = len(dataset) - train_size
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_size, val_size])

In [6]:
val_ds.dataset.transform = val_transform

train_loader = DataLoader(train_ds, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_ds, batch_size=32, shuffle=False)

class_names = dataset.classes
print("Classes:", class_names)

Classes: ['Hawar', 'Karat', 'Sehat']


In [7]:
model = models.resnet50(pretrained=True)



In [8]:
# Freeze backbone
for param in model.parameters():
    param.requires_grad = False

# Replace classifier
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, len(class_names))  # 3 classes

model = model.to(device)

# Loss & Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=1e-4)

In [9]:
def train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10):
    for epoch in range(num_epochs):
        model.train()
        running_loss, running_corrects = 0, 0

        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc  = running_corrects.double() / len(train_loader.dataset)

        print(f"Epoch {epoch+1}/{num_epochs} - "
              f"Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

    return model

In [10]:
# Train
model = train_model(model, criterion, optimizer, train_loader, val_loader, num_epochs=10)

# Save
torch.save(model.state_dict(), "corn_resnet50.pth")
print("Model saved as corn_resnet50.pth")

Epoch 1/10 - Loss: 1.1066 Acc: 0.3272
Epoch 2/10 - Loss: 1.0388 Acc: 0.4424
Epoch 3/10 - Loss: 0.9867 Acc: 0.6037
Epoch 4/10 - Loss: 0.9483 Acc: 0.6728
Epoch 5/10 - Loss: 0.8977 Acc: 0.7419
Epoch 6/10 - Loss: 0.8601 Acc: 0.8157
Epoch 7/10 - Loss: 0.8246 Acc: 0.8157
Epoch 8/10 - Loss: 0.8059 Acc: 0.8203
Epoch 9/10 - Loss: 0.7610 Acc: 0.8479
Epoch 10/10 - Loss: 0.7347 Acc: 0.8479
Model saved as corn_resnet50.pth


In [11]:
model.eval()
y_true, y_pred = [], []

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

print(classification_report(y_true, y_pred, target_names=class_names))

              precision    recall  f1-score   support

       Hawar       0.79      0.70      0.74        33
       Karat       0.71      0.81      0.76        31
       Sehat       0.97      0.97      0.97        29

    accuracy                           0.82        93
   macro avg       0.82      0.82      0.82        93
weighted avg       0.82      0.82      0.82        93

