'''
Author:
        
        PARK, JunHo, junho@ccnets.org

        
        KIM, JoengYoong, jeongyoong@ccnets.org
        
    COPYRIGHT (c) 2024. CCNets. All Rights reserved.
'''

# Recyclable and Household Waste Classification 


https://www.kaggle.com/datasets/alistairking/recyclable-and-household-waste-classification

In [1]:
import sys

path_append = "../" # Go up one directory from where you are.
sys.path.append(path_append) 

from tools.setting.ml_params import MLParameters
from tools.setting.data_config import DataConfig
from nn.utils.init import set_random_seed
set_random_seed(0)

In [2]:
dataset_path = '../data/Recyclable and Household Waste Classification/images/images'

In [3]:
import os
import torch
import random
from PIL import Image
from torch.utils.data import Dataset

def gather_and_split_data(root_dir, train_split=0.6, val_split=0.2, test_split=0.2):
    classes = sorted(os.listdir(root_dir))
    all_image_paths = []
    all_labels = []
    
    # Gather all image paths and labels
    for i, class_name in enumerate(classes):
        class_dir = os.path.join(root_dir, class_name)
        for subfolder in ['default', 'real_world']:
            subfolder_dir = os.path.join(class_dir, subfolder)
            image_names = os.listdir(subfolder_dir)
            for image_name in image_names:
                image_path = os.path.join(subfolder_dir, image_name)
                all_image_paths.append(image_path)
                all_labels.append(i)
    
    # Shuffle all images and labels in the same way
    combined_list = list(zip(all_image_paths, all_labels))
    random.shuffle(combined_list)
    all_image_paths, all_labels = zip(*combined_list)

    # Compute split indices
    num_images = len(all_image_paths)
    train_end = int(train_split * num_images)
    val_end = train_end + int(val_split * num_images)
    
    # Split data
    train_data = (all_image_paths[:train_end], all_labels[:train_end])
    val_data = (all_image_paths[train_end:val_end], all_labels[train_end:val_end])
    test_data = (all_image_paths[val_end:], all_labels[val_end:])
    
    return train_data, val_data, test_data

class WasteDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, index):
        image_path = self.image_paths[index]
        label = self.labels[index]
        image = Image.open(image_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
        
        label = torch.tensor(label, dtype=torch.long)
        label = label.unsqueeze(-1)
        return image, label

In [4]:
from torchvision import transforms

# Create the datasets and data loaders
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Define your root directory and transformations
root_dir = path_append + dataset_path

# Gather and split data
train_data, val_data, test_data = gather_and_split_data(root_dir)

# Create dataset instances
train_dataset = WasteDataset(*train_data, transform=transform)
val_dataset = WasteDataset(*val_data, transform=transform)
test_dataset = WasteDataset(*test_data, transform=transform)

X, y = train_dataset[0]
print(X.shape)
print(y.shape)

torch.Size([3, 128, 128])
torch.Size([1])


In [5]:

data_config = DataConfig(dataset_name = 'recycle_image', task_type='multi_class_classification', obs_shape=[3, 128, 128], label_size=30)

#  Set training configuration from the AlgorithmConfig class, returning them as a Namespace object.
ml_params = MLParameters(ccnet_network = 'resnet18', encoder_network = 'none')

In [6]:
from trainer_hub import TrainerHub

# Set the device to GPU if available, else CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 

# Initialize the TrainerHub class with the training configuration, data configuration, device, and use_print and use_wandb flags
trainer_hub = TrainerHub(ml_params, data_config, device, use_print=True, use_wandb=False)



In [7]:
trainer_hub.train(train_dataset, val_dataset)