<a href="https://colab.research.google.com/github/chiu003/git_practice/blob/main/Gradio.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install gradio

Collecting gradio
  Downloading gradio-5.6.0-py3-none-any.whl.metadata (16 kB)
Collecting aiofiles<24.0,>=22.0 (from gradio)
  Downloading aiofiles-23.2.1-py3-none-any.whl.metadata (9.7 kB)
Collecting fastapi<1.0,>=0.115.2 (from gradio)
  Downloading fastapi-0.115.5-py3-none-any.whl.metadata (27 kB)
Collecting ffmpy (from gradio)
  Downloading ffmpy-0.4.0-py3-none-any.whl.metadata (2.9 kB)
Collecting gradio-client==1.4.3 (from gradio)
  Downloading gradio_client-1.4.3-py3-none-any.whl.metadata (7.1 kB)
Collecting markupsafe~=2.0 (from gradio)
  Downloading MarkupSafe-2.1.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.0 kB)
Collecting pydub (from gradio)
  Downloading pydub-0.25.1-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting python-multipart==0.0.12 (from gradio)
  Downloading python_multipart-0.0.12-py3-none-any.whl.metadata (1.9 kB)
Collecting ruff>=0.2.2 (from gradio)
  Downloading ruff-0.8.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metad

In [4]:
!gdown --id 14oQqOOdsig4DzXJPpTYPYNdWnoTrRzDD #chicken.zip
!gdown --id 1sz2PZhN0zS01KK4AcuvgMH7_S7I4DWVT #pork.zip
!gdown --id 1G86Z5Hp1j6bt-n9GsGbKd5rNA1ou0AAg #red_rice.zip

Downloading...
From: https://drive.google.com/uc?id=14oQqOOdsig4DzXJPpTYPYNdWnoTrRzDD
To: /content/chicken.zip
100% 806k/806k [00:00<00:00, 40.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1sz2PZhN0zS01KK4AcuvgMH7_S7I4DWVT
To: /content/pork.zip
100% 1.01M/1.01M [00:00<00:00, 86.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1G86Z5Hp1j6bt-n9GsGbKd5rNA1ou0AAg
To: /content/red_rice.zip
100% 720k/720k [00:00<00:00, 78.3MB/s]


In [5]:
import os
import zipfile
from PIL import Image
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader, Dataset
import gradio as gr
import shutil
import random

# 步驟1: 解壓縮ZIP檔案並整理資料夾結構
def setup_data_directory():
    # 創建主數據目錄
    data_dir = 'bento_data'
    if os.path.exists(data_dir):
        shutil.rmtree(data_dir)
    os.makedirs(data_dir)

    # 創建訓練數據目錄和驗證數據目錄
    train_dir = os.path.join(data_dir, 'train')
    val_dir = os.path.join(data_dir, 'val')
    os.makedirs(train_dir)
    os.makedirs(val_dir)

    # 解壓縮並整理檔案
    zip_files = {
        r'/content/chicken.zip': '雞腿便當',
        r'/content/pork.zip': '排骨便當',
        r'/content/red_rice.zip': '紅糟肉便當'
    }

    for zip_file, class_name in zip_files.items():
        if not os.path.exists(zip_file):
            print(f"警告: 找不到 {zip_file}")
            continue

        # 在訓練目錄下創建類別目錄
        class_train_dir = os.path.join(train_dir, class_name)
        class_val_dir = os.path.join(val_dir, class_name)
        os.makedirs(class_train_dir)
        os.makedirs(class_val_dir)

        # 解壓縮檔案到對應目錄
        try:
            with zipfile.ZipFile(zip_file, 'r') as zip_ref:
                # 獲取 ZIP 檔案中的所有檔案
                files = zip_ref.namelist()
                random.shuffle(files)  # 隨機打亂文件順序
                train_files = files[:int(0.8 * len(files))]  # 80% 用於訓練
                val_files = files[int(0.8 * len(files)):]    # 20% 用於驗證

                # 解壓縮訓練集檔案
                for file in train_files:
                    zip_ref.extract(file, class_train_dir)
                # 解壓縮驗證集檔案
                for file in val_files:
                    zip_ref.extract(file, class_val_dir)
        except zipfile.BadZipFile:
            print(f"錯誤: {zip_file} 不是有效的ZIP檔案")
            continue

    return train_dir, val_dir

# 步驟2: 調整圖片大小
def resize_images(directory):
    for root, _, files in os.walk(directory):
        for filename in files:
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
                image_path = os.path.join(root, filename)
                try:
                    with Image.open(image_path) as img:
                        # 調整圖片大小為 150x300
                        resized_img = img.resize((150, 300))
                        resized_img.save(image_path)
                except Exception as e:
                    print(f"處理圖片 {filename} 時發生錯誤: {str(e)}")

# 自定義數據集類別
class BentoDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.classes = sorted(os.listdir(root_dir))
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}

        self.images = []
        for cls in self.classes:
            class_path = os.path.join(root_dir, cls)
            if not os.path.isdir(class_path):
                continue
            for filename in os.listdir(class_path):
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.webp')):
                    self.images.append((os.path.join(class_path, filename), self.class_to_idx[cls]))

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_path, label = self.images[idx]
        image = Image.open(img_path).convert('RGB')

        if self.transform:
            image = self.transform(image)

        return image, label

# 步驟3: 建立深度學習模型
class BentoClassifier(nn.Module):
    def __init__(self, num_classes):
        super(BentoClassifier, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(512, num_classes)

    def forward(self, x):
        return self.resnet(x)

# 修改訓練過程
def validate_model(val_dir, model, transform, device):
    model.eval()
    val_dataset = BentoDataset(root_dir=val_dir, transform=transform)
    val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)

    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"驗證集準確率: {accuracy:.2f}%")
    return accuracy

