In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split

transform = transforms.Compose([
    transforms.Resize((224, 224)),              # Resize to the size accepted by the CNN
    transforms.RandomHorizontalFlip(),          # Data Augmentation: Random Horizontal Flip
    transforms.ToTensor(),                      # transform into Tensor
    transforms.Normalize(                       # normalize
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])

dataset_path = '../datasets/trashnet'

# Automatically read subfolders and assign labels
full_dataset = datasets.ImageFolder(root=dataset_path, transform=transform)

# Check class indices (critical point)
print("Class Reflection:", full_dataset.class_to_idx)

# Partition the Training Set and Validation Set（in this case, 80% for training + 20% for Validation）
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size

train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Build Dataloader
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

Class Reflection: {'cardboard': 0, 'glass': 1, 'metal': 2, 'paper': 3, 'plastic': 4, 'trash': 5}


In [2]:
images, labels = next(iter(train_loader))
print("The shape of a batch of image: ", images.shape)
print("Corresponding Label: ", labels[:5])

The shape of a batch of image:  torch.Size([32, 3, 224, 224])
Corresponding Label:  tensor([3, 3, 3, 3, 3])


In [3]:
# 1. Load Models & Set device
import torch
import torchvision.models as models
import torch.nn as nn
import torch.optim as optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Use Device: ", device)

# 2. Load the model architecture and replace the fully connected (fc) layer (modify the structure first!)
model = models.resnet18()
model.fc = nn.Linear(model.fc.in_features, 6)  # 👈 firstly replace the last layer
model.load_state_dict(torch.load('../models/recycling_resnet18_15ep.pth', map_location=device))  # 👈 then load parameters
model = model.to(device)

# 3. Freeze the feature extraction layers
for param in model.parameters():
    param.requires_grad = False
for param in model.fc.parameters():  # need to train the fully connected layers
    param.requires_grad = True

# 4. Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=0.001)

# 5. Continue training
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    train_acc = correct / total
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f} - Train Accuracy: {train_acc:.4f}")

# 6. Save the updated model
torch.save(model.state_dict(), 'recycling_resnet18.pth')

Use Device:  cpu
Epoch 1/5 - Loss: 27.1833 - Train Accuracy: 0.8530
Epoch 2/5 - Loss: 24.6134 - Train Accuracy: 0.8699
Epoch 3/5 - Loss: 23.9142 - Train Accuracy: 0.8718
Epoch 4/5 - Loss: 24.6633 - Train Accuracy: 0.8704
Epoch 5/5 - Loss: 22.9220 - Train Accuracy: 0.8728
