<a href="https://colab.research.google.com/github/ben900926/Plant-seedling-classification/blob/main/AI_final_project_resnet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import sys
import torch
import torchvision
import torch.nn as nn
from glob import glob
from torchvision import models
import torch.utils.data as data 
from torchvision import transforms
from torchvision.datasets.folder import DatasetFolder
from PIL import Image
from tqdm import tqdm
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import numpy as np
import pandas as pd

In [None]:
from google.colab import drive 
drive.mount('/content/drive')
root = '/content/drive/MyDrive/AI_FinalProject/plant-seedlings-classification/train'

val_size = 0.2
batch_size = 64
num_workers = 8
num_epochs = 40

Mounted at /content/drive


In [None]:
# VGG Baseline
class VGG(nn.Module):
    def __init__(self, num_classes):
        super(VGG, self).__init__()
        self.model = models.vgg11(pretrained=True)
        self.model.classifier[6] = nn.Linear(in_features=4096, out_features=num_classes)

    def forward(self, x):
        return self.model(x)
    
# ResNet Baseline, Resnet18, Resnet34
class ResNet(nn.Module):
    def __init__(self, num_classes):
        super(ResNet, self).__init__()
        self.model = models.resnet34(pretrained=True)
        self.model.fc = nn.Linear(in_features=512, out_features=num_classes)

        for module in ['conv1', 'bn1', 'layer1']:
            for param in getattr(self.model, module).parameters():
                param.requires_grad = False
    def forward(self, x):
        return self.model(x)
             
#Densenet Baseline, Densenet 121
class Densenet(nn.Module): 
    def __init__(self, num_classes):
        super(Densenet, self).__init__()
        self.model = models.densenet121(pretrained=True)
        self.model.classifier = nn.Linear(1024, num_classes)

    def forward(self, x):
        return self.model(x)
    
model = ResNet(num_classes=12)

Downloading: "https://download.pytorch.org/models/resnet34-b627a593.pth" to /root/.cache/torch/hub/checkpoints/resnet34-b627a593.pth


  0%|          | 0.00/83.3M [00:00<?, ?B/s]

In [None]:
test_paths = glob('/content/drive/MyDrive/AI_FinalProject/plant-seedlings-classification/test/*.png')

In [None]:
train_paths = []
val_paths = []
train_labels = []
val_labels = []

In [None]:
class_paths = glob(root + '/*')

classes = [path.split("/")[-1] for path in class_paths]

for i, cpath in enumerate(class_paths):
    paths = glob(cpath + '/*.png')
    train_split = int(len(paths) * 0.8)
    
    train_paths.extend(paths[:train_split])
    train_labels.extend([i] * train_split)
    
    val_paths.extend(paths[train_split:])
    val_labels.extend([i] *(len(paths) - train_split))

In [None]:
class Dataset(data.Dataset):

    def __init__(self, img_paths, img_labels=None, transform=None):
        self.transform = transform
        self.img_paths = img_paths
        self.img_labels = img_labels

    def __getitem__(self, idx):
        img_path = self.img_paths[idx]
        img = Image.open(img_path).convert('RGB')
        
        if self.transform:
            img = self.transform(img)
            
        if self.img_labels is not None:
            label = self.img_labels[idx]
            return img, label
        else:
            return img, idx

    def __len__(self):
        return len(self.img_paths)
    
train_transforms = transforms.Compose([transforms.Resize((256, 256)),
                                       transforms.RandomRotation(degrees=(-8, 8)),
                                       transforms.RandomCrop((224, 224)),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ColorJitter(brightness=0.2, contrast=0.2,
                                                              saturation=0.2, hue=0.2),
                                       transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

val_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                     transforms.CenterCrop((224, 224)),
                                     transforms.ToTensor(), 
                                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])])

train_dataset = Dataset(train_paths, train_labels, transform=train_transforms)
val_dataset = Dataset(val_paths, val_labels, transform=val_transforms)
test_dataset = Dataset(test_paths,transform=val_transforms)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=batch_size,
                                           num_workers=num_workers,
                                           shuffle=True)

val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=False)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=False)

  cpuset_checked))


