# validation.ipynb

Validation implementation.

Author: Connacher Murphy

In [19]:
# Libraries
import pest_classification as pest

import os
from sklearn.model_selection import StratifiedKFold
import timm
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from types import SimpleNamespace

In [20]:
# Grab training observations from images df
# CM: sampling to speed up execution
df_all = pest.df
df = df_all[df_all["set"] == "train_set"].sample(1024)
df = df.reset_index(drop=True)

In [21]:
# Dataset and dataloader
config = SimpleNamespace(**{})

config.batch_size = 32

config.image_dir = os.path.expanduser("~/data/ccmt/CCMT Dataset-Augmented")
config.image_size = 256


In [22]:
# Add folds to the dataframe
config.n_folds = 4

# CM: look at arguments
skf = StratifiedKFold(n_splits=config.n_folds)
# Should I add a shuffle here?
# skf = StratifiedKFold(n_splits=config.n_folds, shuffle=True, random_state=42)

In [23]:
# Partition into folds
for fold, (train_index, val_index) in enumerate(skf.split(df, df.label)):
    df.loc[val_index, "fold"] = fold

In [24]:
# Specify architecture parameters
# config.num_classes = len(pest.crop_classes["Maize"])
config.num_classes = 2
config.backbone = "resnet18"

# Specify optimizer parameters
config.lr = 1e-4
config.num_epochs = 3

In [25]:
# Select GPU if available
print(torch.cuda.is_available())
config.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

False


In [26]:
# Training function
def train(train_dataloader, valid_dataloader, model, optimizer, config):
    for epoch in range(config.num_epochs):
        print(f"Epoch {epoch + 1}")
        print("Training...")
        train_loss, train_accuracy = pest.train_epoch(
            train_dataloader, model, optimizer, config
        )
        print(f"Training: loss = {train_loss}, accuracy = {train_accuracy}")
        print("Validating...")
        valid_loss, valid_accuracy = pest.validate_epoch(
            valid_dataloader, model, config
        )
        print(f"Validation: loss = {valid_loss}, accuracy = {valid_accuracy}")

In [27]:
for fold in range(config.n_folds):
    print(f"Fold {fold}")

    # Split into training and validation sets
    train_df = df[df["fold"] != fold].reset_index(drop=True)
    valid_df = df[df["fold"] == fold].reset_index(drop=True)

    train_dataset = pest.AugmentedCCMT(config, train_df)
    valid_dataset = pest.AugmentedCCMT(config, valid_df)

    # Dataloaders
    train_dataloader = DataLoader(
        train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=0
    )
    valid_dataloader = DataLoader(
        valid_dataset, batch_size=config.batch_size, shuffle=False, num_workers=0
    )

    # Initialize (pre-trained) model
    model = timm.create_model(
        config.backbone, pretrained=True, num_classes=config.num_classes
    )
    model.to(config.device)

    # Specify loss function (CM: move this to outer loop?)
    config.criterion = nn.CrossEntropyLoss()

    # Initialize optimizer
    optimizer = optim.Adam(model.parameters(), lr=config.lr, weight_decay=0.0)

    # Call training function
    train(train_dataloader, valid_dataloader, model, optimizer, config)
    

Fold 0
Epoch 1
Training...


100%|██████████| 24/24 [01:14<00:00,  3.11s/it]


Training: loss = 0.3677935041487217, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:12<00:00,  1.60s/it]


Validation: loss = 0.2659143526107073, accuracy = 1.0
Epoch 2
Training...


100%|██████████| 24/24 [01:12<00:00,  3.04s/it]


Training: loss = 0.19109038884441057, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.46s/it]


Validation: loss = 0.1385960392653942, accuracy = 1.0
Epoch 3
Training...


100%|██████████| 24/24 [01:13<00:00,  3.05s/it]


Training: loss = 0.10954604391008615, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.46s/it]


Validation: loss = 0.10629034414887428, accuracy = 1.0
Fold 1
Epoch 1
Training...


100%|██████████| 24/24 [01:13<00:00,  3.06s/it]


Training: loss = 0.546210303902626, accuracy = 0.88671875
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.47s/it]


Validation: loss = 0.3700367659330368, accuracy = 1.0
Epoch 2
Training...


100%|██████████| 24/24 [01:13<00:00,  3.06s/it]


Training: loss = 0.2859940864145756, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.46s/it]


Validation: loss = 0.18649562820792198, accuracy = 1.0
Epoch 3
Training...


100%|██████████| 24/24 [01:13<00:00,  3.04s/it]


Training: loss = 0.15600505533317724, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.47s/it]


Validation: loss = 0.11109901033341885, accuracy = 1.0
Fold 2
Epoch 1
Training...


100%|██████████| 24/24 [01:13<00:00,  3.07s/it]


Training: loss = 0.46675671512881917, accuracy = 0.98828125
Validating...


100%|██████████| 8/8 [00:12<00:00,  1.56s/it]


Validation: loss = 0.31502656266093254, accuracy = 1.0
Epoch 2
Training...


100%|██████████| 24/24 [01:22<00:00,  3.43s/it]


Training: loss = 0.2489154109110435, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:13<00:00,  1.64s/it]


Validation: loss = 0.1868429183959961, accuracy = 1.0
Epoch 3
Training...


100%|██████████| 24/24 [01:13<00:00,  3.08s/it]


Training: loss = 0.1387908694644769, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:12<00:00,  1.62s/it]


Validation: loss = 0.12221243605017662, accuracy = 1.0
Fold 3
Epoch 1
Training...


100%|██████████| 24/24 [01:13<00:00,  3.06s/it]


Training: loss = 0.5676878032584985, accuracy = 0.8333333333333334
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.46s/it]


Validation: loss = 0.38919882103800774, accuracy = 1.0
Epoch 2
Training...


100%|██████████| 24/24 [01:13<00:00,  3.07s/it]


Training: loss = 0.2984118480235338, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.47s/it]


Validation: loss = 0.19266971945762634, accuracy = 1.0
Epoch 3
Training...


100%|██████████| 24/24 [01:12<00:00,  3.03s/it]


Training: loss = 0.16366835994025072, accuracy = 1.0
Validating...


100%|██████████| 8/8 [00:11<00:00,  1.47s/it]

Validation: loss = 0.12075814884155989, accuracy = 1.0



