<a href="https://colab.research.google.com/github/nguyenanhtienabcd/AIO2024_EXERCISE/blob/feature%2FMODULE7-WEEK2/m07w02_ex3.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### import thư viện

In [50]:
# Import necessary libraries for file and system operations
import os

# Import PyTorch libraries for deep learning
import torch
import torch.nn as nn  # Neural network modules
import torch.optim as optim  # Optimization algorithms
import torch.nn.functional as F  # Activation functions, etc.
import torchvision  # Utilities for computer vision tasks
from torchvision import transforms, models  # Image transformations and pre-trained models

# Import libraries for data handling and analysis
import numpy as np  # Numerical computing
import pandas as pd  # Data manipulation and analysis

# Import libraries for visualization
import matplotlib.pyplot as plt  # Plotting
import seaborn as sns  # Statistical data visualization
import matplotlib.patches as patches  # Drawing shapes on plots

# Import libraries for working with XML data
import xml.etree.ElementTree as ET

# Import libraries for progress bars and utilities
import tqdm.notebook as tqdm  # Progress bars for Jupyter notebooks
from PIL import Image  # Image processing
from torch.utils.data import Dataset, DataLoader  # Custom dataset and data loading
from sklearn.metrics import confusion_matrix  # Model evaluation
from sklearn.model_selection import train_test_split  # Data splitting
from torchvision.models.resnet import ResNet18_Weights, ResNet50_Weights  # Pre-trained ResNet weights

import tqdm

In [51]:
import kagglehub

# Download latest version
data_dir = kagglehub.dataset_download("andrewmvd/dog-and-cat-detection")

print("Path to dataset files:", data_dir)


Path to dataset files: /root/.cache/kagglehub/datasets/andrewmvd/dog-and-cat-detection/versions/1


In [52]:
class MyDataset(Dataset):
    def __init__(self, annotations_dir, image_dir, transform=None):
        self.annotations_dir = annotations_dir
        self.image_dir = image_dir
        self.transform = transform
        self.image_files = self.filter_images_with_multiple_objects()

    def filter_images_with_multiple_objects(self):
        valid_image_files = []
        for f in os.listdir(self.image_dir):
            if os.path.isfile(os.path.join(self.image_dir, f)):
                img_name = f
                annotation_name = os.path.splitext(img_name)[0] + ".xml"
                annotation_path = os.path.join(self.annotations_dir, annotation_name)

                if self.count_objects_in_annotation(annotation_path) == 1:
                    valid_image_files.append(img_name)
        return valid_image_files

    def count_objects_in_annotation(self, annotation_path):
        try:
            tree = ET.parse(annotation_path)
            root = tree.getroot()
            count = 0
            for obj in root.findall("object"):
                count += 1
            return count
        except FileNotFoundError:
            return 0

    def parse_annotation(self, annotation_path):
        tree = ET.parse(annotation_path)
        root = tree.getroot()

        # Get image size for normalization
        image_width = int(root.find("size/width").text)
        image_height = int(root.find("size/height").text)

        label = None
        bbox = None
        for obj in root.findall("object"):
            name = obj.find("name").text
            if label is None:  # Take the first label
                label = name
                # Get bounding box coordinates
                xmin = int(obj.find("bndbox/xmin").text)
                ymin = int(obj.find("bndbox/ymin").text)
                xmax = int(obj.find("bndbox/xmax").text)
                ymax = int(obj.find("bndbox/ymax").text)

                # Normalize bbox coordinates to [0, 1]
                bbox = [
                    xmin / image_width,
                    ymin / image_height,
                    xmax / image_width,
                    ymax / image_height,
                ]

        # Convert label to numerical representation (0 for cat, 1 for dog)
        label_num = 0 if label == "cat" else 1 if label == "dog" else -1

        return label_num, torch.tensor(bbox, dtype=torch.float32)

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img1_file = self.image_files[idx]
        img1_path = os.path.join(self.image_dir, img1_file)

        annotation_name = os.path.splitext(img1_file)[0] + ".xml"
        img1_annotations = self.parse_annotation(
            os.path.join(self.annotations_dir, annotation_name)
        )

        if idx == len(self.image_files) - 1:
            idx2 = 0
        else:
            idx2 = idx + 1
        img2_file = self.image_files[idx2]
        img2_path = os.path.join(self.image_dir, img2_file)

        annotation_name = os.path.splitext(img2_file)[0] + ".xml"
        img2_annotations = self.parse_annotation(
            os.path.join(self.annotations_dir, annotation_name)
        )

        img1 = Image.open(img1_path).convert("RGB")
        img2 = Image.open(img2_path).convert("RGB")

        # Horizontal merge
        merged_image = Image.new(
            "RGB", (img1.width + img2.width, max(img1.height, img2.height))
        )
        merged_image.paste(img1, (0, 0))
        merged_image.paste(img2, (img1.width, 0))
        merged_w = img1.width + img2.width
        merged_h = max(img1.height, img2.height)

        merged_annotations = []

        # Adjust bbox coordinates for objects from img1 AND normalize
        new_bbox1 = [
            img1_annotations[1][0] * img1.width / merged_w,  # Normalize xmin
            img1_annotations[1][1] * img1.height / merged_h,  # Normalize ymin
            img1_annotations[1][2] * img1.width / merged_w,  # Normalize xmax
            img1_annotations[1][3] * img1.height / merged_h,  # Normalize ymax
        ]
        merged_annotations.append({"bbox": new_bbox1, "label": img1_annotations[0]})

        # Adjust bbox coordinates for objects from img2 AND normalize
        new_bbox2 = [
            (img2_annotations[1][0] * img2.width + img1.width)
            / merged_w,  # Normalize xmin
            img2_annotations[1][1] * img2.height / merged_h,  # Normalize ymin
            (img2_annotations[1][2] * img2.width + img1.width)
            / merged_w,  # Normalize xmax
            img2_annotations[1][3] * img2.height / merged_h,  # Normalize ymax
        ]

        merged_annotations.append({"bbox": new_bbox2, "label": img2_annotations[0]})

        # Convert merged image to tensor
        if self.transform:
            merged_image = self.transform(merged_image)
        else:
            merged_image = transforms.ToTensor()(merged_image)

        # Convert annotations to 1D tensors, with shape (4,) for bbox and (1,) for label
        annotations = torch.zeros((len(merged_annotations), 5))
        for i, ann in enumerate(merged_annotations):
            annotations[i] = torch.cat(
                (torch.tensor(ann["bbox"]), torch.tensor([ann["label"]]))
            )

        return merged_image, annotations


