In [1]:
import torch
import copy
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from PIL import Image, ImageOps
cudnn.benchmark = True
plt.ion()   # interactive mode
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

In [2]:
class CustomImageDataset(Dataset):
    def __init__(self, img_dir, transform=None, target_transform=None):
        self.img_dir = img_dir
        self.image_filenames = []
        self.transform = transform
        self.target_transform = target_transform

        for f in os.listdir(img_dir):
            try:
                # Lấy nhãn từ tên file, giả sử tên là "5_abc.jpg"
                label = int(f.split('_')[0])
                if 0 <= label <= 9:
                    self.image_filenames.append(f)
            except:
                continue  # Bỏ qua file không hợp lệ

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = self.image_filenames[idx]
        img_path = os.path.join(self.img_dir, img_name)

        # Load ảnh và giữ nguyên RGB
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        label = int(img_name.split('_')[0])  # đảm bảo phù hợp lại lần nữa
        if self.target_transform:
            label = self.target_transform(label)

        return image, label

In [3]:
for f in os.listdir('E:/data/train'):
    try:
        label = int(f.split('_')[0])
        if label > 9:
            print(f"Lỗi nhãn >9: {f}")
    except:
        print(f"Không đọc được nhãn từ: {f}")

In [4]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # tăng khả năng chống nhiễu ánh sáng/màu
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], 
                             [0.229, 0.224, 0.225])
    ]),
}

# Dataset & Dataloader
train_dataset = CustomImageDataset('E:/data/train', transform=data_transforms['train'])
val_dataset = CustomImageDataset('E:/data/val', transform=data_transforms['val'])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

dataloaders = {'train': train_loader, 'val': val_loader}
dataset_sizes = {'train': len(train_dataset), 'val': len(val_dataset)}

device = torch.accelerator.current_accelerator().type if torch.accelerator.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


training

In [5]:
# Model ResNet18 giữ nguyên RGB
model_conv = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)

# Freeze toàn bộ các layer
for param in model_conv.parameters():
    param.requires_grad = False

# Thay thế FC cuối để phân loại 10 lớp
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 10)
model_conv = model_conv.to(device)
# Loss, optimizer, scheduler
criterion = nn.CrossEntropyLoss()
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

In [6]:
def train_model(model, criterion, optimizer, scheduler, dataloaders, dataset_sizes, device, num_epochs=25):
    since = time.time()
    best_model_path = 'best_model.pth'  # ✅ File lưu mô hình tốt nhất

    torch.save(model.state_dict(), best_model_path)  # Lưu tạm thời trước
    best_acc = 0.0

    model.to(device)

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            model.train() if phase == 'train' else model.eval()
            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # ✅ Lưu mô hình tốt nhất theo val accuracy
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), best_model_path)
                print(f'>> SAVED best model (epoch {epoch}, acc {epoch_acc:.4f})')

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # ✅ Khôi phục model tốt nhất
    model.load_state_dict(torch.load(best_model_path))
    return model

In [7]:
# Kích hoạt huấn luyện mô hình
model_conv = train_model(
    model=model_conv,
    criterion=criterion,
    optimizer=optimizer_conv,
    scheduler=exp_lr_scheduler,
    dataloaders=dataloaders,
    dataset_sizes=dataset_sizes,
    device=device,
    num_epochs=25
)


Epoch 0/24
----------
train Loss: 2.3467 Acc: 0.1244
val Loss: 2.1740 Acc: 0.2163
>> SAVED best model (epoch 0, acc 0.2163)

Epoch 1/24
----------
train Loss: 2.1301 Acc: 0.2371
val Loss: 1.9631 Acc: 0.3282
>> SAVED best model (epoch 1, acc 0.3282)

Epoch 2/24
----------
train Loss: 1.9918 Acc: 0.2990
val Loss: 1.8607 Acc: 0.3609
>> SAVED best model (epoch 2, acc 0.3609)

Epoch 3/24
----------
train Loss: 1.8901 Acc: 0.3615
val Loss: 1.7280 Acc: 0.4135
>> SAVED best model (epoch 3, acc 0.4135)

Epoch 4/24
----------
train Loss: 1.8450 Acc: 0.3762
val Loss: 1.6732 Acc: 0.4384
>> SAVED best model (epoch 4, acc 0.4384)

Epoch 5/24
----------
train Loss: 1.7320 Acc: 0.4185
val Loss: 1.6239 Acc: 0.4507
>> SAVED best model (epoch 5, acc 0.4507)

Epoch 6/24
----------
train Loss: 1.7394 Acc: 0.4154
val Loss: 1.5795 Acc: 0.4688
>> SAVED best model (epoch 6, acc 0.4688)

Epoch 7/24
----------
train Loss: 1.6930 Acc: 0.4498
val Loss: 1.5482 Acc: 0.4939
>> SAVED best model (epoch 7, acc 0.4939)