In [None]:
criterion = CrossEntropyLoss()
model = model.cuda()
optimizer = Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    correct = 0.

    steps = len(train_loader.dataset) // batch_size + 1
    with tqdm(total=steps) as progress_bar:
        
        for i, (x, y) in enumerate(train_loader):
            x, y = x.cuda(), y.cuda()

            optimizer.zero_grad() 
            
            y_pred = model(x)
            
            loss = criterion(y_pred, y)
            
            loss.backward()
            
            optimizer.step()
            
            progress_bar.set_postfix(loss=loss.item())
            progress_bar.update(1)
            
            correct += torch.sum(torch.argmax(y_pred, dim=-1) == y)
            
        train_acc = float(correct.item()) / float(len(train_loader.dataset))
        print("Epoch %d: train correct: %.4f" % (epoch, train_acc))
        
    
    steps = len(val_loader.dataset) // batch_size + 1
    correct = 0.
    model.eval()

    with tqdm(total=steps) as progress_bar:
        for i, (x, y) in enumerate(val_loader):
            x, y = x.cuda(), y.cuda()
            
            y_pred = model(x)
            
            loss = criterion(y_pred, y)
            
            progress_bar.set_postfix(loss=loss.item())
            progress_bar.update(1)
            
            correct += torch.sum(torch.argmax(y_pred, dim=-1) == y)
        
        val_acc = float(correct.item()) / float(len(val_loader.dataset))
        print("Epoch %d: val correct: %.4f" % (epoch, val_acc))

  cpuset_checked))
100%|██████████| 48/48 [01:57<00:00,  2.45s/it, loss=0.448]


Epoch 0: train correct: 0.6636


100%|██████████| 12/12 [00:35<00:00,  2.96s/it, loss=0.0536]


Epoch 0: val correct: 0.8458


100%|██████████| 48/48 [01:01<00:00,  1.28s/it, loss=0.208]


Epoch 1: train correct: 0.9049


100%|██████████| 12/12 [00:11<00:00,  1.05it/s, loss=0.0186]


Epoch 1: val correct: 0.9137


100%|██████████| 48/48 [01:00<00:00,  1.26s/it, loss=0.204]


Epoch 2: train correct: 0.9325


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.0373]


Epoch 2: val correct: 0.9203


100%|██████████| 48/48 [00:59<00:00,  1.25s/it, loss=0.145]


Epoch 3: train correct: 0.9490


100%|██████████| 12/12 [00:11<00:00,  1.08it/s, loss=0.0765]


Epoch 3: val correct: 0.8706


100%|██████████| 48/48 [01:01<00:00,  1.28s/it, loss=0.202]


Epoch 4: train correct: 0.9546


100%|██████████| 12/12 [00:11<00:00,  1.03it/s, loss=0.00875]


Epoch 4: val correct: 0.9242


100%|██████████| 48/48 [01:01<00:00,  1.29s/it, loss=0.0484]


Epoch 5: train correct: 0.9664


100%|██████████| 12/12 [00:11<00:00,  1.03it/s, loss=0.00659]


Epoch 5: val correct: 0.9150


100%|██████████| 48/48 [01:01<00:00,  1.27s/it, loss=0.02]


Epoch 6: train correct: 0.9677


100%|██████████| 12/12 [00:11<00:00,  1.03it/s, loss=0.00363]


Epoch 6: val correct: 0.9124


100%|██████████| 48/48 [01:03<00:00,  1.32s/it, loss=0.0556]


Epoch 7: train correct: 0.9760


100%|██████████| 12/12 [00:11<00:00,  1.03it/s, loss=0.0115]


Epoch 7: val correct: 0.9399


100%|██████████| 48/48 [01:02<00:00,  1.30s/it, loss=0.0564]


Epoch 8: train correct: 0.9770


100%|██████████| 12/12 [00:11<00:00,  1.01it/s, loss=0.00483]


Epoch 8: val correct: 0.9399


100%|██████████| 48/48 [01:02<00:00,  1.30s/it, loss=0.0864]


Epoch 9: train correct: 0.9779


100%|██████████| 12/12 [00:11<00:00,  1.00it/s, loss=0.01]


Epoch 9: val correct: 0.9451


100%|██████████| 48/48 [01:04<00:00,  1.33s/it, loss=0.0123]