In [53]:
class SimpleYOLO(nn.Module):
    def __init__(self, num_classes):
        super(SimpleYOLO, self).__init__()
        self.backbone = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        self.num_classes = num_classes

        # Remove the final classification layer of ResNet
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])

        # Add the YOLO head
        self.fcs = nn.Linear(
            2048, 2 * 2 * (4 + self.num_classes)
        )  # 2 is for the number of grid cell

    def forward(self, x):
        # x shape: (batch_size, C, H, W)
        features = self.backbone(x)
        features = F.adaptive_avg_pool2d(
            features, (1, 1)
        )  # shape: (batch_size, 2048, 1, 1)
        features = features.view(features.size(0), -1)  # shape: (batch_size, 2048)
        features = self.fcs(features)

        return features


In [54]:
# khởi tạo model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2
class_to_idx = {"cat": 0, "dog": 1}
model = SimpleYOLO(num_classes=2).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [55]:
def calculate_loss(output, targets, device, num_classes):
    mse_loss = nn.MSELoss()
    ce_loss = nn.CrossEntropyLoss()

    batch_size = output.shape[0]
    total_loss = 0

    for i in range(batch_size):  # Iterate through each image in the batch
        grid_have_object = torch.zeros((batch_size, 2, 2), device=device)

        for j in range(len(targets[i])):  # Iterate through objects in the image
            # Determine which grid cell the object's center falls into
            # Assuming bbox coordinates are normalized to [0, 1]
            bbox_center_x = (targets[i][j][0] + targets[i][j][2]) / 2
            bbox_center_y = (targets[i][j][1] + targets[i][j][3]) / 2

            grid_x = int(
                bbox_center_x * 2
            )  # Multiply by number of grid cells (2 in this case)
            grid_y = int(bbox_center_y * 2)

            grid_have_object[i, grid_y, grid_x] = 1

            # Calculate the starting index for the current grid cell's predictions
            grid_cell_index = (grid_y * 2 + grid_x) * (4 + num_classes)

            # 1. Classification Loss for the responsible grid cell
            # Convert label to one-hot encoding only for this example
            label_one_hot = torch.zeros(num_classes, device=device)
            label_one_hot[int(targets[i][j][4])] = 1

            # Classification loss (using CrossEntropyLoss)
            classification_loss = ce_loss(output[i, grid_cell_index + 4 : grid_cell_index + 4 + num_classes], label_one_hot) # Reshape the output tensor

            # 2. Regression Loss for the responsible grid cell
            bbox_target = targets[i][j][:4].to(device)
            regression_loss = mse_loss(output[i, grid_cell_index : grid_cell_index + 4], bbox_target) # Reshape the output tensor


            total_loss += classification_loss + regression_loss

        # 3. No Object Loss (for other grid cells)
        no_obj_loss = 0
        for other_grid_y in range(2):
            for other_grid_x in range(2):
                if grid_have_object[i, other_grid_y, other_grid_x] == 0:
                    # Calculate the starting index for the other grid cell's predictions
                    other_grid_cell_index = (other_grid_y * 2 + other_grid_x) * (4 + num_classes)

                    # MSE loss for predicting no object (all zeros)
                    no_obj_loss += mse_loss(
                        output[i, other_grid_cell_index: other_grid_cell_index + 4],
                        torch.zeros(4, device=device),
                    )

        total_loss += no_obj_loss

    return total_loss / batch_size  # Average loss over the batch

