蘑菇分类数据集创建脚本
======================

本脚本实现了蘑菇图像数据集的自动化收集过程，包括：
1. 从百度图片搜索下载蘑菇图像，创建分类数据集
2. 从百度百科抓取蘑菇描述信息，存储到SQLite数据库

涵盖36种常见食用菌的图像和相关信息。

In [None]:
!pip install selenium webdriver_manager

import os
import time
import requests
from bs4 import BeautifulSoup
from selenium import webdriver
from selenium.webdriver.common.by import By
from selenium.webdriver.chrome.service import Service
from selenium.webdriver.chrome.options import Options
from webdriver_manager.chrome import ChromeDriverManager
import sqlite3
from PIL import Image
from io import BytesIO

# 添加Google Chrome的存储库
!wget -q -O - https://dl-ssl.google.com/linux/linux_signing_key.pub | apt-key add -
!echo "deb [arch=amd64] http://dl.google.com/linux/chrome/deb/ stable main" >> /etc/apt/sources.list.d/google-chrome.list

# 更新包列表
!apt-get update

# 安装Chrome浏览器
!apt-get install -y google-chrome-stable

# 检查Chrome是否安装成功
!google-chrome-stable --version

# 常量定义

In [None]:
# 36种常见食用菌名称
MUSHROOM_TYPES = [
    "羊肚菌", "牛肝菌", "鸡油菌", "鸡枞菌", "青头菌", "奶浆菌", "干巴菌", "虎掌菌",
    "白葱牛肝菌", "老人头菌", "猪肚菌", "谷熟菌", "白参菌", "黑木耳", "银耳", "金耳",
    "猴头菇", "香菇", "平菇", "金针菇", "口蘑", "鹿茸菇", "榆黄蘑", "榛蘑", "草菇",
    "鸡腿菇", "茶树菇", "蟹味菇", "白玉菇", "红菇", "杏鲍菇", "松茸", "姬松茸", "松露",
    "竹荪", "虫草花"
]

# 每种蘑菇需要收集的图像数量
IMAGES_PER_TYPE = 5

# 数据目录
DATA_DIR = "data"

# 图像存储目录
IMAGES_DIR = "images"

# 数据库文件名
DATABASE_FILE = 'mushrooms.db'

# 标签文件
LABEL_FILE = 'label.txt'

# 请求头
HEADERS = {
    'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 '
                  '(KHTML, like Gecko) Chrome/132.0.0.0 Safari/537.36 Edg/132.0.0.0'
}

# 图像下载相关函数

In [None]:
def setup_chrome_driver():
    """
    设置并初始化Chrome WebDriver
    
    返回:
        webdriver.Chrome: 配置好的Chrome WebDriver实例
    """
    chrome_options = webdriver.ChromeOptions()
    chrome_options.add_argument("--headless")
    chrome_options.add_argument("--no-sandbox")
    chrome_options.add_argument("--disable-dev-shm-usage")
    
    # 在Colab环境中指定Chrome的路径
    chrome_options.binary_location = '/usr/bin/google-chrome'
    
    return webdriver.Chrome(
        service=Service(ChromeDriverManager().install()),
        options=chrome_options
    )