def train_model(train_dir, val_dir):
    # 設定 device 變數，選擇 GPU 或 CPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用的計算設備: {device}")

    # 確認是否有訓練數據
    if not os.path.exists(train_dir):
        raise FileNotFoundError(f"找不到訓練數據目錄: {train_dir}")

    # 設定數據轉換
    transform = transforms.Compose([
        transforms.Resize((150, 300)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    # 載入訓練集
    train_dataset = BentoDataset(root_dir=train_dir, transform=transform)
    train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)

    # 初始化模型
    num_classes = len(train_dataset.classes)
    model = BentoClassifier(num_classes)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.0001)

    # 訓練模型
    model.to(device)

    num_epochs = 25
    best_val_accuracy = 0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

            if i % 5 == 4:  # 每5個批次打印一次
                print(f'Epoch {epoch + 1}, Batch {i + 1}, Loss: {running_loss / 5:.3f}')
                running_loss = 0.0

        # 每個 Epoch 訓練後進行驗證
        val_accuracy = validate_model(val_dir, model, transform, device)

        # 如果驗證準確率更好，則保存模型
        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save({
                'model_state_dict': model.state_dict(),
                'classes': train_dataset.classes
            }, 'best_bento_classifier.pth')

    return model, transform, train_dataset.classes, device

# 步驟4: 預測函數
def predict_bento(image):
    model.eval()
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        confidence, predicted = torch.max(probabilities, 1)

    return {classes[predicted.item()]: confidence.item()}

# 主程式
if __name__ == "__main__":
    try:
        print("開始設置數據目錄...")
        train_dir, val_dir = setup_data_directory()

        print("調整圖片大小...")
        resize_images(train_dir)

        print("開始訓練模型...")
        model, transform, classes, device = train_model(train_dir, val_dir)

        print("設置Gradio介面...")
        # 建立Gradio介面
        iface = gr.Interface(
            fn=predict_bento,
            inputs=gr.Image(type="pil"),
            outputs=gr.Label(),
            title="便當分類器",
            description="上傳一張便當照片，讓AI幫你判斷是哪種便當"
        )

        print("啟動Gradio介面...")
        iface.launch(share=True)


    except Exception as e:
        print(f"發生錯誤: {str(e)}")
        raise


開始設置數據目錄...
調整圖片大小...
開始訓練模型...
使用的計算設備: cuda


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 192MB/s]


Epoch 1, Batch 5, Loss: 1.272
Epoch 1, Batch 10, Loss: 0.969
Epoch 1, Batch 15, Loss: 0.733
Epoch 1, Batch 20, Loss: 0.639
Epoch 1, Batch 25, Loss: 0.460
Epoch 1, Batch 30, Loss: 0.535
驗證集準確率: 93.55%
Epoch 2, Batch 5, Loss: 0.201
Epoch 2, Batch 10, Loss: 0.451
Epoch 2, Batch 15, Loss: 0.475
Epoch 2, Batch 20, Loss: 0.169
Epoch 2, Batch 25, Loss: 0.120
Epoch 2, Batch 30, Loss: 0.276
驗證集準確率: 90.32%
Epoch 3, Batch 5, Loss: 0.385
Epoch 3, Batch 10, Loss: 0.244
Epoch 3, Batch 15, Loss: 0.164
Epoch 3, Batch 20, Loss: 0.086
Epoch 3, Batch 25, Loss: 0.222
Epoch 3, Batch 30, Loss: 0.183
驗證集準確率: 90.32%
Epoch 4, Batch 5, Loss: 0.062
Epoch 4, Batch 10, Loss: 0.110
Epoch 4, Batch 15, Loss: 0.197
Epoch 4, Batch 20, Loss: 0.304
Epoch 4, Batch 25, Loss: 0.141
Epoch 4, Batch 30, Loss: 0.115
驗證集準確率: 96.77%
Epoch 5, Batch 5, Loss: 0.055
Epoch 5, Batch 10, Loss: 0.125
Epoch 5, Batch 15, Loss: 0.057
Epoch 5, Batch 20, Loss: 0.452
Epoch 5, Batch 25, Loss: 0.026
Epoch 5, Batch 30, Loss: 0.116
驗證集準確率: 87.10%
