In [2]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm
import matplotlib.pyplot as plt
from PIL import Image

# ✅ EarlyStopping 類別
class EarlyStopping:
    def __init__(self, patience=8, verbose=True, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_acc = None
        self.early_stop = False
        self.delta = delta
        self.best_model_state = None

    def __call__(self, acc, model):
        if self.best_acc is None:
            self.best_acc = acc
            self.best_model_state = model.state_dict()
        elif acc < self.best_acc + self.delta:
            self.counter += 1
            if self.verbose:
                print(f"⚠️ 早停計數器: {self.counter}/{self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            if self.verbose:
                print(f"✅ 驗證準確度提升: {self.best_acc:.2f} → {acc:.2f}，重置早停計數器")
            self.best_acc = acc
            self.best_model_state = model.state_dict()
            self.counter = 0

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 設定圖片路徑與參數
dataset_path = "corp_augmented_data"
batch_size = 1024
IMG_SIZE = 64

# 圖片預處理
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 資料集與資料分割
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f"類別對應: {dataset.class_to_idx}")

# 模型定義
class DeepCoffeeBeanCNN(nn.Module):
    def __init__(self):
        super(DeepCoffeeBeanCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # 64x64 → 64x64
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),  # 64x64
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 64x64 → 32x32

            nn.Conv2d(32, 64, kernel_size=3, padding=1),  # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),  # 32x32
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 32x32 → 16x16

            nn.Conv2d(64, 128, kernel_size=3, padding=1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),  # 16x16
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # 16x16 → 8x8
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),                # 128 x 8 x 8 = 8192
            nn.Linear(128 * 8 * 8, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 2)  # 分兩類
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x


# 初始化模型
model = CoffeeBeanCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
early_stopper = EarlyStopping(patience=8, verbose=True)

# 訓練迴圈
num_epochs = 80
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
        progress_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f}, Accuracy: {train_acc:.2f}%")

    # 驗證階段
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)
    val_acc = 100 * val_correct / val_total
    print(f"📉 驗證準確度: {val_acc:.2f}%")

    # 檢查是否早停
    early_stopper(val_acc, model)
    if early_stopper.early_stop:
        print("🛑 觸發早停，停止訓練")
        break

# 儲存最佳模型
torch.save(early_stopper.best_model_state, "coffee_bean_cnn.pth")
print("✅ 模型已儲存為最佳驗證表現!")

# 測試模型
model.load_state_dict(early_stopper.best_model_state)  # 使用最佳權重
model.eval()
correct = 0
total = 0
progress_bar = tqdm(test_loader, desc="Testing", leave=True)

with torch.no_grad():
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        progress_bar.set_postfix(acc=100 * correct / total)

test_acc = 100 * correct / total
print(f"📊 測試準確度: {test_acc:.2f}%")


類別對應: {'bad': 0, 'good': 1}


Epoch 1/80: 100%|██████████| 4/4 [00:03<00:00,  1.15it/s, acc=56.9, loss=0.647]


Epoch 1/80 - Loss: 2.6624, Accuracy: 56.93%
📉 驗證準確度: 67.79%


Epoch 2/80: 100%|██████████| 4/4 [00:02<00:00,  1.67it/s, acc=65.4, loss=0.646]


Epoch 2/80 - Loss: 2.5962, Accuracy: 65.40%
📉 驗證準確度: 67.79%
✅ 驗證準確度提升: 67.79 → 67.79，重置早停計數器


Epoch 3/80: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s, acc=65.4, loss=0.63] 


Epoch 3/80 - Loss: 2.5736, Accuracy: 65.40%
📉 驗證準確度: 67.79%
✅ 驗證準確度提升: 67.79 → 67.79，重置早停計數器


Epoch 4/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=65.4, loss=0.638]


Epoch 4/80 - Loss: 2.5665, Accuracy: 65.40%
📉 驗證準確度: 67.79%
✅ 驗證準確度提升: 67.79 → 67.79，重置早停計數器


Epoch 5/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=65.4, loss=0.614]


Epoch 5/80 - Loss: 2.5436, Accuracy: 65.40%
📉 驗證準確度: 67.79%
✅ 驗證準確度提升: 67.79 → 67.79，重置早停計數器


