In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torchvision.models import convnext_tiny
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt

# EarlyStopping 類別
class EarlyStopping:
    def __init__(self, patience=8, delta=0, verbose=True):
        self.patience = patience
        self.counter = 0
        self.best_acc = None
        self.early_stop = False
        self.delta = delta
        self.best_model_state = None
        self.verbose = verbose

    def __call__(self, acc, model):
        if self.best_acc is None:
            self.best_acc = acc
            self.best_model_state = model.state_dict()
        elif acc < self.best_acc + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"⚠️ 早停計數: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if self.verbose:
                print(f"✅ 驗證準確度提升: {self.best_acc:.2f} → {acc:.2f}，重置計數")
            self.best_acc = acc
            self.best_model_state = model.state_dict()
            self.counter = 0

# 裝置設定
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 圖像處理
dataset_path = "corp_augmented_data"
batch_size = 128
img_size = 64

transform = transforms.Compose([
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# 資料集
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"類別對應: {dataset.class_to_idx}")

# 模型
model = convnext_tiny(weights=None)
model.classifier[2] = nn.Linear(model.classifier[2].in_features, 2)
model = model.to(device)

# 損失與優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
early_stopper = EarlyStopping(patience=6)
# 訓練變數
num_epochs = 40
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

# 訓練迴圈
for epoch in range(num_epochs):
    model.train()
    running_loss, correct, total = 0, 0, 0
    progress = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}")

    for images, labels in progress:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)
        progress.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_acc = 100 * correct / total
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    train_accuracies.append(train_acc)
    print(f"🧠 Epoch {epoch+1}: Train Loss: {avg_train_loss:.4f}, Accuracy: {train_acc:.2f}%")

    # 驗證
    model.eval()
    val_running_loss = 0
    val_correct, val_total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            val_running_loss += loss.item()
            _, preds = outputs.max(1)
            val_correct += preds.eq(labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100 * val_correct / val_total
    avg_val_loss = val_running_loss / len(test_loader)
    val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)
    print(f"📉 驗證 Loss: {avg_val_loss:.4f}, 準確率: {val_acc:.2f}%")

    # 檢查早停
    early_stopper(val_acc, model)
    if early_stopper.early_stop:
        print("🛑 早停觸發，訓練終止")
        break

# 儲存最佳模型
torch.save(early_stopper.best_model_state, "convnext_tiny_coffeebean_best.pth")
print("📦 模型已儲存")

# 測試
model.load_state_dict(early_stopper.best_model_state)
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        correct += preds.eq(labels).sum().item()
        total += labels.size(0)

print(f"✅ 測試準確率: {100 * correct / total:.2f}%")

# 🔍 繪製 Loss & Accuracy 曲線
plt.figure(figsize=(12, 5))

# Loss 曲線
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Loss Curve")
plt.legend()
plt.grid(True)

# Accuracy 曲線
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Acc')
plt.plot(val_accuracies, label='Val Acc')
plt.xlabel("Epoch")
plt.ylabel("Accuracy (%)")
plt.title("Accuracy Curve")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig("loss_acc_curve.png")
plt.show()


類別對應: {'bad': 0, 'good': 1}


Epoch 1/40:  32%|███▏      | 9/28 [00:03<00:05,  3.24it/s, acc=53.5, loss=0.795]