In [1]:
# this file is not meant to be run directly, as many dependencies and data files
# are not in the GitHub repository
# instead, this file is here to show you the general training process of our
# machine-learning model

import torch
import numpy as np
from numpy import random
import matplotlib.pyplot as plt

import os
import cv2

from torchvision import models
import torchvision.transforms.functional as TF
from torch import utils, nn, optim
from tqdm import tqdm  # progress bar

import pandas as pd

from datetime import datetime

In [2]:
# hyperparameters and other variables

BATCH_SIZE = 128
LEARNING_RATE = 0.0001
EPOCH = 8

RESIZE = 64

DATA_AUG = True

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {DEVICE}")  # check cuda or cpu

Device: cuda


In [3]:
%%time

columns = ["Label", "Path"]
files_data = pd.DataFrame(columns=columns)

data_path = "data/garbage-classification"
categories = os.listdir(data_path)
for category in categories:
    category_path = data_path + "/" + category
    files = os.listdir(category_path)
    for file in files:
        if not file.startswith('.'):
            files_data.loc[file] = [category, category_path + "/" + file]

CPU times: total: 14.9 s
Wall time: 14.9 s


In [4]:
display(files_data.groupby("Label").count())

Unnamed: 0_level_0,Path
Label,Unnamed: 1_level_1
battery,945
biological,985
brown-glass,607
cardboard,891
clothes,5325
green-glass,629
metal,769
paper,1050
plastic,865
shoes,1977


In [5]:
def square_crop(image_array):
    is_portrait = image_array.shape[0] >= image_array.shape[1]
    large_side = image_array.shape[0] if is_portrait else image_array.shape[1]
    small_side = image_array.shape[1] if is_portrait else image_array.shape[0]
    crop_low = round(large_side / 2) - round(small_side / 2)
    crop_high = round(large_side / 2) + (small_side - round(small_side / 2))
    return image_array[crop_low:crop_high] if is_portrait\
           else image_array[:, crop_low:crop_high]


def read_image(path, crop=True, resize=RESIZE):
    try:
        image_array = square_crop(cv2.imread(path))
    except:
        print("Error: " + path)
    image_array = cv2.resize(
        image_array, dsize=[resize, resize], interpolation=cv2.INTER_CUBIC
    )
    return image_array

In [6]:
labels_unique = sorted(files_data["Label"].unique())
label_to_text = dict(zip(range(0, len(labels_unique)), labels_unique))
text_to_label = dict(zip(labels_unique, range(0, len(labels_unique))))

print(label_to_text)

{0: 'battery', 1: 'biological', 2: 'brown-glass', 3: 'cardboard', 4: 'clothes', 5: 'green-glass', 6: 'metal', 7: 'paper', 8: 'plastic', 9: 'shoes', 10: 'trash', 11: 'white-glass'}


In [7]:
%%time

files_data_shuffle = files_data.sample(frac=1.0)

label_limit = 9999
counts = dict(zip(labels_unique, [0] * len(labels_unique)))

inputs_list = list()
labels_list = list()
for file in files_data_shuffle.index:
    label = files_data.at[file, "Label"]
    
    if counts[label] < label_limit: 
        inputs_list.append(
            np.expand_dims(read_image(files_data.at[file, "Path"]), axis=0)
        )
        labels_list.append(text_to_label[label])
        counts[label] += 1

CPU times: total: 16.9 s
Wall time: 18.8 s


In [8]:
inputs = np.concatenate(inputs_list, axis=0) / 255
labels = np.asarray(labels_list)

In [9]:
def get_split(inputs, labels, ratio=[9, 1]):
    split_index = round(inputs.shape[0] * ratio[0] /sum(ratio))
    train_inputs, test_inputs = inputs[:split_index], inputs[split_index:]
    train_labels, test_labels = labels[:split_index], labels[split_index:]
    print(f"Train: {train_inputs.shape[0]}   Test: {test_inputs.shape[0]}")
    
    return (train_inputs, train_labels), (test_inputs, test_labels)

In [10]:
train_data_np, test_data_np = get_split(inputs, labels)

Train: 13964   Test: 1551


In [11]:
# make test data more balanced

test_limit = 80
test_counts = dict(zip(range(12), [0] * len(labels_unique)))

test_inputs, test_labels = list(), list()
for i, input_ in enumerate(test_data_np[0]):
    label = test_data_np[1][i]
    if test_counts[label] < test_limit:
        test_inputs.append(np.expand_dims(input_, axis=0))
        test_labels.append(label)
        test_counts[label] += 1

test_data_np = (np.concatenate(test_inputs, axis=0), np.asarray(test_labels))

