<a href="https://colab.research.google.com/github/ben900926/Plant-seedling-classification/blob/main/AI_final_project_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
import torch
import torchvision
import torch.nn as nn
from glob import glob
from torchvision import models
import torch.utils.data as data 
from torchvision import transforms
from torchvision.datasets.folder import DatasetFolder
from PIL import Image
from tqdm import tqdm
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import numpy as np
import pandas as pd

In [None]:
from google.colab import drive 
drive.mount('/content/drive')
root = '/content/drive/MyDrive/AI_FinalProject/plant-seedlings-classification/train'

val_size = 0.2
batch_size = 64
num_workers = 8
num_epochs = 25

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# VGG Baseline
class VGG(nn.Module):
    def __init__(self, num_classes):
        super(VGG, self).__init__()
        self.model = models.vgg11(pretrained=True)
        self.model.classifier[6] = nn.Linear(in_features=4096, out_features=num_classes)

    def forward(self, x):
        return self.model(x)
    
# ResNet Baseline, Resnet18, Resnet34
class ResNet(nn.Module):
    def __init__(self, num_classes):
        super(ResNet, self).__init__()
        self.model = models.resnet34(pretrained=True)
        self.model.fc = nn.Linear(in_features=512, out_features=num_classes)

        for module in ['conv1', 'bn1', 'layer1']:
            for param in getattr(self.model, module).parameters():
                param.requires_grad = False
    def forward(self, x):
        return self.model(x)
             
#Densenet Baseline, Densenet 121
class Densenet(nn.Module): 
    def __init__(self, num_classes):
        super(Densenet, self).__init__()
        self.model = models.densenet121(pretrained=True)
        self.model.classifier = nn.Linear(1024, num_classes)

    def forward(self, x):
        return self.model(x)
    
model = VGG(num_classes=12)

Downloading: "https://download.pytorch.org/models/vgg11-8a719046.pth" to /root/.cache/torch/hub/checkpoints/vgg11-8a719046.pth


  0%|          | 0.00/507M [00:00<?, ?B/s]

In [None]:
test_paths = glob('/content/drive/MyDrive/AI_FinalProject/plant-seedlings-classification/test/*.png')

In [None]:
train_paths = []
val_paths = []
train_labels = []
val_labels = []

In [None]:
class_paths = glob(root + '/*')

classes = [path.split("/")[-1] for path in class_paths]

for i, cpath in enumerate(class_paths):
    paths = glob(cpath + '/*.png')
    train_split = int(len(paths) * 0.8)
    
    train_paths.extend(paths[:train_split])
    train_labels.extend([i] * train_split)
    
    val_paths.extend(paths[train_split:])
    val_labels.extend([i] *(len(paths) - train_split))

In [None]:
class Dataset(data.Dataset):

    def __init__(self, img_paths, img_labels=None, transform=None):
        self.transform = transform
        self.img_paths = img_paths
        self.img_labels = img_labels

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
            
        if self.img_labels is not None:
            label = self.img_labels[idx]
            return img, label
        else:
            return img, idx

    def __len__(self):
        return len(self.img_paths)

# Data augmentation and normalization for training
transform_options = [
    transforms.ColorJitter(brightness=0.5, contrast=0.5, hue=0.5),
    transforms.RandomRotation(degrees=[-15, 15]),
    transforms.GaussianBlur(kernel_size=3),
    transforms.RandomAffine(0, shear=20),
]
    
train_transforms = transforms.Compose([transforms.Resize((256, 256)),
                                       transforms.RandomRotation(degrees=(-8, 8)),
                                       transforms.RandomCrop((224, 224)),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.RandomApply([
                                          transforms.RandomChoice(transform_options)
                                      ], p=0.9),
                                       transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

val_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                     transforms.CenterCrop((224, 224)),
                                     transforms.ToTensor(), 
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

train_dataset = Dataset(train_paths, train_labels, transform=train_transforms)
val_dataset = Dataset(val_paths, val_labels, transform=val_transforms)
test_dataset = Dataset(test_paths,transform=val_transforms)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=False)

  cpuset_checked))