In [61]:
def evaluate_model(model, data_loader, device, num_classes):
    model.eval()
    running_loss = 0.0
    all_predictions = []
    all_targets = []

    with torch.no_grad():
        for images, targets in tqdm.tqdm(data_loader, desc="Validation", leave=False):
            images = images.to(device)
            targets = [target.to(device) for target in targets]
            outputs = model(images)
            total_loss = calculate_loss(outputs, targets, device, num_classes)
            running_loss += total_loss.item()

            for batch_idx in range(images.shape[0]):  # Lặp qua từng ảnh trong batch
                for target in targets[batch_idx]:  # Lặp qua từng object trong ảnh
                    # Determine which grid cell the object's center falls into
                    # Assuming bbox coordinates are normalized to [0, 1]
                    bbox_center_x = (target[0] + target[2]) / 2
                    bbox_center_y = (target[1] + target[3]) / 2
                    grid_x = int(bbox_center_x * 2)
                    grid_y = int(bbox_center_y * 2)

                    # Calculate the starting index for the current grid cell's predictions
                    grid_cell_index = (grid_y * 2 + grid_x) * (4 + num_classes)

                    # Dự đoán class (chọn class có xác suất cao nhất)
                    # Get predictions for the current grid cell
                    grid_cell_predictions = outputs[
                        batch_idx, grid_cell_index + 4 : grid_cell_index + 4 + num_classes
                    ]

                    prediction = grid_cell_predictions.argmax()  # Get predicted class

                    all_predictions.append(prediction.item())  # Lưu dự đoán
                    all_targets.append(target[4].item())  # Lưu nhãn thực tế

            # tính toán loss
            val_loss = running_loss / len(data_loader)

            all_predictions = torch.tensor(all_predictions)
            all_targets = torch.tensor(all_targets)

            # Tính toán accuracy
            val_accuracy = (all_predictions == all_targets).float().mean()
            return val_loss, val_accuracy.item()