Epoch 6/80: 100%|██████████| 4/4 [00:02<00:00,  1.78it/s, acc=65.4, loss=0.633]


Epoch 6/80 - Loss: 2.5310, Accuracy: 65.40%
📉 驗證準確度: 67.79%
✅ 驗證準確度提升: 67.79 → 67.79，重置早停計數器


Epoch 7/80: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s, acc=65.4, loss=0.624]


Epoch 7/80 - Loss: 2.4838, Accuracy: 65.38%
📉 驗證準確度: 67.68%
⚠️ 早停計數器: 1/8


Epoch 8/80: 100%|██████████| 4/4 [00:02<00:00,  1.78it/s, acc=65.5, loss=0.622]


Epoch 8/80 - Loss: 2.4957, Accuracy: 65.46%
📉 驗證準確度: 69.02%
✅ 驗證準確度提升: 67.79 → 69.02，重置早停計數器


Epoch 9/80: 100%|██████████| 4/4 [00:02<00:00,  1.78it/s, acc=66, loss=0.613]  


Epoch 9/80 - Loss: 2.4616, Accuracy: 66.05%
📉 驗證準確度: 68.80%
⚠️ 早停計數器: 1/8


Epoch 10/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=67.5, loss=0.593]


Epoch 10/80 - Loss: 2.4301, Accuracy: 67.45%
📉 驗證準確度: 67.79%
⚠️ 早停計數器: 2/8


Epoch 11/80: 100%|██████████| 4/4 [00:02<00:00,  1.66it/s, acc=65.8, loss=0.612]


Epoch 11/80 - Loss: 2.4193, Accuracy: 65.77%
📉 驗證準確度: 68.69%
⚠️ 早停計數器: 3/8


Epoch 12/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=66.5, loss=0.606]


Epoch 12/80 - Loss: 2.4040, Accuracy: 66.47%
📉 驗證準確度: 69.25%
✅ 驗證準確度提升: 69.02 → 69.25，重置早停計數器


Epoch 13/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=67, loss=0.608]  


Epoch 13/80 - Loss: 2.3675, Accuracy: 67.00%
📉 驗證準確度: 68.57%
⚠️ 早停計數器: 1/8


Epoch 14/80: 100%|██████████| 4/4 [00:02<00:00,  1.68it/s, acc=68.4, loss=0.599]


Epoch 14/80 - Loss: 2.3395, Accuracy: 68.41%
📉 驗證準確度: 69.70%
✅ 驗證準確度提升: 69.25 → 69.70，重置早停計數器


Epoch 15/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=69, loss=0.544]  


Epoch 15/80 - Loss: 2.2890, Accuracy: 68.97%
📉 驗證準確度: 69.47%
⚠️ 早停計數器: 1/8


Epoch 16/80: 100%|██████████| 4/4 [00:02<00:00,  1.74it/s, acc=69.1, loss=0.579]


Epoch 16/80 - Loss: 2.2856, Accuracy: 69.11%
📉 驗證準確度: 70.59%
✅ 驗證準確度提升: 69.70 → 70.59，重置早停計數器


Epoch 17/80: 100%|██████████| 4/4 [00:02<00:00,  1.71it/s, acc=69.4, loss=0.568]


Epoch 17/80 - Loss: 2.2695, Accuracy: 69.36%
📉 驗證準確度: 69.70%
⚠️ 早停計數器: 1/8


Epoch 18/80: 100%|██████████| 4/4 [00:02<00:00,  1.70it/s, acc=69.5, loss=0.553]


Epoch 18/80 - Loss: 2.2451, Accuracy: 69.47%
📉 驗證準確度: 71.60%
✅ 驗證準確度提升: 70.59 → 71.60，重置早停計數器


Epoch 19/80: 100%|██████████| 4/4 [00:02<00:00,  1.76it/s, acc=69.4, loss=0.562]


Epoch 19/80 - Loss: 2.2153, Accuracy: 69.44%
📉 驗證準確度: 70.71%
⚠️ 早停計數器: 1/8


