In [15]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from tqdm import tqdm  # 進度條
import matplotlib.pyplot as plt
from PIL import Image

# 設定設備
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 設定圖片路徑
dataset_path = "corp_augmented_data"
batch_size = 1024
IMG_SIZE = 64

# 影像轉換 (預處理)
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # 標準化 -1 到 1
])

# 建立資料集
dataset = datasets.ImageFolder(root=dataset_path, transform=transform)
train_size = int(0.8 * len(dataset))  # 80% 訓練，20% 測試
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

# 建立 DataLoader
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# 顯示類別對應
print(f"類別對應: {dataset.class_to_idx}")

# 建立 CNN 模型
class CoffeeBeanCNN(nn.Module):
    def __init__(self):
        super(CoffeeBeanCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(128 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 2)  # 2 類別 (good, bad)
        self.dropout = nn.Dropout(0.5)
    
    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = self.pool(torch.relu(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 初始化模型
model = CoffeeBeanCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 訓練模型 (加入進度條)
num_epochs = 40
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    # tqdm 進度條
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=True)

    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        correct += predicted.eq(labels).sum().item()
        total += labels.size(0)
        
        # 更新進度條資訊
        progress_bar.set_postfix(loss=loss.item(), acc=100 * correct / total)

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {running_loss:.4f}, Accuracy: {train_acc:.2f}%")

# 儲存模型
torch.save(model.state_dict(), "coffee_bean_cnn.pth")
print("✅ 模型已儲存!")

# 測試模型 (加入進度條)
model.eval()
correct = 0
total = 0
progress_bar = tqdm(test_loader, desc="Testing", leave=True)

with torch.no_grad():
    for images, labels in progress_bar:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        # 更新進度條資訊
        progress_bar.set_postfix(acc=100 * correct / total)

test_acc = 100 * correct / total
print(f"📊 測試準確度: {test_acc:.2f}%")


類別對應: {'bad': 0, 'good': 1}


Epoch 1/40: 100%|██████████| 4/4 [00:01<00:00,  2.11it/s, acc=58.8, loss=0.654]


Epoch 1/40 - Loss: 2.6891, Accuracy: 58.75%


Epoch 2/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.643]


Epoch 2/40 - Loss: 2.5888, Accuracy: 66.55%


Epoch 3/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.65] 


Epoch 3/40 - Loss: 2.5808, Accuracy: 66.55%


Epoch 4/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.63] 


Epoch 4/40 - Loss: 2.5496, Accuracy: 66.55%


Epoch 5/40: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, acc=66.6, loss=0.627]


Epoch 5/40 - Loss: 2.5484, Accuracy: 66.55%


Epoch 6/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.642]


Epoch 6/40 - Loss: 2.5478, Accuracy: 66.55%


Epoch 7/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.635]


Epoch 7/40 - Loss: 2.5355, Accuracy: 66.55%


Epoch 8/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.61] 


Epoch 8/40 - Loss: 2.5059, Accuracy: 66.55%


Epoch 9/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.598]


Epoch 9/40 - Loss: 2.4894, Accuracy: 66.55%


Epoch 10/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.622]


Epoch 10/40 - Loss: 2.4857, Accuracy: 66.55%


Epoch 11/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.616]


Epoch 11/40 - Loss: 2.4582, Accuracy: 66.55%


Epoch 12/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.605]


Epoch 12/40 - Loss: 2.4198, Accuracy: 66.55%


Epoch 13/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.603]


Epoch 13/40 - Loss: 2.4044, Accuracy: 66.55%


Epoch 14/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=66.6, loss=0.624]


Epoch 14/40 - Loss: 2.4013, Accuracy: 66.55%


Epoch 15/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.614]


Epoch 15/40 - Loss: 2.3796, Accuracy: 66.55%


Epoch 16/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.557]


Epoch 16/40 - Loss: 2.3368, Accuracy: 66.55%


Epoch 17/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.584]


Epoch 17/40 - Loss: 2.3220, Accuracy: 66.55%


Epoch 18/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.559]


Epoch 18/40 - Loss: 2.2789, Accuracy: 66.55%