In [62]:
def train_model(model, train_loader, val_loader, optimizer, num_epochs, device, num_classes):
    best_val_accuracy = 0.0  # Lưu độ chính xác cao nhất
    train_losses, val_losses = [], []  # Danh sách lưu Loss
    val_accuracies = []  # Lưu độ chính xác

    for epoch in tqdm.tqdm(range(num_epochs), desc="Epochs"):
        model.train()  # Đặt mô hình vào chế độ train
        running_loss = 0.0

        # Huấn luyện từng batch trong train_loader
        for images, targets in tqdm.tqdm(train_loader, desc="Batches", leave=False):
            images = images.to(device)

            optimizer.zero_grad()  # Reset gradient
            output = model(images)  # Forward pass

            # Tính loss
            total_loss = calculate_loss(output, targets, device, num_classes)

            # Try reducing the batch size
            # If it still fails, consider detaching the output from the computation graph and clearing the graph occasionally.
            # This can help conserve memory.
            try:
                total_loss.backward()  # Backpropagation
            except RuntimeError as e:
                if "cuDNN error: CUDNN_STATUS_INTERNAL_ERROR" in str(e):
                    print("Encountered cuDNN error. Trying to reduce batch size or clear CUDA cache.")
                    # You could potentially clear the CUDA cache using torch.cuda.empty_cache()
                    torch.cuda.empty_cache()
                    # Also consider detaching the output from the graph before backward:
                    # total_loss = total_loss.detach().cpu()
                    # total_loss.backward()
                    # Continue training if successful or raise the exception otherwise
                else:
                    raise e # Re-raise the exception if it's not the cuDNN error


            optimizer.step()  # Cập nhật trọng số

            running_loss += total_loss.item()  # Cộng dồn loss

        # Tính loss trung bình của epoch
        epoch_loss = running_loss / len(train_loader)
        train_losses.append(epoch_loss)

        # Đánh giá mô hình trên tập validation
        val_loss, val_accuracy = evaluate_model(model, val_loader, device, num_classes)
        val_losses.append(val_loss)
        val_accuracies.append(val_accuracy)

        # In kết quả sau mỗi epoch
        print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f}, "
              f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")

        # Lưu mô hình tốt nhất
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), "best_model.pth")
            print(" Saved best model with accuracy:", best_val_accuracy)

    return train_losses, val_losses, val_accuracies

In [63]:
# Define paths
from torch.utils.data import DataLoader, Dataset, Subset
batch_size = 8

annotations_dir = os.path.join(data_dir, "annotations")
image_dir = os.path.join(data_dir, "images")

# Define transformations
transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

# First create two base datasets with different transforms
base_dataset = MyDataset(annotations_dir, image_dir, transform=transform)

# Get the total dataset size
dataset_size = len(base_dataset)

# Calculate train and validation sizes
val_size = int(0.2 * dataset_size)
train_size = dataset_size - val_size

# Generate indices
indices = np.arange(dataset_size)
np.random.seed(42)  # Ensure reproducibility
np.random.shuffle(indices)

# Split indices for train and validation sets
train_indices, val_indices = indices[:train_size], indices[train_size:]
# train_indices, val_indices = indices[:1], indices[:1]

# Create Subsets using the appropriate base dataset
train_dataset = Subset(base_dataset, train_indices)
val_dataset = Subset(base_dataset, val_indices)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=4,
    persistent_workers=True,
    drop_last=True,
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    pin_memory=True,
    num_workers=4,
    persistent_workers=True,
    drop_last=False,
)

# Initialize model, criterion, and optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
num_classes = 2  # Assuming two classes: dog and cat
class_to_idx = {"dog": 0, "cat": 1}


In [64]:
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