In [12]:
class Dataset(utils.data.Dataset):
    def __init__(self, data):
        self.inputs, self.labels = data  # expects tuple of numpy arrays

    def classes(self):
        return self.labels

    def __len__(self):
        return self.labels.shape[0]
    
    def get_batch_inputs(self, idx):
        # fetch a batch of inputs
        return self.inputs[idx]
    
    def get_batch_labels(self, idx):
        # fetch a batch of labels
        return self.labels[idx]

    def __getitem__(self, idx):
        batch_inputs = self.get_batch_inputs(idx)
        batch_labels = self.get_batch_labels(idx)

        return batch_inputs, batch_labels

In [13]:
# define CNN model

class Model(nn.Module):
    
    def __init__(self, input_dim, output_dim):
        super(Model, self).__init__()
        # initialize layers
        self.resnet = models.resnet50(pretrained=True)
        self.resnet.fc = nn.Linear(2048, 1024)
        self.dense = nn.Sequential(
            nn.Linear(1024, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.1)
        )
        self.out = nn.Linear(128, output_dim)
    
    def forward(self, inputs):
        # print(inputs.shape)
        inputs = torch.movedim(inputs, -1, 1)
        # forward propagation
        inputs = self.resnet(inputs)
        inputs = self.dense(inputs)
        return self.out(inputs)

In [14]:
data_count = np.asarray(files_data.groupby("Label").count()["Path"])
aug_counts = (np.rint(data_count.max() / data_count) - 1).astype(int)
for i, aug_count in enumerate(aug_counts):
    print(f"[{i}] {label_to_text[i]}: {aug_count}")

[0] battery: 5
[1] biological: 4
[2] brown-glass: 8
[3] cardboard: 5
[4] clothes: 0
[5] green-glass: 7
[6] metal: 6
[7] paper: 4
[8] plastic: 5
[9] shoes: 2
[10] trash: 7
[11] white-glass: 6


In [15]:
def crop_resize(input_, size, center=True):
    if center:
        input_ = TF.center_crop(input_, size)
    else:
        input_ = TF.five_crop(input_, size)[random.randint(0, 5)]
    return TF.resize(input_, [RESIZE, RESIZE])


def data_aug(data):
    inputs, labels = data
    inputs = torch.movedim(torch.tensor(inputs), -1, 1)
    rotate_crop = RESIZE - round(0.29289321881345254 * RESIZE)
    crop_low = round(0.7 * RESIZE)
    gaussian_high = round(0.05 * RESIZE)
    
    inputs_aug, labels_aug = list(), list()
    for i, input_ in enumerate(tqdm(inputs)):
        for j in range(aug_counts[labels[i]]):
            random_bits = np.random.randint(2, size=8)
            all_aug = not np.sum(random_bits)
            # horizontal flip
            if random_bits[0] or all_aug:
                input_ = TF.hflip(input_)
            # vertical flip
            if random_bits[1] or all_aug:
                input_ = TF.vflip(input_)
            # rotation
            if random_bits[2] or all_aug:
                input_ = TF.rotate(input_, random.uniform(low=0.0, high=360.0))
                input_ = crop_resize(input_, rotate_crop, center=True)
            # hue
            if random_bits[3] or all_aug:
                input_ = TF.adjust_hue(
                    input_, random.uniform(low=-0.2, high=0.2)
                )
            # saturation
            if random_bits[4] or all_aug:
                input_ = TF.adjust_saturation(
                    input_, random.uniform(low=0.5, high=1.5)
                )
            # luminance (gamma)
            if random_bits[5] or all_aug:
                gamma = pow(random.normal(loc=1, scale=0.3), 2)
                input_ = TF.adjust_gamma(input_, np.clip(gamma, 0.3, 1.7))
            # gaussian blur
            if random_bits[6] or all_aug:
                input_ = TF.gaussian_blur(
                    input_, random.randint(low=1, high=gaussian_high) * 2 - 1
                )
            # gaussian noise
            if random_bits[7] or all_aug:
                noise = torch.from_numpy(random.normal(
                    loc=0.0,
                    scale=random.uniform(low=0.0, high=0.1),
                    size=input_.shape
                ))
                input_ = torch.add(input_, noise)
            input_ = torch.clip(input_, 0.0, 1.0)
            inputs_aug.append(torch.unsqueeze(input_, 0))
            labels_aug.append(labels[i])
    
    inputs = torch.cat([inputs, torch.cat(inputs_aug, 0)], 0)
    return (
        torch.movedim(inputs, 1, -1).numpy(),
        np.concatenate([labels, np.asarray(labels_aug)], axis=0)
    )

In [16]:
# evaluate the accuracy of one batch
def evaluate_batch(batch_outputs, batch_labels):
    batch_predictions = batch_outputs.argmax(axis=1)
    return np.average((batch_predictions==batch_labels).cpu())


