In [46]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
import json
import os
import sys

In [47]:
DATA_DIR='../data/PlantVillage'
BATCH_SIZE = 8
EPOCHS = 15
LEARNING_RATE = 0.001

In [48]:
sys.path.append(os.path.abspath('..'))
from models.model_v1 import get_model

In [49]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


In [50]:
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])
])

In [51]:
if not os.path.exists(DATA_DIR):
    print("path not found")
else:
    full_dataset = datasets.ImageFolder(DATA_DIR, transform=transform)
    print(f"Total image:{len(full_dataset)}")
    print(f"classes detected:{len(full_dataset.classes)}")

    train_size = int(0.8*len(full_dataset))
    val_size = len(full_dataset) - train_size
    train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])


    train_loader = DataLoader(train_dataset, batch_size = BATCH_SIZE , shuffle = True,drop_last = True)
    val_loader = DataLoader(val_dataset, batch_size = BATCH_SIZE , shuffle = False,drop_last = True)
    

Total image:20638
classes detected:15


In [52]:
model = get_model(num_classes = len(full_dataset.classes))
model = model.to(device)

In [53]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LEARNING_RATE)

In [54]:
print('starting training')

for epoch in range(EPOCHS):
    # --- TRAINING PHASE ---
    model.train()
    running_loss = 0.0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)          # <--- defined 'outputs' here
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)

    # --- VALIDATION PHASE ---
    model.eval()
    correct = 0     # <--- Defined as 'correct'
    total = 0       # <--- Defined as 'total'

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            
            # FIXED: using 'outputs' (plural) to match the next line
            outputs = model(images)      
            
            # FIXED: consistently using 'outputs'
            _, predicted = torch.max(outputs.data, 1) 
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    # FIXED: Using 'correct' and 'total' (not val_correct)
    val_acc = 100 * correct / total 

    # FIXED: Changed EPOCH to EPOCHS
    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Accuracy: {val_acc:.2f}%")

starting training
Epoch 1/15 | Train Loss: 0.7479 | Val Accuracy: 88.13%
Epoch 2/15 | Train Loss: 0.3553 | Val Accuracy: 93.41%
Epoch 3/15 | Train Loss: 0.2635 | Val Accuracy: 94.65%
Epoch 4/15 | Train Loss: 0.2059 | Val Accuracy: 93.51%
Epoch 5/15 | Train Loss: 0.1619 | Val Accuracy: 95.69%
Epoch 6/15 | Train Loss: 0.1402 | Val Accuracy: 94.50%
Epoch 7/15 | Train Loss: 0.1100 | Val Accuracy: 95.47%
Epoch 8/15 | Train Loss: 0.0987 | Val Accuracy: 97.50%
Epoch 9/15 | Train Loss: 0.1000 | Val Accuracy: 95.52%
Epoch 10/15 | Train Loss: 0.0715 | Val Accuracy: 98.04%
Epoch 11/15 | Train Loss: 0.0630 | Val Accuracy: 98.09%
Epoch 12/15 | Train Loss: 0.0709 | Val Accuracy: 97.55%
Epoch 13/15 | Train Loss: 0.0622 | Val Accuracy: 97.87%
Epoch 14/15 | Train Loss: 0.0506 | Val Accuracy: 96.90%
Epoch 15/15 | Train Loss: 0.0533 | Val Accuracy: 98.06%


In [55]:
torch.save(model.state_dict(),'../models/model_v2.pth')
print('model saved')

metrics = {
    "accuracy":accuracy,
    "dataset":"PlantVillage(emmarex dataset 1)",
    "model": "resnet18",
    "num_classes": len(full_dataset.classes)
}

with open('../results/metrics_v2.json','w') as f:
    json.dump(metrics,f,indent=4)
print("saved json")

model saved
saved json