def inference(model, image_path, transform, device, class_to_idx, threshold=0.5):
    """
    Hàm thực hiện dự đoán bounding box và class trên một hình ảnh đầu vào.

    Args:
        model: Mô hình đã huấn luyện.
        image_path (str): Đường dẫn đến hình ảnh.
        transform: Các phép biến đổi ảnh giống lúc training.
        device: Thiết bị sử dụng (CPU/GPU).
        class_to_idx (dict): Dictionary ánh xạ tên class sang index.
        threshold (float): Ngưỡng xác suất để hiển thị bounding box.

    Returns:
        Hiển thị hình ảnh với bounding box và class được dự đoán.
    """
    model.eval()  # Đặt mô hình vào chế độ đánh giá (evaluation mode)

    # 1️⃣ Đọc ảnh và lấy kích thước gốc
    image = Image.open(image_path).convert("RGB")
    original_width, original_height = image.size

    # 2️⃣ Resize ảnh về kích thước phù hợp với mô hình
    resized_image = image.resize((448, 448))
    resized_width, resized_height = resized_image.size

    # 3️⃣ Áp dụng biến đổi giống lúc training
    transformed_image = transform(resized_image).unsqueeze(0).to(device)

    # 4️⃣ Thực hiện dự đoán
    with torch.no_grad():
        output = model(transformed_image)  # Forward pass
        output = output.view(1, 2, 2, 4 + len(class_to_idx))  # Reshape thành 2x2 grid

    # 5️⃣ Hiển thị ảnh gốc đã resize
    fig, ax = plt.subplots(1)
    ax.axis("off")
    ax.imshow(resized_image)

    # 6️⃣ Duyệt qua từng ô (grid cell) trong lưới 2x2
    for grid_y in range(2):
        for grid_x in range(2):
            # Lấy class và bounding box tại ô (grid_x, grid_y)
            class_pred = output[0, grid_y, grid_x, 4:].argmax().item()
            bbox = output[0, grid_y, grid_x, :4].tolist()

            # Xác suất của class
            confidence = torch.softmax(output[0, grid_y, grid_x, 4:], dim=0)[class_pred].item()

            # 7️⃣ Chuyển bounding box từ tọa độ chuẩn hóa về ảnh gốc
            x_min = bbox[0] * (resized_width / 2) + grid_x * (resized_width / 2)
            y_min = bbox[1] * (resized_height / 2) + grid_y * (resized_height / 2)
            x_max = bbox[2] * (resized_width / 2) + grid_x * (resized_width / 2)
            y_max = bbox[3] * (resized_height / 2) + grid_y * (resized_height / 2)

            # 8️⃣ Vẽ bounding box và hiển thị label nếu xác suất cao hơn `threshold`
            if confidence > threshold:
                rect = patches.Rectangle(
                    (x_min, y_min),
                    x_max - x_min,
                    y_max - y_min,
                    linewidth=2, edgecolor="r",
                    facecolor="none"
                )
                ax.add_patch(rect)

                plt.text(
                    x_min,
                    y_min,
                    f"{list(class_to_idx.keys())[class_pred]}: {confidence:.2f}",
                    color="white",
                    fontsize=12,
                    bbox=dict(facecolor="red", alpha=0.5),
                )

    # 9️⃣ Hiển thị ảnh với bounding box và nhãn
    plt.show()



In [None]:
num_epochs = 10 # You can adjust the number of epochs as needed
train_losses, val_losses, val_accuracies = train_model(model, train_loader, val_loader, optimizer, num_epochs, device, num_classes)