In [None]:
criterion = CrossEntropyLoss()
model = model.cuda()
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0.

    steps = len(train_loader.dataset) // batch_size + 1
    with tqdm(total=steps) as progress_bar:
        
        for i, (x, y) in enumerate(train_loader):
            x, y = x.cuda(), y.cuda()

            optimizer.zero_grad() 
            
            y_pred = model(x)
            
            loss = criterion(y_pred, y)
            
            loss.backward()
            
            optimizer.step()
            
            progress_bar.set_postfix(loss=loss.item())
            progress_bar.update(1)
            
            correct += torch.sum(torch.argmax(y_pred, dim=-1) == y)
            
        train_acc = float(correct.item()) / float(len(train_loader.dataset))
        print("Epoch %d: train correct: %.4f" % (epoch, train_acc))
        
    
    steps = len(val_loader.dataset) // batch_size + 1
    correct = 0.
    model.eval()

    with tqdm(total=steps) as progress_bar:
        for i, (x, y) in enumerate(val_loader):
            x, y = x.cuda(), y.cuda()
            
            y_pred = model(x)
            
            loss = criterion(y_pred, y)
            
            progress_bar.set_postfix(loss=loss.item())
            progress_bar.update(1)
            
            correct += torch.sum(torch.argmax(y_pred, dim=-1) == y)
        
        val_acc = float(correct.item()) / float(len(val_loader.dataset))
        print("Epoch %d: val correct: %.4f" % (epoch, val_acc))

  cpuset_checked))
100%|██████████| 48/48 [01:34<00:00,  1.97s/it, loss=0.539]


Epoch 0: train correct: 0.4882


100%|██████████| 12/12 [00:33<00:00,  2.81s/it, loss=0.108]


Epoch 0: val correct: 0.7948


100%|██████████| 48/48 [01:07<00:00,  1.41s/it, loss=0.677]


Epoch 1: train correct: 0.7811


100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.0184]


Epoch 1: val correct: 0.8588


100%|██████████| 48/48 [01:08<00:00,  1.43s/it, loss=0.223]


Epoch 2: train correct: 0.8397


100%|██████████| 12/12 [00:12<00:00,  1.06s/it, loss=0.141]


Epoch 2: val correct: 0.8810


100%|██████████| 48/48 [01:10<00:00,  1.48s/it, loss=0.154]


Epoch 3: train correct: 0.8762


100%|██████████| 12/12 [00:12<00:00,  1.04s/it, loss=0.00267]


Epoch 3: val correct: 0.8967


100%|██████████| 48/48 [01:11<00:00,  1.49s/it, loss=0.299]


Epoch 4: train correct: 0.9042


100%|██████████| 12/12 [00:12<00:00,  1.07s/it, loss=0.0918]


Epoch 4: val correct: 0.8954


100%|██████████| 48/48 [01:11<00:00,  1.49s/it, loss=0.392]


Epoch 5: train correct: 0.9111


100%|██████████| 12/12 [00:12<00:00,  1.04s/it, loss=0.00307]


Epoch 5: val correct: 0.9163


100%|██████████| 48/48 [01:10<00:00,  1.48s/it, loss=0.194]


Epoch 6: train correct: 0.9286


100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.144]


Epoch 6: val correct: 0.9190


100%|██████████| 48/48 [01:07<00:00,  1.40s/it, loss=0.171]


Epoch 7: train correct: 0.9391


100%|██████████| 12/12 [00:12<00:00,  1.06s/it, loss=0.0307]


Epoch 7: val correct: 0.9281


100%|██████████| 48/48 [01:10<00:00,  1.46s/it, loss=0.117]


Epoch 8: train correct: 0.9282


100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.0104]


Epoch 8: val correct: 0.9320


100%|██████████| 48/48 [01:08<00:00,  1.43s/it, loss=0.0615]


Epoch 9: train correct: 0.9398


100%|██████████| 12/12 [00:12<00:00,  1.04s/it, loss=0.00241]


Epoch 9: val correct: 0.9373


100%|██████████| 48/48 [01:07<00:00,  1.41s/it, loss=0.275]


Epoch 10: train correct: 0.9546


100%|██████████| 12/12 [00:12<00:00,  1.03s/it, loss=0.0015]


Epoch 10: val correct: 0.9451


100%|██████████| 48/48 [01:08<00:00,  1.43s/it, loss=0.124]


Epoch 11: train correct: 0.9566


100%|██████████| 12/12 [00:12<00:00,  1.06s/it, loss=0.00291]


