In [2]:
import torch, torchvision
import cv2
import matplotlib.pyplot as plt

from tqdm import tqdm

In [9]:
import os

MODELS = "../models/"

print(os.path.exists(MODELS))


True


In [4]:
device = None
device_count = 0
cuda_available = torch.cuda.is_available()
mps_available = torch.backends.mps.is_available()

if cuda_available:
    device = torch.device("cuda")
    device_count = torch.cuda.device_count()
elif mps_available:
    device = torch.device("mps")
    device_count = torch.mps.device_count()
else:
    device = torch.device("cpu")
    device_count = torch.cpu.device_count()

print(f"Using device: {device}:{device_count}")


Using device: mps:1


In [5]:
# load plantseg dataset

PLANT_VILLAGE = "../data/raw/Plant_leave_diseases_dataset_with_augmentation"

transforms = torchvision.transforms.Compose([
    # Convert to tensor and normalize using ImageNet stats
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    # Resize images to consistent size
    torchvision.transforms.Resize((224, 224)),
    # Random augmentations for training
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomVerticalFlip(p=0.5),
    torchvision.transforms.RandomRotation(degrees=30),
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    # Random crop after resize for more robustness
    torchvision.transforms.RandomResizedCrop(224, scale=(0.8, 1.0))
])

plant_village = torchvision.datasets.ImageFolder(PLANT_VILLAGE, transform=transforms)

# plant_village_classes = plant_village.classes
# plant_village_samples = plant_village.samples

In [10]:
# model init

# split dataset into train, test, validation sets (80-10-10)
dataset_size = len(plant_village)
train_size = int(0.8 * dataset_size)
test_size = int(0.1 * dataset_size) 
val_size = dataset_size - train_size - test_size

train, test, val = torch.utils.data.random_split(
    plant_village, 
    [train_size, test_size, val_size]
)

BATCH_SIZE = 32
NUM_WORKERS = 2

train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# print(f"Number of classes in train dataset: {len(train.dataset.classes)}")
# print(f"Number of classes in train_loader dataset: {len(train_loader.dataset.dataset.classes)}")


# print(f"Total dataset size: {dataset_size}")
# print(f"Training set size: {len(train)}")
# print(f"Test set size: {len(test)}")
# print(f"Validation set size: {len(val)}")


# load preloaded resnet 18 cnn
model = torchvision.models.resnet18()
# print(model)

# adding the fc layer to the model
num_classes = len(train.dataset.classes)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

model = model.to(device)

# hyper parameters
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# train model
epochs = 10

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(train_loader, unit='batch', desc=f"Epoch {epoch+1}/{epochs}") as tepoch:
        for i, (inputs, labels) in enumerate(tepoch):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # update the progress bar with the loss and accuracy
            tepoch.set_postfix(loss=running_loss / (i + 1), accuracy=100 * correct / total)

    # calculate accuracy
    accuracy = 100 * (correct / total)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%")

    # validation
    model.eval()
    with torch.no_grad():
        valid_loss = 0.0
        valid_correct = 0
        valid_total = 0

        with tqdm(val_loader, unit='batch', desc="Validating") as vepoch:
            for (inputs, labels) in vepoch:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                valid_correct += (predicted == labels).sum().item()
                valid_total += labels.size(0)

                vepoch.set_postfix(loss=valid_loss / (valid_total / 64), accuracy=100 * valid_correct / valid_total)

        valid_accuracy = 100 * (valid_correct / valid_total)
        print(f"Validation Loss: {valid_loss/len(val_loader):.4f}, Accuracy: {valid_accuracy:.2f}%")
        
# test the model
model.eval()
with torch.no_grad():
    test_correct = 0
    test_total = 0
    test_loss = 0.0
    for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            test_correct += (predicted == labels).sum().item()
            test_total += labels.size(0)

    test_accuracy = 100 * (test_correct / test_total)
    print(f"Test Accuracy: {test_accuracy:.2f}%")


torch.save(model.state_dict(), MODELS)


Epoch 1/10: 100%|██████████| 1538/1538 [03:58<00:00,  6.45batch/s, accuracy=49.7, loss=1.72]


Epoch [1/10], Loss: 1.7195, Accuracy: 49.73%


Validating: 100%|██████████| 193/193 [00:35<00:00,  5.48batch/s, accuracy=63.5, loss=2.37]


Validation Loss: 1.1780, Accuracy: 63.54%


Epoch 2/10: 100%|██████████| 1538/1538 [03:52<00:00,  6.63batch/s, accuracy=70.1, loss=0.967]