def download_images(query_i, query, num_images=IMAGES_PER_TYPE, max_retries=3):
    """
    从百度图片搜索下载指定关键词的图像
    
    参数:
        query_i (int): 查询索引，用于创建目录名
        query (str): 查询关键词，例如蘑菇名称
        num_images (int): 要下载的图像数量
        max_retries (int): 下载失败时的最大重试次数
        
    返回:
        int: 成功下载的图像数量
    """
    # 初始化WebDriver
    driver = setup_chrome_driver()

    # 创建存储图片的目录
    query_dir = os.path.join(DATA_DIR, f"class{query_i}")
    if not os.path.exists(query_dir):
        os.makedirs(query_dir)

    # 搜索URL
    search_url = f"https://image.baidu.com/search/index?tn=baiduimage&word={query}"

    try:
        driver.get(search_url)
        images = set()  # 使用集合避免重复图像
        
        # 滚动页面以加载更多图像直到满足需求
        scroll_attempts = 0
        max_scroll_attempts = 20  # 防止无限滚动
        
        while len(images) < num_images and scroll_attempts < max_scroll_attempts:
            # 获取所有图片元素
            soup = BeautifulSoup(driver.page_source, 'html.parser')
            img_tags = soup.find_all('img', class_='main_img')

            # 收集图像URL
            for img in img_tags:
                src = img.get('src')
                if src and src.startswith('http'):
                    images.add(src)

            # 向下滚动页面以加载更多图片
            driver.execute_script("window.scrollBy(0, document.body.scrollHeight);")
            time.sleep(2)  # 等待页面加载更多图片
            scroll_attempts += 1

        print(f"找到 {len(images)} 张 {query} 的图像。")

        # 将集合转换为列表并限制数量
        images_list = list(images)[:num_images]

        # 下载图片
        successful_downloads = 0
        for i, img_url in enumerate(images_list):
            retries = 0
            while retries < max_retries:
                try:
                    response = requests.get(img_url, timeout=5)
                    if response.status_code == 200:
                        with open(os.path.join(query_dir, f"{i}.jpg"), "wb") as file:
                            file.write(response.content)
                        print(f"已下载图像 {i+1}/{num_images} - {query}")
                        successful_downloads += 1
                        break  # 下载成功，跳出重试循环
                    else:
                        print(f"图像 {i+1} 下载失败: HTTP状态码 {response.status_code}")
                except Exception as e:
                    print(f"尝试 {retries+1}/{max_retries}: 下载图像 {i+1} 失败: {e}")
                retries += 1
                time.sleep(1)  # 短暂延迟后重试
                
        print(f"成功下载 {successful_downloads}/{num_images} 张 {query} 的图像")
        return successful_downloads

    finally:
        driver.quit()  # 确保关闭浏览器

