In [1]:
import torch, torchvision
import cv2
import matplotlib.pyplot as plt

from tqdm import tqdm

In [6]:
import os

MODELS = "../models/"

print(os.path.exists(MODELS))


True


In [3]:
device = None
device_count = 0
cuda_available = torch.cuda.is_available()
mps_available = torch.backends.mps.is_available()

if cuda_available:
    device = torch.device("cuda")
    device_count = torch.cuda.device_count()
elif mps_available:
    device = torch.device("mps")
    device_count = torch.mps.device_count()
else:
    device = torch.device("cpu")
    device_count = torch.cpu.device_count()

print(f"Using device: {device}:{device_count}")


Using device: cuda:1


In [4]:
# load plantseg dataset

PLANT_VILLAGE = "../data/raw/Plant_leave_diseases_dataset_with_augmentation"

transforms = torchvision.transforms.Compose([
    # Convert to tensor and normalize using ImageNet stats
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    ),
    # Resize images to consistent size
    torchvision.transforms.Resize((224, 224)),
    # Random augmentations for training
    torchvision.transforms.RandomHorizontalFlip(p=0.5),
    torchvision.transforms.RandomVerticalFlip(p=0.5),
    torchvision.transforms.RandomRotation(degrees=30),
    torchvision.transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    # Random crop after resize for more robustness
    torchvision.transforms.RandomResizedCrop(224, scale=(0.8, 1.0))
])

plant_village = torchvision.datasets.ImageFolder(PLANT_VILLAGE, transform=transforms)

# plant_village_classes = plant_village.classes
# plant_village_samples = plant_village.samples

In [None]:
# model init

# split dataset into train, test, validation sets (80-10-10)
dataset_size = len(plant_village)
train_size = int(0.8 * dataset_size)
test_size = int(0.1 * dataset_size) 
val_size = dataset_size - train_size - test_size

train, test, val = torch.utils.data.random_split(
    plant_village, 
    [train_size, test_size, val_size]
)

BATCH_SIZE = 32
NUM_WORKERS = 2