Epoch 10: train correct: 0.9783


100%|██████████| 12/12 [00:12<00:00,  1.01s/it, loss=0.00245]


Epoch 10: val correct: 0.9425


100%|██████████| 48/48 [01:02<00:00,  1.30s/it, loss=0.0682]


Epoch 11: train correct: 0.9842


100%|██████████| 12/12 [00:11<00:00,  1.05it/s, loss=0.0403]


Epoch 11: val correct: 0.9294


100%|██████████| 48/48 [01:01<00:00,  1.28s/it, loss=0.0262]


Epoch 12: train correct: 0.9882


100%|██████████| 12/12 [00:11<00:00,  1.03it/s, loss=0.00719]


Epoch 12: val correct: 0.9359


100%|██████████| 48/48 [01:02<00:00,  1.30s/it, loss=0.0587]


Epoch 13: train correct: 0.9865


100%|██████████| 12/12 [00:11<00:00,  1.02it/s, loss=0.0107]


Epoch 13: val correct: 0.9359


100%|██████████| 48/48 [01:01<00:00,  1.29s/it, loss=0.0377]


Epoch 14: train correct: 0.9822


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.00579]


Epoch 14: val correct: 0.9386


100%|██████████| 48/48 [00:59<00:00,  1.25s/it, loss=0.0788]


Epoch 15: train correct: 0.9832


100%|██████████| 12/12 [00:11<00:00,  1.01it/s, loss=0.00709]


Epoch 15: val correct: 0.9451


100%|██████████| 48/48 [01:02<00:00,  1.31s/it, loss=0.0351]


Epoch 16: train correct: 0.9888


100%|██████████| 12/12 [00:11<00:00,  1.02it/s, loss=0.00914]


Epoch 16: val correct: 0.9516


100%|██████████| 48/48 [01:03<00:00,  1.32s/it, loss=0.123]


Epoch 17: train correct: 0.9882


100%|██████████| 12/12 [00:11<00:00,  1.00it/s, loss=0.00498]


Epoch 17: val correct: 0.9425


100%|██████████| 48/48 [01:00<00:00,  1.27s/it, loss=0.128]


Epoch 18: train correct: 0.9947


100%|██████████| 12/12 [00:11<00:00,  1.08it/s, loss=0.0251]


Epoch 18: val correct: 0.9582


100%|██████████| 48/48 [00:59<00:00,  1.25s/it, loss=0.0215]


Epoch 19: train correct: 0.9832


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.00872]


Epoch 19: val correct: 0.9595


100%|██████████| 48/48 [00:59<00:00,  1.25s/it, loss=0.00251]


Epoch 20: train correct: 0.9944


100%|██████████| 12/12 [00:11<00:00,  1.08it/s, loss=0.0108]


Epoch 20: val correct: 0.9516


100%|██████████| 48/48 [00:58<00:00,  1.21s/it, loss=0.0015]


Epoch 21: train correct: 0.9951


100%|██████████| 12/12 [00:11<00:00,  1.07it/s, loss=0.00575]


Epoch 21: val correct: 0.9399


100%|██████████| 48/48 [01:00<00:00,  1.26s/it, loss=0.151]


Epoch 22: train correct: 0.9941


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.0162]


Epoch 22: val correct: 0.9333


100%|██████████| 48/48 [00:58<00:00,  1.23s/it, loss=0.231]


Epoch 23: train correct: 0.9914


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.013]


Epoch 23: val correct: 0.9621


100%|██████████| 48/48 [00:58<00:00,  1.22s/it, loss=0.00786]


Epoch 24: train correct: 0.9905


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.0163]


Epoch 24: val correct: 0.9686


100%|██████████| 48/48 [01:00<00:00,  1.26s/it, loss=0.0255]


Epoch 25: train correct: 0.9918


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.00481]


Epoch 25: val correct: 0.9529


100%|██████████| 48/48 [00:58<00:00,  1.23s/it, loss=0.0258]


Epoch 26: train correct: 0.9951


100%|██████████| 12/12 [00:11<00:00,  1.07it/s, loss=0.00239]


Epoch 26: val correct: 0.9529


