In [None]:
import cv2
import numpy as np
import os
import re
import h5py
import pickle
def sample_frames(video_path, num_frames=10):
    video = cv2.VideoCapture(video_path)
    total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
    frame_idxs = np.linspace(0, total_frames - 1, num_frames, dtype=int)
    frames = []
    for idx in frame_idxs:
        video.set(cv2.CAP_PROP_POS_FRAMES, idx)
        _, frame = video.read()
        frames.append(frame)
    video.release()
    return np.stack(frames)

def natural_sort_key(s):
    return [int(text) if text.isdigit() else text.lower() for text in re.split(r'(\d+)', s)]

def process_folder_sorted(folder_path, pre_slen=10, aft_slen=10, suffix='.gif'):
    contents = os.listdir(folder_path)
    contents.sort(key=natural_sort_key)

    for item in contents:
        item_path = os.path.join(folder_path, item)
        if os.path.isdir(item_path):
            print(f"进入子文件夹: {item_path}")
            yield from process_folder_sorted(item_path, pre_slen, aft_slen, suffix)
        elif os.path.isfile(item_path) and item.endswith(suffix):
            video = sample_frames(item_path, pre_slen + aft_slen)
            data_x = video[:pre_slen]
            data_y = video[pre_slen:]
            data_x = data_x.astype(np.float32) / 255.0 if data_x.max() > 1.0 else data_x
            data_y = data_y.astype(np.float32) / 255.0 if data_y.max() > 1.0 else data_y
            yield data_x, data_y

def save_to_h5(filename, dataset_name, data):
    with h5py.File(filename, 'a') as h5f:
        if dataset_name in h5f:
            dset = h5f[dataset_name]
            dset.resize(dset.shape[0] + data.shape[0], axis=0)
            dset[-data.shape[0]:] = data
        else:
            maxshape = (None,) + data.shape[1:]  # Allow unlimited rows
            h5f.create_dataset(dataset_name, data=data, maxshape=maxshape, chunks=True)


folders = ['train', 'val', 'test']
pre_seq_length = 10
aft_seq_length = 10

for folder in folders:
    folder_path = os.path.join(os.getcwd(), folder)
    batch_size = 10  
    data_x_list = []
    data_y_list = []
    batch_count = 0
    for data_x, data_y in process_folder_sorted(folder_path, pre_slen=pre_seq_length, aft_slen=aft_seq_length, suffix='.gif'):
        data_x_list.append(data_x)
        data_y_list.append(data_y)
        if len(data_x_list) >= batch_size:
            data_x = np.array(data_x_list)
            data_y = np.array(data_y_list)
            data_x = np.transpose(data_x, (0, 1, 4, 2, 3))
            data_y = np.transpose(data_y, (0, 1, 4, 2, 3))
            save_to_h5(f'dataset_{folder}_sorted_test.h5', f'X_{folder}', data_x)
            save_to_h5(f'dataset_{folder}_sorted_test.h5', f'Y_{folder}', data_y)
            data_x_list.clear()
            data_y_list.clear()
            batch_count += 1
            print(f"Processed batch {batch_count} for folder {folder}")

    if data_x_list:
        data_x = np.array(data_x_list)
        data_y = np.array(data_y_list)
        data_x = np.transpose(data_x, (0, 1, 4, 2, 3))
        data_y = np.transpose(data_y, (0, 1, 4, 2, 3))
        save_to_h5(f'dataset_{folder}_sorted_test.h5', f'X_{folder}', data_x)
        save_to_h5(f'dataset_{folder}_sorted_test.h5', f'Y_{folder}', data_y)


final_dataset = {}
# for folder in folders:
with h5py.File(f'dataset_{folder}_sorted_test.h5', 'r') as h5f:
        final_dataset[f'X_{folder}'] = np.array(h5f[f'X_{folder}'])
        final_dataset[f'Y_{folder}'] = np.array(h5f[f'Y_{folder}'])

with open('100_gif_augumentation.pkl', 'wb') as f:
    pickle.dump(final_dataset, f)