Epoch 19/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.589]


Epoch 19/40 - Loss: 2.2733, Accuracy: 66.55%


Epoch 20/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=66.6, loss=0.563]


Epoch 20/40 - Loss: 2.2664, Accuracy: 66.55%


Epoch 21/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=67.8, loss=0.522]


Epoch 21/40 - Loss: 2.2122, Accuracy: 67.82%


Epoch 22/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=68.5, loss=0.559]


Epoch 22/40 - Loss: 2.2214, Accuracy: 68.52%


Epoch 23/40: 100%|██████████| 4/4 [00:01<00:00,  2.16it/s, acc=69.4, loss=0.557]


Epoch 23/40 - Loss: 2.1945, Accuracy: 69.44%


Epoch 24/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=68.5, loss=0.516]


Epoch 24/40 - Loss: 2.1548, Accuracy: 68.49%


Epoch 25/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=70.4, loss=0.551]


Epoch 25/40 - Loss: 2.1489, Accuracy: 70.37%


Epoch 26/40: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, acc=69.9, loss=0.534]


Epoch 26/40 - Loss: 2.1364, Accuracy: 69.92%


Epoch 27/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=70.1, loss=0.552]


Epoch 27/40 - Loss: 2.1346, Accuracy: 70.15%


Epoch 28/40: 100%|██████████| 4/4 [00:01<00:00,  2.17it/s, acc=70.8, loss=0.548]


Epoch 28/40 - Loss: 2.1139, Accuracy: 70.76%


Epoch 29/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=70.5, loss=0.489]


Epoch 29/40 - Loss: 2.0510, Accuracy: 70.45%


Epoch 30/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=71.8, loss=0.507]


Epoch 30/40 - Loss: 2.0422, Accuracy: 71.77%


Epoch 31/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=72.4, loss=0.517]


Epoch 31/40 - Loss: 2.0611, Accuracy: 72.36%


Epoch 32/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=71.1, loss=0.512]


Epoch 32/40 - Loss: 2.0071, Accuracy: 71.13%


Epoch 33/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=72.4, loss=0.509]


Epoch 33/40 - Loss: 2.0003, Accuracy: 72.39%


Epoch 34/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=73.4, loss=0.491]


Epoch 34/40 - Loss: 1.9978, Accuracy: 73.37%


Epoch 35/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=72.9, loss=0.483]


Epoch 35/40 - Loss: 1.9620, Accuracy: 72.92%


Epoch 36/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=74, loss=0.467]  


Epoch 36/40 - Loss: 1.9111, Accuracy: 73.99%


Epoch 37/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=74.5, loss=0.456]


Epoch 37/40 - Loss: 1.9350, Accuracy: 74.49%


Epoch 38/40: 100%|██████████| 4/4 [00:01<00:00,  2.19it/s, acc=74.3, loss=0.551]


Epoch 38/40 - Loss: 1.9726, Accuracy: 74.30%


Epoch 39/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=76.3, loss=0.506]


Epoch 39/40 - Loss: 1.9413, Accuracy: 76.29%


Epoch 40/40: 100%|██████████| 4/4 [00:01<00:00,  2.18it/s, acc=76.5, loss=0.433]


Epoch 40/40 - Loss: 1.8720, Accuracy: 76.54%
✅ 模型已儲存!


Testing: 100%|██████████| 1/1 [00:00<00:00,  2.52it/s, acc=66.8]

📊 測試準確度: 66.78%





In [12]:
print(f"訓練集圖片數量：{len(train_dataset)} 張")
print(f"測試集圖片數量：{len(test_dataset)} 張")
print(f"訓練集 Batch 數量：{len(train_loader)} 個 batch（每 batch {batch_size} 張）")
print(f"測試集 Batch 數量：{len(test_loader)} 個 batch（每 batch {batch_size} 張）")


訓練集圖片數量：3564 張
測試集圖片數量：891 張
訓練集 Batch 數量：56 個 batch（每 batch 64 張）
測試集 Batch 數量：14 個 batch（每 batch 64 張）