Epochs:   0%|          | 0/10 [00:00<?, ?it/s]
Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:00<05:13,  1.17it/s][A
Batches:   1%|          | 2/368 [00:01<02:48,  2.18it/s][A
Batches:   1%|          | 3/368 [00:01<02:11,  2.77it/s][A
Batches:   1%|          | 4/368 [00:01<01:55,  3.15it/s][A
Batches:   1%|▏         | 5/368 [00:01<01:48,  3.35it/s][A
Batches:   2%|▏         | 6/368 [00:02<01:39,  3.64it/s][A
Batches:   2%|▏         | 7/368 [00:02<01:26,  4.17it/s][A
Batches:   2%|▏         | 8/368 [00:02<01:15,  4.80it/s][A
Batches:   2%|▏         | 9/368 [00:02<01:19,  4.51it/s][A
Batches:   3%|▎         | 10/368 [00:02<01:30,  3.97it/s][A
Batches:   3%|▎         | 11/368 [00:03<01:25,  4.20it/s][A
Batches:   3%|▎         | 12/368 [00:03<01:16,  4.67it/s][A
Batches:   4%|▎         | 13/368 [00:03<01:28,  4.01it/s][A
Batches:   4%|▍         | 14/368 [00:03<01:24,  4.17it/s][A
Batches:   4%|▍         | 15/368 [00:04<01:18,  4.50it/s

Epoch 1/10, Train Loss: 0.5825, Validation Loss: 0.0034, Validation Accuracy: 1.0000


Epochs:  10%|█         | 1/10 [01:17<11:35, 77.31s/it]

 Saved best model with accuracy: 1.0



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:00<04:32,  1.35it/s][A
Batches:   1%|          | 2/368 [00:00<02:30,  2.44it/s][A
Batches:   1%|          | 3/368 [00:01<01:52,  3.23it/s][A
Batches:   1%|          | 4/368 [00:01<01:35,  3.81it/s][A
Batches:   1%|▏         | 5/368 [00:01<01:20,  4.50it/s][A
Batches:   2%|▏         | 6/368 [00:01<01:12,  4.97it/s][A
Batches:   2%|▏         | 7/368 [00:01<01:07,  5.34it/s][A
Batches:   2%|▏         | 8/368 [00:01<01:04,  5.62it/s][A
Batches:   2%|▏         | 9/368 [00:02<01:05,  5.45it/s][A
Batches:   3%|▎         | 10/368 [00:02<01:00,  5.89it/s][A
Batches:   3%|▎         | 11/368 [00:02<00:57,  6.17it/s][A
Batches:   3%|▎         | 12/368 [00:02<00:55,  6.46it/s][A
Batches:   4%|▎         | 13/368 [00:02<00:54,  6.50it/s][A
Batches:   4%|▍         | 14/368 [00:02<00:53,  6.62it/s][A
Batches:   4%|▍         | 15/368 [00:02<00:50,  6.96it/s][A
Batches:   4%|▍         | 16/368 [00:03<0

Epoch 2/10, Train Loss: 0.4654, Validation Loss: 0.0024, Validation Accuracy: 1.0000



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:01<06:15,  1.02s/it][A
Batches:   1%|          | 2/368 [00:01<03:21,  1.82it/s][A
Batches:   1%|          | 3/368 [00:01<02:13,  2.74it/s][A
Batches:   1%|          | 4/368 [00:01<01:51,  3.27it/s][A
Batches:   1%|▏         | 5/368 [00:01<01:36,  3.75it/s][A
Batches:   2%|▏         | 6/368 [00:01<01:24,  4.27it/s][A
Batches:   2%|▏         | 7/368 [00:02<01:19,  4.57it/s][A
Batches:   2%|▏         | 8/368 [00:02<01:12,  4.98it/s][A
Batches:   2%|▏         | 9/368 [00:02<01:13,  4.88it/s][A
Batches:   3%|▎         | 10/368 [00:02<01:09,  5.17it/s][A
Batches:   3%|▎         | 11/368 [00:02<01:04,  5.54it/s][A
Batches:   3%|▎         | 12/368 [00:03<01:01,  5.82it/s][A
Batches:   4%|▎         | 13/368 [00:03<00:58,  6.06it/s][A
Batches:   4%|▍         | 14/368 [00:03<00:54,  6.55it/s][A
Batches:   4%|▍         | 15/368 [00:03<00:52,  6.71it/s][A
Batches:   4%|▍         | 16/368 [00:03<0

Epoch 3/10, Train Loss: 0.3652, Validation Loss: 0.0012, Validation Accuracy: 1.0000



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:01<10:42,  1.75s/it][A
Batches:   1%|          | 2/368 [00:02<05:39,  1.08it/s][A
Batches:   1%|          | 3/368 [00:02<04:02,  1.51it/s][A
Batches:   1%|          | 4/368 [00:02<03:12,  1.89it/s][A
Batches:   1%|▏         | 5/368 [00:03<02:39,  2.28it/s][A
Batches:   2%|▏         | 6/368 [00:03<02:19,  2.59it/s][A
Batches:   2%|▏         | 7/368 [00:03<02:08,  2.82it/s][A
Batches:   2%|▏         | 8/368 [00:03<02:07,  2.83it/s][A
Batches:   2%|▏         | 9/368 [00:04<02:04,  2.88it/s][A
Batches:   3%|▎         | 10/368 [00:04<01:58,  3.02it/s][A
Batches:   3%|▎         | 11/368 [00:04<01:51,  3.20it/s][A
Batches:   3%|▎         | 12/368 [00:05<01:51,  3.20it/s][A
Batches:   4%|▎         | 13/368 [00:05<01:49,  3.25it/s][A
Batches:   4%|▍         | 14/368 [00:05<01:41,  3.48it/s][A
Batches:   4%|▍         | 15/368 [00:06<01:42,  3.44it/s][A
Batches:   4%|▍         | 16/368 [00:06<0

Epoch 4/10, Train Loss: 0.2843, Validation Loss: 0.0020, Validation Accuracy: 1.0000



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:01<10:01,  1.64s/it][A
Batches:   1%|          | 2/368 [00:02<05:27,  1.12it/s][A
Batches:   1%|          | 3/368 [00:02<03:54,  1.55it/s][A
Batches:   1%|          | 4/368 [00:02<03:08,  1.93it/s][A
Batches:   1%|▏         | 5/368 [00:02<02:32,  2.39it/s][A
Batches:   2%|▏         | 6/368 [00:03<02:12,  2.74it/s][A
Batches:   2%|▏         | 7/368 [00:03<01:59,  3.02it/s][A
Batches:   2%|▏         | 8/368 [00:03<01:51,  3.24it/s][A
Batches:   2%|▏         | 9/368 [00:03<01:45,  3.39it/s][A
Batches:   3%|▎         | 10/368 [00:04<01:43,  3.47it/s][A
Batches:   3%|▎         | 11/368 [00:04<01:41,  3.53it/s][A
Batches:   3%|▎         | 12/368 [00:04<01:40,  3.53it/s][A
Batches:   4%|▎         | 13/368 [00:05<01:39,  3.57it/s][A
Batches:   4%|▍         | 14/368 [00:05<01:32,  3.84it/s][A
Batches:   4%|▍         | 15/368 [00:05<01:30,  3.91it/s][A
Batches:   4%|▍         | 16/368 [00:05<0

Epoch 5/10, Train Loss: 0.2456, Validation Loss: 0.0067, Validation Accuracy: 0.9375



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:01<06:20,  1.04s/it][A
Batches:   1%|          | 2/368 [00:01<03:24,  1.79it/s][A
Batches:   1%|          | 3/368 [00:01<02:19,  2.62it/s][A
Batches:   1%|          | 4/368 [00:01<01:45,  3.46it/s][A
Batches:   1%|▏         | 5/368 [00:01<01:35,  3.80it/s][A
Batches:   2%|▏         | 6/368 [00:01<01:26,  4.17it/s][A
Batches:   2%|▏         | 7/368 [00:02<01:17,  4.66it/s][A
Batches:   2%|▏         | 8/368 [00:02<01:14,  4.86it/s][A
Batches:   2%|▏         | 9/368 [00:02<01:06,  5.44it/s][A
Batches:   3%|▎         | 10/368 [00:02<01:03,  5.65it/s][A
Batches:   3%|▎         | 11/368 [00:02<00:59,  5.95it/s][A
Batches:   3%|▎         | 12/368 [00:02<00:58,  6.09it/s][A
Batches:   4%|▎         | 13/368 [00:03<01:02,  5.69it/s][A
Batches:   4%|▍         | 14/368 [00:03<00:58,  6.01it/s][A
Batches:   4%|▍         | 15/368 [00:03<00:55,  6.37it/s][A
Batches:   4%|▍         | 16/368 [00:03<0

Epoch 6/10, Train Loss: 0.2194, Validation Loss: 0.0019, Validation Accuracy: 0.9375



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:01<09:59,  1.63s/it][A
Batches:   1%|          | 2/368 [00:01<05:16,  1.16it/s][A
Batches:   1%|          | 3/368 [00:02<03:28,  1.75it/s][A
Batches:   1%|          | 4/368 [00:02<02:32,  2.39it/s][A
Batches:   1%|▏         | 5/368 [00:02<02:02,  2.97it/s][A
Batches:   2%|▏         | 6/368 [00:02<01:42,  3.52it/s][A
Batches:   2%|▏         | 7/368 [00:02<01:25,  4.24it/s][A
Batches:   2%|▏         | 8/368 [00:03<01:20,  4.45it/s][A
Batches:   2%|▏         | 9/368 [00:03<01:17,  4.65it/s][A
Batches:   3%|▎         | 10/368 [00:03<01:09,  5.15it/s][A
Batches:   3%|▎         | 11/368 [00:03<01:07,  5.29it/s][A
Batches:   3%|▎         | 12/368 [00:03<01:03,  5.65it/s][A
Batches:   4%|▎         | 13/368 [00:03<01:01,  5.78it/s][A
Batches:   4%|▍         | 14/368 [00:04<00:58,  6.06it/s][A
Batches:   4%|▍         | 15/368 [00:04<00:55,  6.38it/s][A
Batches:   4%|▍         | 16/368 [00:04<0

Epoch 7/10, Train Loss: 0.1888, Validation Loss: 0.0011, Validation Accuracy: 1.0000



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:02<12:35,  2.06s/it][A
Batches:   1%|          | 2/368 [00:02<06:25,  1.05s/it][A
Batches:   1%|          | 3/368 [00:02<04:20,  1.40it/s][A
Batches:   1%|          | 4/368 [00:03<03:17,  1.84it/s][A
Batches:   1%|▏         | 5/368 [00:03<02:40,  2.26it/s][A
Batches:   2%|▏         | 6/368 [00:03<02:27,  2.46it/s][A
Batches:   2%|▏         | 7/368 [00:03<02:06,  2.86it/s][A
Batches:   2%|▏         | 8/368 [00:04<01:55,  3.13it/s][A
Batches:   2%|▏         | 9/368 [00:04<01:50,  3.23it/s][A
Batches:   3%|▎         | 10/368 [00:04<01:50,  3.24it/s][A
Batches:   3%|▎         | 11/368 [00:04<01:45,  3.39it/s][A
Batches:   3%|▎         | 12/368 [00:05<01:42,  3.46it/s][A
Batches:   4%|▎         | 13/368 [00:05<01:38,  3.62it/s][A
Batches:   4%|▍         | 14/368 [00:05<01:36,  3.69it/s][A
Batches:   4%|▍         | 15/368 [00:05<01:35,  3.70it/s][A
Batches:   4%|▍         | 16/368 [00:06<0

Epoch 8/10, Train Loss: 0.1768, Validation Loss: 0.0034, Validation Accuracy: 0.8750



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:00<05:29,  1.11it/s][A
Batches:   1%|          | 2/368 [00:01<03:05,  1.97it/s][A
Batches:   1%|          | 3/368 [00:01<02:21,  2.59it/s][A
Batches:   1%|          | 4/368 [00:01<01:52,  3.23it/s][A
Batches:   1%|▏         | 5/368 [00:01<01:32,  3.92it/s][A
Batches:   2%|▏         | 6/368 [00:01<01:19,  4.57it/s][A
Batches:   2%|▏         | 7/368 [00:02<01:10,  5.09it/s][A
Batches:   2%|▏         | 8/368 [00:02<01:06,  5.45it/s][A
Batches:   2%|▏         | 9/368 [00:02<01:12,  4.98it/s][A
Batches:   3%|▎         | 10/368 [00:02<01:13,  4.86it/s][A
Batches:   3%|▎         | 11/368 [00:02<01:09,  5.12it/s][A
Batches:   3%|▎         | 12/368 [00:02<01:07,  5.29it/s][A
Batches:   4%|▎         | 13/368 [00:03<00:59,  6.00it/s][A
Batches:   4%|▍         | 14/368 [00:03<00:59,  5.92it/s][A
Batches:   4%|▍         | 15/368 [00:03<00:55,  6.33it/s][A
Batches:   4%|▍         | 16/368 [00:03<0

Epoch 9/10, Train Loss: 0.1896, Validation Loss: 0.0010, Validation Accuracy: 1.0000



Batches:   0%|          | 0/368 [00:00<?, ?it/s][A
Batches:   0%|          | 1/368 [00:01<10:28,  1.71s/it][A
Batches:   1%|          | 2/368 [00:02<05:34,  1.09it/s][A
Batches:   1%|          | 3/368 [00:02<03:38,  1.67it/s][A
Batches:   1%|          | 4/368 [00:02<02:38,  2.29it/s][A
Batches:   1%|▏         | 5/368 [00:02<02:02,  2.97it/s][A
Batches:   2%|▏         | 6/368 [00:02<01:41,  3.57it/s][A
Batches:   2%|▏         | 7/368 [00:03<01:31,  3.93it/s][A
Batches:   2%|▏         | 8/368 [00:03<01:22,  4.37it/s][A
Batches:   2%|▏         | 9/368 [00:03<01:17,  4.66it/s][A
Batches:   3%|▎         | 10/368 [00:03<01:11,  5.03it/s][A
Batches:   3%|▎         | 11/368 [00:03<01:07,  5.25it/s][A
Batches:   3%|▎         | 12/368 [00:03<01:05,  5.41it/s][A
Batches:   4%|▎         | 13/368 [00:04<01:00,  5.91it/s][A
Batches:   4%|▍         | 14/368 [00:04<00:59,  6.00it/s][A
Batches:   4%|▍         | 15/368 [00:04<00:55,  6.37it/s][A
Batches:   4%|▍         | 16/368 [00:04<0

In [None]:
model.load_state_dict(torch.load("best_model.pth", map_location=torch.device("cpu")))

# 🔹 Định nghĩa thông tin class
class_to_idx = {"cat": 0, "dog": 1}  # Ví dụ: cat = 0, dog = 1

# 🔹 Thực hiện dự đoán trên một hình ảnh mẫu
image_path = "/mnt/c/Study/OD_Project/good_1.jpg"
inference(model, image_path, transform, device, class_to_idx, threshold=0.5)