100%|██████████| 48/48 [00:59<00:00,  1.23s/it, loss=0.0024]


Epoch 27: train correct: 0.9964


100%|██████████| 12/12 [00:11<00:00,  1.05it/s, loss=0.00744]


Epoch 27: val correct: 0.9595


100%|██████████| 48/48 [01:00<00:00,  1.26s/it, loss=0.105]


Epoch 28: train correct: 0.9901


100%|██████████| 12/12 [00:11<00:00,  1.04it/s, loss=0.00728]


Epoch 28: val correct: 0.9451


100%|██████████| 48/48 [00:59<00:00,  1.24s/it, loss=0.00171]


Epoch 29: train correct: 0.9944


100%|██████████| 12/12 [00:11<00:00,  1.08it/s, loss=0.0173]


Epoch 29: val correct: 0.9464


100%|██████████| 48/48 [00:58<00:00,  1.22s/it, loss=0.0234]


Epoch 30: train correct: 0.9868


100%|██████████| 12/12 [00:11<00:00,  1.08it/s, loss=0.0102]


Epoch 30: val correct: 0.9477


100%|██████████| 48/48 [00:59<00:00,  1.25s/it, loss=0.00111]


Epoch 31: train correct: 0.9931


100%|██████████| 12/12 [00:11<00:00,  1.07it/s, loss=0.015]


Epoch 31: val correct: 0.9451


100%|██████████| 48/48 [00:59<00:00,  1.24s/it, loss=0.0037]


Epoch 32: train correct: 0.9957


100%|██████████| 12/12 [00:11<00:00,  1.07it/s, loss=0.00973]


Epoch 32: val correct: 0.9556


100%|██████████| 48/48 [01:00<00:00,  1.25s/it, loss=0.0389]


Epoch 33: train correct: 0.9941


100%|██████████| 12/12 [00:11<00:00,  1.01it/s, loss=0.00389]


Epoch 33: val correct: 0.9242


100%|██████████| 48/48 [01:02<00:00,  1.31s/it, loss=0.021]


Epoch 34: train correct: 0.9934


100%|██████████| 12/12 [00:12<00:00,  1.01s/it, loss=0.0153]


Epoch 34: val correct: 0.9333


100%|██████████| 48/48 [01:02<00:00,  1.31s/it, loss=0.00985]


Epoch 35: train correct: 0.9911


100%|██████████| 12/12 [00:11<00:00,  1.01it/s, loss=0.0912]


Epoch 35: val correct: 0.9503


100%|██████████| 48/48 [01:02<00:00,  1.29s/it, loss=0.00415]


Epoch 36: train correct: 0.9954


100%|██████████| 12/12 [00:11<00:00,  1.05it/s, loss=0.00662]


Epoch 36: val correct: 0.9373


100%|██████████| 48/48 [00:59<00:00,  1.25s/it, loss=0.0723]


Epoch 37: train correct: 0.9891


100%|██████████| 12/12 [00:11<00:00,  1.04it/s, loss=0.0199]


Epoch 37: val correct: 0.9359


100%|██████████| 48/48 [01:00<00:00,  1.25s/it, loss=0.106]


Epoch 38: train correct: 0.9862


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.0267]


Epoch 38: val correct: 0.9412


100%|██████████| 48/48 [00:59<00:00,  1.23s/it, loss=0.0132]


Epoch 39: train correct: 0.9947


100%|██████████| 12/12 [00:11<00:00,  1.06it/s, loss=0.0178]

Epoch 39: val correct: 0.9425





In [None]:
test_label_indices = []
test_img_indices = []

for i, (x, img_idx) in enumerate(test_loader):
    x = x.cuda()
    y_pred = model(x)
    test_label_indices.extend(list(torch.argmax(y_pred, dim=-1).cpu().numpy()))
    test_img_indices.extend(list(img_idx.cpu().numpy()))
    
test_names = [test_paths[idx] for idx in test_img_indices]
test_names = [name.split("/")[-1] for name in test_names]
test_labels = [classes[idx] for idx in test_label_indices]

out_df = pd.DataFrame({'file': test_names,'species' : test_labels})
out_df.to_csv('ResNet-34.csv', index=False)

  cpuset_checked))