Epoch 20/80: 100%|██████████| 4/4 [00:02<00:00,  1.65it/s, acc=70.7, loss=0.533]


Epoch 20/80 - Loss: 2.2022, Accuracy: 70.65%
📉 驗證準確度: 69.47%
⚠️ 早停計數器: 2/8


Epoch 21/80: 100%|██████████| 4/4 [00:02<00:00,  1.72it/s, acc=70.3, loss=0.529]


Epoch 21/80 - Loss: 2.1718, Accuracy: 70.31%
📉 驗證準確度: 69.81%
⚠️ 早停計數器: 3/8


Epoch 22/80: 100%|██████████| 4/4 [00:02<00:00,  1.65it/s, acc=71.1, loss=0.55] 


Epoch 22/80 - Loss: 2.1772, Accuracy: 71.13%
📉 驗證準確度: 71.83%
✅ 驗證準確度提升: 71.60 → 71.83，重置早停計數器


Epoch 23/80: 100%|██████████| 4/4 [00:02<00:00,  1.68it/s, acc=72, loss=0.538]  


Epoch 23/80 - Loss: 2.1387, Accuracy: 72.00%
📉 驗證準確度: 72.73%
✅ 驗證準確度提升: 71.83 → 72.73，重置早停計數器


Epoch 24/80: 100%|██████████| 4/4 [00:02<00:00,  1.67it/s, acc=72.7, loss=0.536]


Epoch 24/80 - Loss: 2.1117, Accuracy: 72.70%
📉 驗證準確度: 72.17%
⚠️ 早停計數器: 1/8


Epoch 25/80: 100%|██████████| 4/4 [00:02<00:00,  1.67it/s, acc=72.9, loss=0.518]


Epoch 25/80 - Loss: 2.0783, Accuracy: 72.90%
📉 驗證準確度: 71.60%
⚠️ 早停計數器: 2/8


Epoch 26/80: 100%|██████████| 4/4 [00:02<00:00,  1.65it/s, acc=73.3, loss=0.484]


Epoch 26/80 - Loss: 2.0263, Accuracy: 73.32%
📉 驗證準確度: 70.71%
⚠️ 早停計數器: 3/8


Epoch 27/80: 100%|██████████| 4/4 [00:02<00:00,  1.66it/s, acc=73.7, loss=0.526]


Epoch 27/80 - Loss: 2.0407, Accuracy: 73.68%
📉 驗證準確度: 73.18%
✅ 驗證準確度提升: 72.73 → 73.18，重置早停計數器


Epoch 28/80: 100%|██████████| 4/4 [00:02<00:00,  1.75it/s, acc=74.1, loss=0.511]


Epoch 28/80 - Loss: 2.0233, Accuracy: 74.13%
📉 驗證準確度: 72.28%
⚠️ 早停計數器: 1/8


Epoch 29/80: 100%|██████████| 4/4 [00:02<00:00,  1.71it/s, acc=74.8, loss=0.5]  


Epoch 29/80 - Loss: 1.9695, Accuracy: 74.78%
📉 驗證準確度: 72.39%
⚠️ 早停計數器: 2/8


Epoch 30/80: 100%|██████████| 4/4 [00:02<00:00,  1.69it/s, acc=75.1, loss=0.501]


Epoch 30/80 - Loss: 1.9799, Accuracy: 75.11%
📉 驗證準確度: 72.95%
⚠️ 早停計數器: 3/8


Epoch 31/80: 100%|██████████| 4/4 [00:02<00:00,  1.68it/s, acc=74.7, loss=0.51] 


Epoch 31/80 - Loss: 1.9522, Accuracy: 74.72%
📉 驗證準確度: 73.18%
✅ 驗證準確度提升: 73.18 → 73.18，重置早停計數器


Epoch 32/80: 100%|██████████| 4/4 [00:02<00:00,  1.69it/s, acc=75.6, loss=0.478]


Epoch 32/80 - Loss: 1.9349, Accuracy: 75.62%
📉 驗證準確度: 72.50%
⚠️ 早停計數器: 1/8