Epoch [2/10], Loss: 0.9674, Accuracy: 70.06%


Validating: 100%|██████████| 193/193 [00:35<00:00,  5.51batch/s, accuracy=70.8, loss=2]   


Validation Loss: 0.9976, Accuracy: 70.83%


Epoch 3/10: 100%|██████████| 1538/1538 [03:53<00:00,  6.57batch/s, accuracy=77.6, loss=0.699]


Epoch [3/10], Loss: 0.6991, Accuracy: 77.58%


Validating: 100%|██████████| 193/193 [00:34<00:00,  5.60batch/s, accuracy=77.2, loss=1.44]


Validation Loss: 0.7180, Accuracy: 77.24%


Epoch 4/10: 100%|██████████| 1538/1538 [03:52<00:00,  6.60batch/s, accuracy=82.1, loss=0.557]


Epoch [4/10], Loss: 0.5572, Accuracy: 82.09%


Validating: 100%|██████████| 193/193 [00:36<00:00,  5.33batch/s, accuracy=83.8, loss=1.03]


Validation Loss: 0.5141, Accuracy: 83.79%


Epoch 5/10: 100%|██████████| 1538/1538 [03:56<00:00,  6.50batch/s, accuracy=84.7, loss=0.469]


Epoch [5/10], Loss: 0.4691, Accuracy: 84.72%


Validating: 100%|██████████| 193/193 [00:34<00:00,  5.62batch/s, accuracy=87.1, loss=0.783]


Validation Loss: 0.3898, Accuracy: 87.06%


Epoch 6/10: 100%|██████████| 1538/1538 [03:51<00:00,  6.63batch/s, accuracy=87, loss=0.397]  


Epoch [6/10], Loss: 0.3969, Accuracy: 87.00%


Validating: 100%|██████████| 193/193 [00:33<00:00,  5.70batch/s, accuracy=87.9, loss=0.787]


Validation Loss: 0.3921, Accuracy: 87.93%


Epoch 7/10: 100%|██████████| 1538/1538 [03:52<00:00,  6.61batch/s, accuracy=88.5, loss=0.347]


Epoch [7/10], Loss: 0.3471, Accuracy: 88.51%


Validating: 100%|██████████| 193/193 [00:33<00:00,  5.75batch/s, accuracy=90.4, loss=0.605]


Validation Loss: 0.3010, Accuracy: 90.39%


Epoch 8/10: 100%|██████████| 1538/1538 [03:52<00:00,  6.61batch/s, accuracy=89.8, loss=0.311]


Epoch [8/10], Loss: 0.3107, Accuracy: 89.77%


Validating: 100%|██████████| 193/193 [00:33<00:00,  5.74batch/s, accuracy=90.4, loss=0.616]


Validation Loss: 0.3066, Accuracy: 90.36%


Epoch 9/10: 100%|██████████| 1538/1538 [03:50<00:00,  6.68batch/s, accuracy=91.2, loss=0.276]


Epoch [9/10], Loss: 0.2756, Accuracy: 91.15%


Validating: 100%|██████████| 193/193 [00:34<00:00,  5.65batch/s, accuracy=91.1, loss=0.53] 


Validation Loss: 0.2639, Accuracy: 91.06%


Epoch 10/10: 100%|██████████| 1538/1538 [03:50<00:00,  6.66batch/s, accuracy=91.5, loss=0.259]


Epoch [10/10], Loss: 0.2593, Accuracy: 91.49%


Validating: 100%|██████████| 193/193 [00:34<00:00,  5.61batch/s, accuracy=92.2, loss=0.491]

Validation Loss: 0.2445, Accuracy: 92.18%





Test Accuracy: 92.16%


RuntimeError: [enforce fail at inline_container.cc:642] . invalid file name: ../models/

In [11]:
# Define the directory and filename
model_dir = '../models/'
os.makedirs(model_dir, exist_ok=True)  # Make sure the directory exists

# Define the model save path
model_save_path = os.path.join(model_dir, 'modelv1_1.pth')  # Adding a filename with extension

# Save the model state_dict
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}.")


Model saved to ../models/modelv1_1.pth.


In [18]:
checkpoint_dir = model_save_path + '/checkpoints/'
checkpoint_fname = f"modelv1_1_chkpnt.pth"

chkpnt_path = "../models/checkpoints/modelv1_1_checkpoint.pth"

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'criterion_state_dict': criterion.state_dict(),
    'epoch': epoch,
    'loss': loss,
}, chkpnt_path)
