In [1]:
import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import os
import pandas as pd
import json

from models.hybrid_model import BrainTumorModel
from dataloader.loader import get_data_loaders
from utils.helpers import init_weights

In [2]:
def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    correct = 0
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

    acc = correct / total
    avg_loss = total_loss / len(train_loader)

    return avg_loss, acc

In [3]:
def validate(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    avg_loss = total_loss / len(val_loader)

    return avg_loss, acc

In [4]:
data_dir = "../dataset"
batch_size = 32
img_size = 224
num_epochs = 25
lr = 1e-4

In [5]:
MODELS_DIR = './models'
LOGS_DIR = './logs'

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

Device: cpu


In [6]:
train_loader, val_loader, test_loader, class_names = get_data_loaders(data_dir, batch_size, img_size)

In [11]:
model = BrainTumorModel(num_classes=len(class_names)).to(device)
model.apply(init_weights)

BrainTumorModel(
  (stem): Sequential(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((96,), eps=1e-05, elementwise_affine=True)
  )
  (stage1): Sequential(
    (0): ConvNeXtBlock(
      (dwconv): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
      (norm): LayerNorm2d((96,), eps=1e-05, elementwise_affine=True)
      (pwconv1): Linear(in_features=96, out_features=384, bias=True)
      (act): GELU(approximate='none')
      (pwconv2): Linear(in_features=384, out_features=96, bias=True)
      (drop_path): Identity()
    )
    (1): ConvNeXtBlock(
      (dwconv): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
      (norm): LayerNorm2d((96,), eps=1e-05, elementwise_affine=True)
      (pwconv1): Linear(in_features=96, out_features=384, bias=True)
      (act): GELU(approximate='none')
      (pwconv2): Linear(in_features=384, out_features=96, bias=True)
      (drop_path): Identity()
    )
    (2): ConvNe

In [9]:
criterion = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)

In [12]:
history = []
best_val_loss = float('inf')
best_val_acc = 0
patience = 5
patience_counter = 0

In [14]:
class_names_path = os.path.join(MODELS_DIR, "class_names.json")
with open(class_names_path, 'w') as f:
    json.dump(class_names, f)
print(f"Đã lưu tên lớp: {class_names_path}")

Đã lưu tên lớp: ./models/class_names.json


In [None]:
for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, val_loader, criterion, device)

    scheduler.step()

    history.append({
        'epoch': epoch + 1,
        'train_loss': train_loss,
        'train_acc': train_acc,
        'val_loss': val_loss,
        'val_acc': val_acc
    })

    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_val_acc = val_acc
        patience_counter = 0
        save_path = os.path.join(MODELS_DIR, "best_model.pth")
        torch.save(model.state_dict(), save_path)
        print(f"Saved best model! Val Loss: {val_loss:.4f}")
        
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print(f"Dừng sớm: Val Loss không cải thiện sau {patience} epoch.")
            break

In [None]:
df_history = pd.DataFrame(history)
log_path = os.path.join(LOGS_DIR, "training_history.csv")
df_history.to_csv(log_path, index=False)
print(f"Đã lưu lịch sử huấn luyện tại: {log_path}")
print(f"Best Validation Accuracy: {best_val_acc:.4f}")