train_loader = torch.utils.data.DataLoader(train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_loader = torch.utils.data.DataLoader(test, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = torch.utils.data.DataLoader(val, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

# print(f"Number of classes in train dataset: {len(train.dataset.classes)}")
# print(f"Number of classes in train_loader dataset: {len(train_loader.dataset.dataset.classes)}")


# print(f"Total dataset size: {dataset_size}")
# print(f"Training set size: {len(train)}")
# print(f"Test set size: {len(test)}")
# print(f"Validation set size: {len(val)}")


# load preloaded resnet 18 cnn
model = torchvision.models.resnet18()
# print(model)

# adding the fc layer to the model
num_classes = len(train.dataset.classes)
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

model = model.to(device)

# hyper parameters
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# train model
epochs = 10

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    with tqdm(train_loader, unit='batch', desc=f"Epoch {epoch+1}/{epochs}") as tepoch:
        for i, (inputs, labels) in enumerate(tepoch):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            # forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # backward pass
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

            # update the progress bar with the loss and accuracy
            tepoch.set_postfix(loss=running_loss / (i + 1), accuracy=100 * correct / total)

    # calculate accuracy
    accuracy = 100 * (correct / total)
    print(f"Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {accuracy:.2f}%")

    # validation
    model.eval()
    with torch.no_grad():
        valid_loss = 0.0
        valid_correct = 0
        valid_total = 0

        with tqdm(val_loader, unit='batch', desc="Validating") as vepoch:
            for (inputs, labels) in vepoch:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                valid_loss += loss.item()

                _, predicted = torch.max(outputs, 1)
                valid_correct += (predicted == labels).sum().item()
                valid_total += labels.size(0)

                vepoch.set_postfix(loss=valid_loss / (valid_total / 64), accuracy=100 * valid_correct / valid_total)

        valid_accuracy = 100 * (valid_correct / valid_total)
        print(f"Validation Loss: {valid_loss/len(val_loader):.4f}, Accuracy: {valid_accuracy:.2f}%")
        
# test the model
model.eval()
with torch.no_grad():
    test_correct = 0
    test_total = 0
    test_loss = 0.0
    for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            test_correct += (predicted == labels).sum().item()
            test_total += labels.size(0)

    test_accuracy = 100 * (test_correct / test_total)
    print(f"Test Accuracy: {test_accuracy:.2f}%")


Epoch 1/10: 100%|██████████| 1538/1538 [03:59<00:00,  6.43batch/s, accuracy=50.1, loss=1.7] 


Epoch [1/10], Loss: 1.6962, Accuracy: 50.12%


Validating: 100%|██████████| 193/193 [00:28<00:00,  6.71batch/s, accuracy=52.8, loss=3.29]


Validation Loss: 1.6374, Accuracy: 52.81%


Epoch 2/10: 100%|██████████| 1538/1538 [04:13<00:00,  6.06batch/s, accuracy=70.8, loss=0.935]


Epoch [2/10], Loss: 0.9353, Accuracy: 70.80%


Validating: 100%|██████████| 193/193 [00:30<00:00,  6.31batch/s, accuracy=69.5, loss=2.01]


Validation Loss: 1.0002, Accuracy: 69.46%


Epoch 3/10: 100%|██████████| 1538/1538 [04:10<00:00,  6.14batch/s, accuracy=78.7, loss=0.678]


Epoch [3/10], Loss: 0.6776, Accuracy: 78.66%


Validating: 100%|██████████| 193/193 [00:26<00:00,  7.18batch/s, accuracy=81.3, loss=1.17]


Validation Loss: 0.5806, Accuracy: 81.25%


Epoch 4/10: 100%|██████████| 1538/1538 [04:19<00:00,  5.92batch/s, accuracy=82.5, loss=0.543]


Epoch [4/10], Loss: 0.5428, Accuracy: 82.54%


Validating: 100%|██████████| 193/193 [00:28<00:00,  6.88batch/s, accuracy=80.9, loss=1.17]


Validation Loss: 0.5839, Accuracy: 80.94%


Epoch 5/10: 100%|██████████| 1538/1538 [04:18<00:00,  5.94batch/s, accuracy=85.1, loss=0.454]


Epoch [5/10], Loss: 0.4539, Accuracy: 85.12%


Validating: 100%|██████████| 193/193 [00:28<00:00,  6.77batch/s, accuracy=86.9, loss=0.83] 


Validation Loss: 0.4133, Accuracy: 86.89%


Epoch 6/10: 100%|██████████| 1538/1538 [04:18<00:00,  5.94batch/s, accuracy=87.2, loss=0.391]


Epoch [6/10], Loss: 0.3909, Accuracy: 87.21%


Validating: 100%|██████████| 193/193 [00:30<00:00,  6.42batch/s, accuracy=88.4, loss=0.728]


Validation Loss: 0.3626, Accuracy: 88.42%


Epoch 7/10: 100%|██████████| 1538/1538 [04:19<00:00,  5.94batch/s, accuracy=88.8, loss=0.341]


Epoch [7/10], Loss: 0.3414, Accuracy: 88.81%


Validating: 100%|██████████| 193/193 [00:31<00:00,  6.07batch/s, accuracy=88.5, loss=0.691]


Validation Loss: 0.3439, Accuracy: 88.52%


Epoch 8/10: 100%|██████████| 1538/1538 [04:27<00:00,  5.75batch/s, accuracy=90, loss=0.306]  


Epoch [8/10], Loss: 0.3057, Accuracy: 90.01%


Validating: 100%|██████████| 193/193 [00:34<00:00,  5.56batch/s, accuracy=89.1, loss=0.67] 


Validation Loss: 0.3336, Accuracy: 89.11%


Epoch 9/10: 100%|██████████| 1538/1538 [04:53<00:00,  5.24batch/s, accuracy=90.9, loss=0.28] 


Epoch [9/10], Loss: 0.2798, Accuracy: 90.88%


Validating: 100%|██████████| 193/193 [00:34<00:00,  5.62batch/s, accuracy=90.1, loss=0.582]


Validation Loss: 0.2898, Accuracy: 90.15%


Epoch 10/10: 100%|██████████| 1538/1538 [04:27<00:00,  5.76batch/s, accuracy=91.6, loss=0.258]


Epoch [10/10], Loss: 0.2578, Accuracy: 91.64%


Validating: 100%|██████████| 193/193 [00:32<00:00,  5.97batch/s, accuracy=92.7, loss=0.449]

Validation Loss: 0.2236, Accuracy: 92.73%





Test Accuracy: 92.32%


RuntimeError: [enforce fail at inline_container.cc:642] . invalid file name: ../models/

In [7]:
# Define the directory and filename
model_dir = '../models/'
os.makedirs(model_dir, exist_ok=True)  # Make sure the directory exists

# Define the model save path
model_save_path = os.path.join(model_dir, 'modelv1_1.pth')  # Adding a filename with extension

# Save the model state_dict
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}.")

Model saved to ../models/modelv1_1.pth.


In [8]:
checkpoint_dir = model_save_path + '/checkpoints/'
checkpoint_fname = f"modelv1_1_chkpnt.pth"

chkpnt_path = "../models/checkpoints/modelv1_1_checkpoint.pth"

torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'criterion_state_dict': criterion.state_dict(),
    'epoch': epoch,
    'loss': loss,
}, chkpnt_path)
print(f"Model checkpoint saved to {model_save_path}.")

Model checkpoint saved to ../models/modelv1_1.pth.


In [None]:
# loading the model for some ruf inference