Epoch 33/80: 100%|██████████| 4/4 [00:02<00:00,  1.66it/s, acc=76.7, loss=0.46] 


Epoch 33/80 - Loss: 1.8627, Accuracy: 76.66%
📉 驗證準確度: 72.50%
⚠️ 早停計數器: 2/8


Epoch 34/80: 100%|██████████| 4/4 [00:02<00:00,  1.59it/s, acc=76.6, loss=0.456]


Epoch 34/80 - Loss: 1.8398, Accuracy: 76.63%
📉 驗證準確度: 71.72%
⚠️ 早停計數器: 3/8


Epoch 35/80: 100%|██████████| 4/4 [00:02<00:00,  1.66it/s, acc=76.2, loss=0.454]


Epoch 35/80 - Loss: 1.8622, Accuracy: 76.15%
📉 驗證準確度: 74.07%
✅ 驗證準確度提升: 73.18 → 74.07，重置早停計數器


Epoch 36/80: 100%|██████████| 4/4 [00:02<00:00,  1.63it/s, acc=77.6, loss=0.441]


Epoch 36/80 - Loss: 1.8046, Accuracy: 77.55%
📉 驗證準確度: 72.50%
⚠️ 早停計數器: 1/8


Epoch 37/80: 100%|██████████| 4/4 [00:02<00:00,  1.39it/s, acc=78.5, loss=0.424]


Epoch 37/80 - Loss: 1.7752, Accuracy: 78.51%
📉 驗證準確度: 73.06%
⚠️ 早停計數器: 2/8


Epoch 38/80: 100%|██████████| 4/4 [00:02<00:00,  1.70it/s, acc=78.5, loss=0.419]


Epoch 38/80 - Loss: 1.7383, Accuracy: 78.54%
📉 驗證準確度: 73.29%
⚠️ 早停計數器: 3/8


Epoch 39/80: 100%|██████████| 4/4 [00:02<00:00,  1.69it/s, acc=79.4, loss=0.442]


Epoch 39/80 - Loss: 1.7350, Accuracy: 79.41%
📉 驗證準確度: 72.50%
⚠️ 早停計數器: 4/8


Epoch 40/80: 100%|██████████| 4/4 [00:02<00:00,  1.71it/s, acc=78.8, loss=0.415]


Epoch 40/80 - Loss: 1.7013, Accuracy: 78.84%
📉 驗證準確度: 73.85%
⚠️ 早停計數器: 5/8


Epoch 41/80: 100%|██████████| 4/4 [00:02<00:00,  1.57it/s, acc=79.4, loss=0.437]


Epoch 41/80 - Loss: 1.6890, Accuracy: 79.38%
📉 驗證準確度: 73.63%
⚠️ 早停計數器: 6/8


Epoch 42/80: 100%|██████████| 4/4 [00:02<00:00,  1.59it/s, acc=81.3, loss=0.4]  


Epoch 42/80 - Loss: 1.6178, Accuracy: 81.34%
📉 驗證準確度: 73.51%
⚠️ 早停計數器: 7/8


Epoch 43/80: 100%|██████████| 4/4 [00:02<00:00,  1.68it/s, acc=80.7, loss=0.404]


Epoch 43/80 - Loss: 1.6231, Accuracy: 80.72%
📉 驗證準確度: 72.95%
⚠️ 早停計數器: 8/8
🛑 觸發早停，停止訓練
✅ 模型已儲存為最佳驗證表現!


Testing: 100%|██████████| 1/1 [00:00<00:00,  2.06it/s, acc=73]

📊 測試準確度: 72.95%





In [2]:
print(f"訓練集圖片數量：{len(train_dataset)} 張")
print(f"測試集圖片數量：{len(test_dataset)} 張")
print(f"訓練集 Batch 數量：{len(train_loader)} 個 batch（每 batch {batch_size} 張）")
print(f"測試集 Batch 數量：{len(test_loader)} 個 batch（每 batch {batch_size} 張）")


訓練集圖片數量：3564 張
測試集圖片數量：891 張
訓練集 Batch 數量：4 個 batch（每 batch 1024 張）
測試集 Batch 數量：1 個 batch（每 batch 1024 張）