def rename_images_in_folders(data_dir):
    """
    对指定目录中的所有图像文件按序号重命名
    
    参数:
        data_dir (str): 包含类别子文件夹的数据目录路径
    """
    # 遍历data目录下的所有子文件夹
    for folder_name in os.listdir(data_dir):
        folder_path = os.path.join(data_dir, folder_name)
        
        if os.path.isdir(folder_path):
            print(f"处理文件夹: {folder_name}")
            # 获取文件夹中所有的图片文件
            image_files = [f for f in os.listdir(folder_path) 
                          if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif'))]
            
            print(f"找到 {len(image_files)} 个图像文件")
            
            # 重命名图片文件
            success_count = 0
            for i, old_filename in enumerate(image_files):
                old_file_path = os.path.join(folder_path, old_filename)
                new_filename = f"{i}.jpg"
                new_file_path = os.path.join(folder_path, new_filename)
                
                # 如果目标文件已存在，先删除
                if os.path.exists(new_file_path) and old_file_path != new_file_path:
                    os.remove(new_file_path)
                
                try:
                    # 如果源文件和目标文件不同，则进行重命名
                    if old_file_path != new_file_path:
                        os.rename(old_file_path, new_file_path)
                        success_count += 1
                except Exception as e:
                    print(f"重命名 {old_filename} 失败: {e}")
            
            print(f"在 {folder_name} 中成功重命名 {success_count} 个文件")

# 第一步: 下载蘑菇图像数据集

In [None]:
def download_all_mushroom_images():
    """
    下载所有蘑菇类型的图像并创建标签映射
    """
    # 确保数据目录存在
    if not os.path.exists(DATA_DIR):
        os.makedirs(DATA_DIR)
    
    # 创建标签映射文件并开始下载图像
    with open(LABEL_FILE, 'w', encoding='utf-8') as f:
        for i, keyword in enumerate(MUSHROOM_TYPES):
            print(f"\n开始下载 {keyword} 的图像（类别{i}）...")
            download_images(i, keyword, IMAGES_PER_TYPE)
            f.write(f"{keyword} class{i}\n")  # 添加换行符以提高可读性
            
    print(f"所有图像下载完成。标签映射已保存到{LABEL_FILE}文件。")
    
    # 重命名所有图像文件
    rename_images_in_folders(DATA_DIR)
    print("图像重命名完成")

download_all_mushroom_images()

# 百度百科信息抓取相关函数

In [None]:
def fetch_mushroom_info(keyword):
    """
    从百度百科获取蘑菇信息
    
    参数:
        keyword (str): 蘑菇名称关键词
        
    返回:
        tuple: (描述, 图像URL)
    """
    url = f"https://baike.baidu.com/item/{keyword}"
    
    try:
        response = requests.get(url, headers=HEADERS, allow_redirects=True, timeout=10)
        response.raise_for_status()  # 抛出HTTP错误状态
    except requests.exceptions.RequestException as e:
        print(f"请求异常: {e}")
        return None, None
    
    soup = BeautifulSoup(response.content, 'html.parser')

    # 获取<meta name="description">标签的内容
    meta_tag = soup.find('meta', attrs={'name': 'description'})
    description = meta_tag['content'].strip() if meta_tag and 'content' in meta_tag.attrs else None

    # 获取第一张图片的URL
    img_tag = soup.find('meta', attrs={'name': 'image'})
    img_url = img_tag['content'].strip() if img_tag and 'content' in img_tag.attrs else None

    return description, img_url

def download_image(name, img_url, folder=IMAGES_DIR):
    """
    下载并保存图像
    
    参数:
        name (str): 图像名称
        img_url (str): 图像URL
        folder (str): 保存文件夹
        
    返回:
        str: 保存的图像路径
    """
    if not os.path.exists(folder):
        os.makedirs(folder)

    try:
        response = requests.get(img_url, timeout=10)
        response.raise_for_status()
        
        img_name = os.path.join(folder, f"{name}.jpg")
        with open(img_name, 'wb') as handler:
            handler.write(response.content)
        return img_name
    except Exception as e:
        print(f"下载图像失败: {e}")
        return None

def create_database():
    """
    创建SQLite数据库和表结构
    """
    with sqlite3.connect(DATABASE_FILE) as conn:
        c = conn.cursor()
        c.execute('''CREATE TABLE IF NOT EXISTS mushrooms
                    (name TEXT PRIMARY KEY, description TEXT, image_path TEXT)''')
        conn.commit()

def save_to_database(name, description, image_path):
    """
    将蘑菇信息保存到数据库
    
    参数:
        name (str): 蘑菇名称
        description (str): 蘑菇描述
        image_path (str): 图像文件路径
    """
    with sqlite3.connect(DATABASE_FILE) as conn:
        c = conn.cursor()
        # 检查是否已存在相同名称的记录
        c.execute("SELECT * FROM mushrooms WHERE name = ?", (name,))
        if c.fetchone():
            # 更新现有记录
            c.execute("UPDATE mushrooms SET description = ?, image_path = ? WHERE name = ?", 
                     (description, image_path, name))
        else:
            # 插入新记录
            c.execute("INSERT INTO mushrooms (name, description, image_path) VALUES (?, ?, ?)", 
                     (name, description, image_path))
        conn.commit()


# 第二步: 获取蘑菇描述信息

In [None]:
def fetch_and_save_mushroom_info():
    """
    获取并保存所有蘑菇信息到数据库
    """
    create_database()
    success_count = 0
    
    for keyword in MUSHROOM_TYPES:
        print(f"获取 {keyword} 的信息...")
        description, img_url = fetch_mushroom_info(keyword)
        image_path = None
        
        if img_url:
            try:
                image_path = download_image(keyword, img_url)
                print(f"已下载 {keyword} 的图像: {image_path}")
            except Exception as e:
                print(f"下载 {keyword} 的图像失败: {e}")

        if description:
            save_to_database(keyword, description, image_path)
            print(f"已保存 {keyword} 的信息。")
            print(description)
            success_count += 1
        else:
            print(f"未找到 {keyword} 的信息。")
    
    print(f"成功获取了 {success_count}/{len(MUSHROOM_TYPES)} 种蘑菇的信息")

fetch_and_save_mushroom_info()