In [1]:
import torch
from torch.utils.data import Dataset
import os
import cv2 

def collate_fn(samples: list[dict]) -> dict:
    #images = [sample['image'].permute(1, 2, -1).unsqueeze(0) for sample in samples] 
    images = [sample['image'] for sample in samples]
    labels = [sample['label'] for sample in samples] 

    images = torch.stack(images, dim=0)
    labels = torch.tensor(labels)
    
    return {
        'image': images,
        'label': labels
    }

class VinaFood(Dataset):
    #def __init__(self, path: str):
    def __init__(self, path: str, image_size: tuple[int]):
        super().__init__()
    
        self.image_size = image_size
        self.label2idx = {}
        self.idx2label = {}
        self.data: list[dict] = self.load_data(path)
        
    def load_data(self, path):
        data = []
        label_id = 0
        print(f"Loading data from: {path}")
        for folder in os.listdir(path):
            label = folder
            if label not in self.label2idx:
                self.label2idx[label] = label_id
                label_id += 1
            folder_path = os.path.join(path, folder)
            print(f"Processing folder: {folder} (label_id: {self.label2idx[label]})")
            
            for image_file in os.listdir(folder_path):
                image_path = os.path.join(folder_path, image_file)
                image = cv2.imread(image_path)
                data.append({
                    'image': image,
                    'label': label
                })

        self.idx2label = {id: label for label, id in self.label2idx.items()}
        return data
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx: int) -> dict:
        item = self.data[idx]
        
        image = item['image']
        label = item['label']
        
        # image = cv2.resize(image, (224, 224))
        # label_id = self.label2idx[label]
        
        image = cv2.resize(image, self.image_size)
        # Convert to RGB if needed (OpenCV loads in BGR)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        # Convert to tensor once
        image = torch.tensor(image, dtype=torch.float32).permute(2,0,1) / 255.0
        return {
            'image': image,
            'label': self.label2idx[label]
        }

In [2]:
import torch 
from torch import nn 
from torch.nn import functional as F

class LeNet(nn.Module): 
    def __init__(self, image_size, num_labels):
        super().__init__()
        self.w, self.h = image_size 
        self.input_size = self.w * self.h
        self.conv1 = nn.Conv2d(
            in_channels=1,
            out_channels=6,
            kernel_size=5,
            stride=1,
            padding=2
        )
        self.avg_pool = nn.AvgPool2d(
            kernel_size=2,
            stride=2,
        )
        self.conv2 = nn.Conv2d(
            in_channels=6,
            out_channels=16,
            kernel_size=5,
            stride=1,
            padding=0
        )
        self.FC1 = nn.Linear(16 * 5 * 5, 120)
        self.FC2 = nn.Linear(120, 84)
        self.FC3 = nn.Linear(84, num_labels)
        self.relu = nn.ReLU()
    def forward(self, x):
        """
            x: Tensor(batch_size, 28, 28)
        """
        x = x.unsqueeze(1)
        x = self.avg_pool(self.relu(self.conv1(x)))
        x = self.avg_pool(self.relu(self.conv2(x)))
        x = x.view(-1, 16*5*5) # (B, 16*5*5)  
        x = self.relu(self.FC1(x))
        x = self.relu(self.FC2(x))
        x = self.FC3(x)
        return x

In [3]:
from mnist_dataset import MNISTDataset, collate_fn
from torch.utils.data import DataLoader
from torch import nn 
import torch
import numpy as np 
from LeNet import LeNet
from sklearn.metrics import precision_score, recall_score, f1_score

device = "cpu"

train_dataset = MNISTDataset(
    image_path="train-images.idx3-ubyte",
    label_path="train-labels.idx1-ubyte"
)

test_dataset = MNISTDataset(
    image_path="t10k-images.idx3-ubyte",
    label_path="t10k-labels.idx1-ubyte"
)

print("Hay nhap loai mo hinh (1 hoac 3): ")
model_type = input().strip()
if model_type == "1":
    model = LeNet(
        image_size=(28, 28),
        num_labels=10
    ).to(device)
elif model_type == "3":
    pass

loss_fn= nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

train_dataloader = DataLoader(
    dataset = train_dataset,
    batch_size = 32,
    shuffle = True,
    collate_fn = collate_fn
)

test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False,
    collate_fn=collate_fn
)

def evaluate(model, dataloader):
    model.eval() 
    outputs = []
    trues = []
    for item in dataloader:
        image = item["image"].to(device) # (B, 28, 28)
        label = item["label"].to(device) # (B,)
        output = model(image)   # (B, 10) - raw logits
        predictions = torch.argmax(output, dim=-1)  # (B,) - predicted classes

        outputs.extend(predictions.tolist())
        trues.extend(label.tolist())
    return {
        "recall": recall_score(np.array(trues), np.array(outputs), average="macro"),
        "precision": precision_score(np.array(trues), np.array(outputs), average="macro"),
        "f1": f1_score(np.array(trues), np.array(outputs), average="macro"),
    }


EPOCHS = 10 
for epoch in range(EPOCHS):
    print(f"Epoch: {epoch+1}")

    losses = []
    model.train() 
    for item in train_dataloader:
        image = item["image"].to(device) # (B, 28, 28)
        label = item["label"].to(device) # (B,)
        # Forward pass
        output = model(image)   # (B, 10)

        loss = loss_fn(output, label.long())
        losses.append(loss.item())

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f"Loss: {(np.array(losses).mean())}")
    metrics = evaluate(model, test_dataloader)
    for metric in metrics:
        print(f"{metric}: {metrics[metric]}")



Hay nhap loai mo hinh (1 hoac 3): 
Epoch: 1
Loss: 0.1345107184850766
recall: 0.9761101776184266
precision: 0.9765654543809559
f1: 0.9760230300121615
Epoch: 2
Loss: 0.05697057037660852
recall: 0.9858137967355193
precision: 0.9862839039568503
f1: 0.9860002372904061
Epoch: 3
Loss: 0.045977044524114656
recall: 0.9869929980735378
precision: 0.9871829546912523
f1: 0.98704699377796
Epoch: 4
Loss: 0.03840104632595709
recall: 0.9892916946204912
precision: 0.9892542302749401
f1: 0.9892608541624831
Epoch: 5
Loss: 0.030955076661509037
recall: 0.9852252384246466
precision: 0.9856666588411519
f1: 0.9853854957031574
Epoch: 6
Loss: 0.028377845545517387
recall: 0.9897196989534276
precision: 0.9899638726896841
f1: 0.9898247230185151
Epoch: 7
Loss: 0.02560418547791408
recall: 0.9884145943435371
precision: 0.9886125012714736
f1: 0.9884913739982795
Epoch: 8
Loss: 0.019512649226860475
recall: 0.9885948629768635
precision: 0.9887567727514291
f1: 0.9886670672102319
Epoch: 9
Loss: 0.021490124352066764
recall: 