In [2]:
# from google.colab import drive
from os.path import join
import os
import json
import numpy as np
from PIL import Image

from torchvision import datasets, transforms
from torchvision.models import resnet50
from torch.utils.data import DataLoader
from torch import nn, optim, cuda
from torch.backends import cudnn
import torch

In [None]:
# drive.mount("/content/drive")

In [27]:
data_dir = "/content/drive/Othercomputers/My Laptop"
notebook_dir = "/content/drive/MyDrive/FruitVegClassification"

with open(join(notebook_dir,"index_to_label.json")) as itl:
  index_to_label = json.load(itl)

with open(join(notebook_dir,"label_to_index.json")) as lti:
  label_to_index = json.load(lti)

device = "cuda" if cuda.is_available() else "cpu"
print("Using device", device)

cudnn.benchmark = True

Using device cuda


In [28]:
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224, scale=(0.85, 1.15)),
    transforms.RandomAffine(0, shear=15, translate=(0.2, 0.2)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

val_test_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=mean, std=std)
])

def pil_image_loader(path):
  with open(path, "rb") as f:
    img = Image.open(f)
    return img.convert("RGB")

In [30]:
train_dataset = datasets.ImageFolder(join(data_dir, "train"), transform=train_transforms, loader=pil_image_loader)
test_dataset = datasets.ImageFolder(join(data_dir, "test"), transform=val_test_transforms, loader=pil_image_loader)
val_dataset = datasets.ImageFolder(join(data_dir, "validation"), transform=val_test_transforms, loader=pil_image_loader)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=2, pin_memory=True)

In [None]:
base_model = resnet50(pretrained=True)



In [33]:
num_ftrs = base_model.fc.in_features

In [34]:
for param in base_model.parameters():
    param.requires_grad = False

unfreeze = False

for name, param in base_model.named_parameters():
    if "layer4" in name:
        unfreeze = True
    if unfreeze:
        param.requires_grad = True

In [35]:
base_model.fc = nn.Sequential(
    nn.Linear(num_ftrs, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 128),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(128, 36)
)

In [36]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(filter(lambda p: p.requires_grad, base_model.parameters()), lr=0.000001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.5, patience=3, min_lr=0.000001)

best_val_loss = np.inf
patience = 10
counter = 0
num_epochs = 100

In [None]:
base_model.to(device)

for epoch in range(num_epochs):
    base_model.train()
    train_losses = []
    for img, label in train_loader:
        img, label = img.to(device), label.to(device)
        optimizer.zero_grad()
        output = base_model(img)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        train_losses.append(loss.item())

    base_model.eval()
    val_losses = []
    preds, targets = [], []
    with torch.no_grad():
        for img, label in val_loader:
            img, label = img.to(device), label.to(device)
            output = base_model(img)
            loss = criterion(output, label)
            val_losses.append(loss.item())
            pred_labels = torch.argmax(output, dim=1)
            preds.extend(pred_labels.cpu().numpy())
            targets.extend(label.cpu().numpy())

    val_loss = np.mean(val_losses)
    print(f"Epoch {epoch+1}: Train Loss = {np.mean(train_losses):.4f}, Val Loss = {val_loss:.4f}")
    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(base_model.state_dict(), "BestResNet50Model.pth")
        counter = 0
    else:
        counter += 1
        if counter >= patience:
            print("Early stopping.")
            break

In [None]:
torch.save(base_model.state_dict(), "BestResNet50Model.pth")