In [1]:
import torch
import torch.nn as nn
import warnings
warnings.filterwarnings("ignore")
import torch.optim as optim
from torchvision import transforms, datasets, models

transform= transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])

])
train_dataset = datasets.ImageFolder('data/train', transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)


In [2]:
model=nn.Sequential(
    models.resnet18(pretrained=True),
    nn.Linear(1000, 512),
    nn.ReLU(),
    nn.Linear(512, 38),
    nn.LogSoftmax(dim=1)
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
criterion = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
num_epochs = 3


In [3]:
for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    torch.save(model.state_dict(), f'plant_disease_model_epoch{epoch}.pth')
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")

Epoch [1/3], Loss: 0.4357
Epoch [2/3], Loss: 0.2056
Epoch [3/3], Loss: 0.1424


In [10]:
import pickle
# Save the model architecture and weights
with open('plant_disease_model.pkl', 'wb') as f:
    pickle.dump(model, f)
from PIL import Image

# Import model from local storage
model_loaded = model
model_loaded.load_state_dict(torch.load('plant_disease_model_epoch2.pth'))
model_loaded.eval()

# Save the model

# Load and preprocess the image
img_path = 'data/train/Tomato___Early_blight/0e03c87e-b43f-4cfe-a837-71306c68f4c0___RS_Erly.B 7733.JPG'
img = Image.open(img_path).convert('RGB')
img_tensor = transform(img).unsqueeze(0).to(device)

# Load model weights
# model.load_state_dict(torch.load('plant_disease_model_epoch2.pth'))
# model.eval()

# Predict
classes = train_dataset.classes
with torch.no_grad():
    output = model_loaded(img_tensor)
    pred = output.argmax(dim=1)
print(f"Predicted class: {pred.item()}")
print(f"Predicted class name: {classes[pred.item()]}")

Predicted class: 6
Predicted class name: Tomato___Early_blight
