In [None]:
# 数据集地址：https://wukong-dataset.github.io/wukong-dataset/download.html
# 论文地址：https://arxiv.org/abs/2202.06767
# 数据集百度网盘地址：https://pan.baidu.com/share/init?surl=HbVGfdFvN8FIw7f-lSU4pg 提取码：noah

In [None]:
# 数据下载代码，可在此基础上完善
import os
import csv
import subprocess
import pandas as pd
import logging
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import Pool

# 文件和文件夹路径
csv_dir = "wukong_release/" #元文件存放地址，其格式是以csv的格式存在
output_base_dir = "wukong_train/" 
anno_dir = "wukong_anno/"
log_file = "/hdownload_errors.log"  # 新增日志文件

# 设置日志记录
logging.basicConfig(filename=log_file, level=logging.ERROR)

# 最大重试次数
max_retry_attempts = 3

# 超时限制（秒）
timeout = 30

# 定义下载函数
def download_image(args):
    index, row, csv_name = args
    image_url = row['url']
    caption = row['caption']
    filename = f"{csv_name}_{index + 1}.jpg"  # 使用索引来命名文件
    image_dir = os.path.join(output_base_dir, csv_name)  # 定义 image_dir
    os.makedirs(image_dir, exist_ok=True)
    image_path = os.path.join(image_dir, filename)
    retry_attempts = 0
    while retry_attempts < max_retry_attempts:
        # 使用curl命令下载图像，增加超时限制
        result = subprocess.run(['curl', '-o', image_path, '--connect-timeout', str(timeout), image_url], 
                                stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        if result.returncode == 0:
            # 下载成功，退出循环
            break
        else:
            # 下载失败，记录错误信息到日志
            error_message = f"Error downloading image {image_url}: curl returned non-zero exit code {result.returncode}"
            logging.error(error_message)
            retry_attempts += 1
    # 如果达到最大重试次数仍然失败，记录错误信息
    if retry_attempts == max_retry_attempts:
        error_message = f"Failed to download image {image_url} after {max_retry_attempts} attempts."
        logging.error(error_message)
        return
    
    # 返回文件名和caption
    return (filename, caption)

# 处理单个CSV文件的下载任务
def process_csv(csv_file):
    # 创建图像目录
    csv_name = os.path.splitext(csv_file)[0]
    image_dir = os.path.join(output_base_dir, csv_name)
    os.makedirs(image_dir, exist_ok=True)

    # 指定要处理的CSV文件
    csv_path = os.path.join(csv_dir, csv_file)
    image_anno = pd.read_csv(csv_path)

    download_results = []

    with ThreadPoolExecutor(max_workers=64) as executor:  # 设置线程数量
        args = [(index, row, csv_name) for index, row in image_anno.iterrows()]
        download_results = list(executor.map(download_image, args))

    # 将下载结果转换为DataFrame
    download_results = [result for result in download_results if result is not None]
    downloaded_data = pd.DataFrame(download_results, columns=['url', 'caption'])

    # 只保留文件名和 caption 列
    downloaded_data = downloaded_data[['url', 'caption']]

    # 将新的DataFrame写回CSV文件中
    new_csv_path = os.path.join(anno_dir, f'{csv_name}_updated.csv')
    downloaded_data.to_csv(new_csv_path, index=False)

# 遍历wukong_release/下的所有CSV文件并下载图像
with Pool(processes=16) as pool:  # 设置进程数量
    pool.map(process_csv, [csv_file for csv_file in os.listdir(csv_dir) if csv_file.endswith('.csv')])

In [None]:
# s5cmd将数据上传至S3数据库
# 第一步：安装s5cmd库 
# s5cmd库地址 https://github.com/peak/s5cmd
!curl -L https://github.com/peak/s5cmd/releases/download/v2.0.0/s5cmd_2.0.0_Linux-64bit.tar.gz | tar -xz s5cmd
!chmod +x ./s5cmd

In [None]:
# 测试s5cmd
!./s5cmd ls s3://com.zetyun.data/test1/

In [None]:
# 第二步，经过反复测试，使用s5cmd上传数据时，如果一次上传太多会导致系统卡死；具体地：当上传较多文件时在/tmp路径下会产生特别多的临时文件导致cpu进程卡死程序终端；
# 因此，需要将文件在做切分，代码如下：
import os
import shutil

# 源文件和目标文件夹路径
source_path = 'wukong_train'
SOURCE_DIRS = [source_path + '/' + i for i in os.listdir(source_path)]
DEST_DIRS = ['t_' +  i for i in SOURCE_DIRS]

# 遍历每个源文件夹和目标文件夹
for source_dir, dest_dir in zip(SOURCE_DIRS, DEST_DIRS):
    # 创建目标文件夹
    os.makedirs(dest_dir, exist_ok=True)

    # 计数器和子文件索引
    count = 0
    index = 1

    # 创建第一个子文件夹
    os.makedirs(os.path.join(dest_dir, f"wukong_{index}"), exist_ok=True)

    # 遍历源文件夹中的图像文件
    for file_name in os.listdir(source_dir):
        source_file = os.path.join(source_dir, file_name)

        # 检查是否达到 4000 张图像
        if count == 4000:
            # 增加子文件索引
            index += 1
            # 创建下一个子文件夹
            os.makedirs(os.path.join(dest_dir, f"wukong_{index}"), exist_ok=True)
            # 重置计数器
            count = 0

        # 将图像文件移动到目标子文件夹
        shutil.move(source_file, os.path.join(dest_dir, f"wukong_{index}", file_name))

        # 增加计数器
        count += 1

In [None]:
# 第三步，将处理后的文件上传s3
import os
import time
import multiprocessing as mp

def process_data(file_path, s3_file):
    cmd = f'/home/ec2-user/SageMaker/s5cmd sync {file_path} {s3_file}/'
    os.system(cmd)

def tanss(local_paths):
    files_list = []

    for item in os.listdir(local_paths):
        local_path = local_paths + '/'+ item + '/'
        s3_houzui = local_path.split('/')[1]
        s3_file = 's3://rawdata.s3.bucket/wukong/raw_data/wukong_train/' + s3_houzui
        files_list.append((local_path, s3_file))
        
    process_num = 14
    pool = mp.Pool(process_num)

    for file_info in files_list:
        pool.apply_async(process_data, args=file_info)

    pool.close()
    pool.join()

    
if __name__ == "__main__":
    # 源文件和目标文件夹路径
    source_path = 't_wukong_train'
    SOURCE_DIRS = [source_path + '/' + i for i in os.listdir(source_path)]
    for i in SOURCE_DIRS:
        tanss(i)

In [None]:
# 30w图像数据，通过分割之后用s5cmd上传只需要450s左右，而通过下面命令上传需要2500s左右
# !./s5cmd sync imgdata/wukong_100m_108  s3://rawdata.s3.bucket/wukong/imgs_wukong/train-imgs/ > /dev/null