# evaluate model with test data
def evaluate(model, test_loader, loss, device=DEVICE):
    
    model.eval()
    
    test_accuracies, test_losses = [], []
    for batch_inputs, batch_labels in iter(test_loader):
        
        batch_inputs = batch_inputs.type(torch.FloatTensor).to(device)  # CUDA
        batch_labels = batch_labels.type(torch.int64).to(device)
        
        batch_outputs = model(batch_inputs)
        test_accuracies.append(evaluate_batch(batch_outputs, batch_labels))
        test_losses.append(loss(batch_outputs, batch_labels).cpu().item())
    
    del loss
    
    return np.average(test_accuracies), np.average(test_losses)


# train for one epoch
def train(model, train_data_np, test_data_np, loss, optimizer, device=DEVICE):
    
    if DATA_AUG:
        train_data_np = data_aug(train_data_np)
    train_data, test_data = Dataset(train_data_np), Dataset(test_data_np)
    
    train_loader = utils.data.DataLoader(
        train_data, batch_size=BATCH_SIZE, shuffle=True
    )
    test_loader = utils.data.DataLoader(test_data, batch_size=1)
    
    model.train()
    
    train_accuracies, train_losses = list(), list()
    for batch_inputs, batch_labels in tqdm(train_loader):
        
        batch_inputs = batch_inputs.type(torch.FloatTensor).to(device)  # CUDA
        batch_labels = batch_labels.type(torch.int64).to(device)
        
        optimizer.zero_grad()
        batch_outputs = model(batch_inputs)
        batch_loss = loss(batch_outputs, batch_labels)
        
        train_accuracies.append(evaluate_batch(batch_outputs, batch_labels))
        train_losses.append(batch_loss.cpu().item())
        
        batch_loss.backward()
        optimizer.step()
        
        del batch_inputs  # free CUDA memory (I think)
        del batch_labels
    
    test_accuracy, test_loss = evaluate(model, test_loader, loss)
    del loss
    return np.average(train_accuracies), np.average(train_losses),\
           test_accuracy, test_loss

In [17]:
model = Model(3 * RESIZE * RESIZE, len(label_to_text.keys())).to(DEVICE)
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

outputs = list()

In [18]:
def generate_model_name(name):
    now = datetime.now()
    time_string = now.strftime("%m%d_%H%M")
    return name + "_" + time_string

In [19]:
for i in range(2):
    for epoch in range(4):
        output = train(model, train_data_np, test_data_np, loss, optimizer)
        outputs.append(output)
        print(output)
    torch.save(model.state_dict(), "models/" + generate_model_name("model"))

100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:16<00:00, 181.97it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [03:09<00:00,  2.40it/s]


(0.33253298566122486, 1.9319297369356196, 0.4994413407821229, 1.5552341625735673)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:15<00:00, 185.29it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.58it/s]


(0.435537596095707, 1.6459998383395997, 0.553072625698324, 1.330761539833048)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:20<00:00, 172.95it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.58it/s]


(0.47370932128357945, 1.543314242415491, 0.5519553072625698, 1.3745200985701584)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:34<00:00, 147.47it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.57it/s]


(0.503119398915954, 1.46013837737659, 0.5586592178770949, 1.3520910876889274)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:35<00:00, 146.14it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.58it/s]


(0.5221381159842792, 1.4002034396327014, 0.5932960893854748, 1.2737352651929623)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:44<00:00, 133.70it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.57it/s]


(0.5474648008983329, 1.3274510374153239, 0.6167597765363129, 1.2517553084814146)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:48<00:00, 128.87it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.57it/s]


(0.5631630198885722, 1.2795418023012808, 0.6268156424581005, 1.215572213744464)


100%|███████████████████████████████████████████████████████████████████████████| 13964/13964 [01:59<00:00, 116.80it/s]
100%|████████████████████████████████████████████████████████████████████████████████| 454/454 [02:56<00:00,  2.58it/s]


(0.578973597974432, 1.2328403139166895, 0.617877094972067, 1.2396404721822736)


In [20]:
for output in outputs:
    print(output[2], end=", ")

0.4994413407821229, 0.553072625698324, 0.5519553072625698, 0.5586592178770949, 0.5932960893854748, 0.6167597765363129, 0.6268156424581005, 0.617877094972067, 

# Load model

In [None]:
model_path = "models/model_0416_1135"

model_load = Model(3 * RESIZE & RESIZE, 12)
model_load.load_state_dict(torch.load(model_path))
model_load.eval()
print(f"Model at \"{model_path}\" loaded.")

In [None]:
test_loader = utils.data.DataLoader(
    Dataset(test_data_np), batch_size=BATCH_SIZE, shuffle=True
)
loss = nn.CrossEntropyLoss()

print(f"Test accuracy: {evaluate(model, test_loader, loss)[0]}")