In [1]:
!git clone https://github.com/pratikkayal/PlantDoc-Dataset.git

Cloning into 'PlantDoc-Dataset'...
remote: Enumerating objects: 2670, done.[K
remote: Counting objects: 100% (42/42), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 2670 (delta 23), reused 41 (delta 22), pack-reused 2628[K
Receiving objects: 100% (2670/2670), 932.92 MiB | 64.59 MiB/s, done.
Resolving deltas: 100% (24/24), done.
Updating files: 100% (2581/2581), done.


In [72]:
!git clone https://github.com/EliSchwartz/imagenet-sample-images.git

Cloning into 'imagenet-sample-images'...
remote: Enumerating objects: 1012, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 1012 (delta 3), reused 5 (delta 2), pack-reused 1002[K
Receiving objects: 100% (1012/1012), 103.84 MiB | 58.81 MiB/s, done.
Resolving deltas: 100% (3/3), done.
Updating files: 100% (1002/1002), done.


In [1]:
!echo "List of Intel GPUs available on the system:"
!xpu-smi  discovery 2> /dev/null
!echo "Intel Xeon CPU used by this notebook:"
!lscpu | grep "Model name"

List of Intel GPUs available on the system:
+-----------+--------------------------------------------------------------------------------------+
| Device ID | Device Information                                                                   |
+-----------+--------------------------------------------------------------------------------------+
| 0         | Device Name: Intel(R) Data Center GPU Max 1100                                       |
|           | Vendor Name: Intel(R) Corporation                                                    |
|           | UUID: 00000000-0000-0029-0000-002f0bda8086                                           |
|           | PCI BDF Address: 0000:29:00.0                                                        |
|           | DRM Device: /dev/dri/card1                                                           |
|           | Function Type: physical                                                              |
+-----------+----------------------------------

In [23]:
import torch
import torch.nn as nn
import torchvision
torchvision.disable_beta_transforms_warning()
import torchvision.transforms.v2
from torchvision import transforms
import intel_extension_for_pytorch as ipex
from PIL import Image
from tqdm.auto import tqdm
import numpy as np
import os, random

In [24]:
IMG_EXTENSIONS = [".png", ".jpg", ".jpeg"]

class DiseasedPlants(torch.utils.data.Dataset):
    def __init__(self, root_dir, split, size, disease_list = None, transform = None, pad_dir = None, pad_num = None):
        assert split in ['train', 'test', 'val']
        self.split = split
        self.dir = os.path.join(root_dir, split)
        self.size = size
        self.pad_dir = pad_dir
        self.num_disease = 0
        self.num_clean = 0
        
        # where folders that match with disease_list are labeled as 1 and everything else is 0
        self.disease_list = disease_list if not disease_list is None else ["spot", "rot", "blight", "virus", "rust",
                                                                           "mold", "spider", "scab", "bacterial", "mildew"]
        self.transform = transform if not transform is None else transforms.Compose([ 
            transforms.v2.Resize((size,size)), 
            transforms.v2.ToImageTensor(), 
            transforms.v2.ConvertImageDtype(),
            transforms.v2.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        ])

        self.files_labels = []
        for directory in os.listdir(self.dir):
            file_path = os.path.join(self.dir, directory)
            for img in os.listdir(file_path):
                img_path = os.path.join(file_path, img)
                if any(extension in img_path.lower() for extension in IMG_EXTENSIONS):
                    if any(diseases in directory for diseases in self.disease_list):
                        self.files_labels.append((img_path, 1))
                        self.num_disease += 1
                    else:
                        self.files_labels.append((img_path, 0))
                        self.num_clean += 1
                    
        self.pad_num = max(0, self.num_disease - self.num_clean) if pad_num is None else pad_num
        if not self.pad_dir is None and self.pad_num > 0:
            for image in np.random.choice(os.listdir(self.pad_dir), size=self.pad_num, replace=False):
                if any(extension in image.lower() for extension in IMG_EXTENSIONS):
                    self.files_labels.append((os.path.join(self.pad_dir, image), 0))
                    self.num_clean += 1

    def __len__(self):
        return len(self.files_labels)

    def __getitem__(self, idx):
        image = self.transform(Image.open(self.files_labels[idx][0]).convert('RGB'))
        return image, self.files_labels[idx][1]

In [25]:
BATCH_SIZE = 64
IMAGE_SIZE = 128
DEVICE = "xpu"
train_dataset = DiseasedPlants("./PlantDoc-Dataset/", "train", IMAGE_SIZE, pad_dir="./imagenet-sample-images")
test_dataset = DiseasedPlants("./PlantDoc-Dataset/", "test", IMAGE_SIZE, pad_dir="./imagenet-sample-images")
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [26]:
def train(model, optimizer, criterion, num_epochs=10, test_interval=1, save_dir="./checkpoints"):
    model.to(DEVICE)
    best_acc = 0
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    for epoch in range(num_epochs):
        model.train()
        with tqdm(total=len(train_loader),
                  desc=f'Epoch {epoch + 1}/{num_epochs}',
                  position=0,
                  leave=True) as pbar:
            for image, label in train_loader:
                image = image.to(DEVICE)
                label = label.to(torch.float32).to(DEVICE)
                pred = model(image)
                loss = criterion(pred.squeeze(), label)
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                # update progress bar
                pbar.update(1)
                pbar.set_postfix(loss=loss.item())
            
        if epoch % test_interval == 0:
            model.eval()
            total_acc = []
            with torch.no_grad():
                with tqdm(total=len(test_loader),
                  desc=f'Testing: ',
                  position=0,
                  leave=True) as pbar:
                    for image, label in test_loader:
                        image = image.to(DEVICE)
                        label = label.to(DEVICE)
                        pred = model(image)
                        total_acc.append((torch.mean(((pred > .5).int() == label).float())).item())
                        pbar.update(1)
                        pbar.set_postfix(acc=total_acc[-1])
            final_acc = sum(total_acc)/ len(total_acc)
            if final_acc > best_acc:
                best_acc = final_acc
                torch.save(model.state_dict(), os.path.join(save_dir, f"epoch{epoch+1}.pth"))
            print(f"Epoch {epoch + 1} got {final_acc:.3f} accuracy on test")

In [27]:
# mobile_net = torchvision.models.mobilenet_v2(weights="MobileNet_V2_Weights.DEFAULT")
mobile_net = torchvision.models.efficientnet_v2_s(weights="EfficientNet_V2_S_Weights.IMAGENET1K_V1")
mobile_net.classifier = torch.nn.Sequential(
    torch.nn.Dropout(p=0.2, inplace=False),
    torch.nn.Linear(in_features=1280, out_features=1, bias=True),
    torch.nn.Sigmoid()
)
mobile_net.classifier.requires_grad = True
criterion = torch.nn.BCELoss()
optimizer = torch.optim.NAdam(mobile_net.classifier.parameters(), lr=.0005, weight_decay=0)

In [None]:
train(mobile_net, optimizer, criterion)

Epoch 1/10:   0%|          | 0/47 [00:00<?, ?it/s]

Testing:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1 got 0.498 accuracy on test


Epoch 2/10:   0%|          | 0/47 [00:00<?, ?it/s]

Testing:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2 got 0.506 accuracy on test


Epoch 3/10:   0%|          | 0/47 [00:00<?, ?it/s]