Epoch 11: val correct: 0.9386


100%|██████████| 48/48 [01:08<00:00,  1.43s/it, loss=0.0629]


Epoch 12: train correct: 0.9562


100%|██████████| 12/12 [00:12<00:00,  1.07s/it, loss=0.00325]


Epoch 12: val correct: 0.9216


100%|██████████| 48/48 [01:11<00:00,  1.49s/it, loss=0.135]


Epoch 13: train correct: 0.9635


100%|██████████| 12/12 [00:12<00:00,  1.07s/it, loss=0.00105]


Epoch 13: val correct: 0.9386


100%|██████████| 48/48 [01:09<00:00,  1.44s/it, loss=0.11]


Epoch 14: train correct: 0.9710


100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.000176]


Epoch 14: val correct: 0.9190


100%|██████████| 48/48 [01:08<00:00,  1.44s/it, loss=0.0172]


Epoch 15: train correct: 0.9700


100%|██████████| 12/12 [00:12<00:00,  1.04s/it, loss=0.000516]


Epoch 15: val correct: 0.9386


100%|██████████| 48/48 [01:07<00:00,  1.41s/it, loss=0.043]


Epoch 16: train correct: 0.9668


100%|██████████| 12/12 [00:12<00:00,  1.02s/it, loss=0.0141]


Epoch 16: val correct: 0.9320


100%|██████████| 48/48 [01:09<00:00,  1.45s/it, loss=0.291]


Epoch 17: train correct: 0.9747


100%|██████████| 12/12 [00:12<00:00,  1.02s/it, loss=0.0215]


Epoch 17: val correct: 0.9412


100%|██████████| 48/48 [01:08<00:00,  1.43s/it, loss=0.00782]


Epoch 18: train correct: 0.9691


100%|██████████| 12/12 [00:12<00:00,  1.05s/it, loss=0.0283]


Epoch 18: val correct: 0.9111


100%|██████████| 48/48 [01:07<00:00,  1.40s/it, loss=0.278]


Epoch 19: train correct: 0.9737


100%|██████████| 12/12 [00:12<00:00,  1.03s/it, loss=0.0336]


Epoch 19: val correct: 0.9399


100%|██████████| 48/48 [01:09<00:00,  1.46s/it, loss=0.167]


Epoch 20: train correct: 0.9737


100%|██████████| 12/12 [00:12<00:00,  1.04s/it, loss=5.63e-5]


Epoch 20: val correct: 0.9451


100%|██████████| 48/48 [01:09<00:00,  1.44s/it, loss=0.0909]


Epoch 21: train correct: 0.9822


100%|██████████| 12/12 [00:12<00:00,  1.04s/it, loss=0.0115]


Epoch 21: val correct: 0.9451


100%|██████████| 48/48 [01:07<00:00,  1.41s/it, loss=0.00615]


Epoch 22: train correct: 0.9845


100%|██████████| 12/12 [00:12<00:00,  1.03s/it, loss=0.00508]


Epoch 22: val correct: 0.9425


100%|██████████| 48/48 [01:07<00:00,  1.41s/it, loss=0.185]


Epoch 23: train correct: 0.9855


100%|██████████| 12/12 [00:12<00:00,  1.02s/it, loss=0.00603]


Epoch 23: val correct: 0.9425


100%|██████████| 48/48 [01:06<00:00,  1.38s/it, loss=0.0216]


Epoch 24: train correct: 0.9789


100%|██████████| 12/12 [00:12<00:00,  1.03s/it, loss=0.0175]

Epoch 24: val correct: 0.9373





In [None]:
test_label_indices = []
test_img_indices = []

for i, (x, img_idx) in enumerate(test_loader):
    x = x.cuda()
    y_pred = model(x)
    test_label_indices.extend(list(torch.argmax(y_pred, dim=-1).cpu().numpy()))
    test_img_indices.extend(list(img_idx.cpu().numpy()))
    
test_names = [test_paths[idx] for idx in test_img_indices]
test_names = [name.split("/")[-1] for name in test_names]
test_labels = [classes[idx] for idx in test_label_indices]

out_df = pd.DataFrame({'file': test_names,'species' : test_labels})
out_df.to_csv('/content/drive/MyDrive/AI_FinalProject/VGG.csv', index=False)

  cpuset_checked))
