In [None]:
import os
import torch
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import cv2


In [None]:
transform = transforms.Compose([
    transforms.ColorJitter(0.1, 0.1, 0.1, 0.1),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

dataset = datasets.ImageFolder('dataset', transform=transform)

# Kiểm tra nhãn
print('Classes:', dataset.classes)


In [None]:
train_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [len(dataset) - 50, 50]
)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=8, shuffle=True, num_workers=0
)

test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=8, shuffle=False, num_workers=0
)


In [None]:
model = models.resnet18(pretrained=True)
model.fc = torch.nn.Linear(model.fc.in_features, 3)  # 3 lớp: blocked, normal, boost

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


In [None]:
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
NUM_EPOCHS = 30
BEST_MODEL_PATH = 'best_model_resnet18_3class.pth'
best_accuracy = 0.0


In [None]:
for epoch in range(NUM_EPOCHS):
    model.train()
    for images, labels in train_loader:
        images = images.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        loss.backward()
        optimizer.step()

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            predicted = torch.argmax(outputs, dim=1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    accuracy = correct / total
    print(f'Epoch {epoch + 1}: Accuracy = {accuracy:.4f}')

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), BEST_MODEL_PATH)
