In [7]:
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as transforms
import csv


In [8]:
class MyDigitDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.transform = transform
        image_paths = []
        for ext in ('*.jpg', '*.jpeg', '*.png', '*.heic', '*.jfif'):
            image_paths.extend(glob.glob(os.path.join(root_dir, ext)))
        self.image_paths = image_paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            basename = os.path.basename(img_path)
            label_part = basename.split('_')[0]

            # ✅ Kiểm tra nếu label không phải là số thì bỏ qua
            if not label_part.isdigit():
                raise ValueError(f"Nhãn không hợp lệ: {label_part}")

            label = int(label_part)

            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Lỗi đọc ảnh {img_path}: {e}")
            return torch.zeros(3, 32, 32), -1

class MyTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.image_paths = [os.path.join(image_dir, fname)
                            for fname in os.listdir(image_dir)
                            if fname.endswith(('.png', '.jpg', '.jpeg', '*.heic', '*.jfif'))]
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Lỗi đọc ảnh {img_path}: {e}")
            return torch.zeros(3, 32, 32), os.path.basename(img_path)

        if self.transform:
            image = self.transform(image)

        return image, os.path.basename(img_path)


In [9]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])



In [None]:
train_dir = r'D:\Train2'
test_dir = r'D:\Test\data.2025'

train_dataset = MyDigitDataset(train_dir, transform=transform)
print("Số lượng ảnh train:", len(train_dataset))
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

test_dataset = MyTestDataset(test_dir, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)


Số lượng ảnh train: 5473


In [11]:
class ANNModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.classifier = nn.Sequential(
            nn.Linear(3 * 32 * 32, 512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        x = self.classifier(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = ANNModel().to(device)


In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

num_epochs = 200
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    total = 0  # để đếm tổng số mẫu hợp lệ

    for images, labels in train_loader:
        # Bỏ qua ảnh có nhãn -1 (do lỗi đọc ảnh hoặc sai định dạng tên file)
        mask = labels != -1
        if mask.sum() == 0:
            continue
        images, labels = images[mask], labels[mask]

        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        total += 1

    avg_loss = running_loss / total if total > 0 else 0
    print(f"Epoch {epoch+1}, Loss: {avg_loss:.4f}")


Epoch 1, Loss: 2.3265
Epoch 2, Loss: 2.3044
Epoch 3, Loss: 2.2911
Epoch 4, Loss: 2.2785
Epoch 5, Loss: 2.2611
Epoch 6, Loss: 2.2359
Epoch 7, Loss: 2.1987
Epoch 8, Loss: 2.1703
Epoch 9, Loss: 2.1204
Epoch 10, Loss: 2.0623
Epoch 11, Loss: 2.0239
Epoch 12, Loss: 1.9757
Epoch 13, Loss: 1.9341
Epoch 14, Loss: 1.8866
Epoch 15, Loss: 1.8464
Epoch 16, Loss: 1.8250
Epoch 17, Loss: 1.8012
Epoch 18, Loss: 1.7500
Epoch 19, Loss: 1.7072
Epoch 20, Loss: 1.6735
Epoch 21, Loss: 1.6438
Epoch 22, Loss: 1.6461
Epoch 23, Loss: 1.5986
Epoch 24, Loss: 1.5589
Epoch 25, Loss: 1.5573
Epoch 26, Loss: 1.5066
Epoch 27, Loss: 1.5204
Epoch 28, Loss: 1.4688
Epoch 29, Loss: 1.4529
Epoch 30, Loss: 1.3985
Epoch 31, Loss: 1.4044
Epoch 32, Loss: 1.3603
Epoch 33, Loss: 1.3986
Epoch 34, Loss: 1.3354
Epoch 35, Loss: 1.3598
Epoch 36, Loss: 1.3008
Epoch 37, Loss: 1.2834
Epoch 38, Loss: 1.2619
Epoch 39, Loss: 1.2387
Epoch 40, Loss: 1.2165
Epoch 41, Loss: 1.2173
Epoch 42, Loss: 1.1834
Epoch 43, Loss: 1.1509
Epoch 44, Loss: 1.16

In [13]:
model.eval()
results = []
with torch.no_grad():
    for images, paths in test_loader:
        images = images.to(device)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        preds = preds.cpu().numpy()
        for path, pred in zip(paths, preds):
            filename = os.path.basename(path)
            results.append([filename, int(pred)])

with open('ANN_predictions.csv', 'w', newline='', encoding='utf-8') as f:
    writer = csv.writer(f)
    writer.writerow(['filename', 'prediction'])
    writer.writerows(results)


Lỗi đọc ảnh D:\Test\data.2025\21151445fd63ff79440449974f17109d.jpg: cannot identify image file 'D:\\Test\\data.2025\\21151445fd63ff79440449974f17109d.jpg'
Lỗi đọc ảnh D:\Test\data.2025\2c18ee0e7cea8354149df435532d74ae.jpeg: cannot identify image file 'D:\\Test\\data.2025\\2c18ee0e7cea8354149df435532d74ae.jpeg'
Lỗi đọc ảnh D:\Test\data.2025\3a816aa78f56749a0822d700ff560924.jpeg: cannot identify image file 'D:\\Test\\data.2025\\3a816aa78f56749a0822d700ff560924.jpeg'
Lỗi đọc ảnh D:\Test\data.2025\4abc1b5dcf1be1de6503dc072e132fa0.jpeg: cannot identify image file 'D:\\Test\\data.2025\\4abc1b5dcf1be1de6503dc072e132fa0.jpeg'
Lỗi đọc ảnh D:\Test\data.2025\70ebf0bec317006017a54d6c9172af45.jpeg: cannot identify image file 'D:\\Test\\data.2025\\70ebf0bec317006017a54d6c9172af45.jpeg'
Lỗi đọc ảnh D:\Test\data.2025\8eeee227b3f244e980b747387bc79bf2.jpeg: cannot identify image file 'D:\\Test\\data.2025\\8eeee227b3f244e980b747387bc79bf2.jpeg'
Lỗi đọc ảnh D:\Test\data.2025\9e748b9617e26b90011f8d7c3f8a7e