In [11]:
!pip install einops imageio decord

Collecting imageio
  Downloading imageio-2.33.1-py3-none-any.whl (313 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m313.3/313.3 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
Installing collected packages: imageio
Successfully installed imageio-2.33.1


In [14]:
!pip install imageio[ffmpeg]

Collecting imageio-ffmpeg
  Downloading imageio_ffmpeg-0.4.9-py3-none-macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl (22.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m22.5/22.5 MB[0m [31m15.6 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: imageio-ffmpeg
Successfully installed imageio-ffmpeg-0.4.9


In [5]:
!pwd

/Users/l_y_o/Work/AnimateDiff/data


In [1]:
%cd ../

/Users/l_y_o/Work/AnimateDiff


In [5]:
import os, io, csv, math, random
import numpy as np
from einops import rearrange
from decord import VideoReader


from animatediff.data.gif_reader import GifReader

import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
def zero_rank_print(*args):
    print(*args)

from torchvision.transforms import v2


# import RandomIoUCrop
# transforms = RandomIoUCrop()

def tranform_image(img):
    try:
        w, h = img.shape[-2:]
        if w == h:
            pass
        elif w > h:
            transform = v2.Pad((0, int((w - h) / 2), 0, int((w - h) / 2)), padding_mode='edge')
            img = transform(img)
        else:
            transform = v2.Pad((int((h - w) / 2), 0, int((h - w) / 2), 0), padding_mode='edge')
            img = transform(img)
        
    
        return img
    except Exception as ex:
        print(f"eror: {ex}, img:{img.shape}")
        raise ex


class SquarePad:
    def __call__(self, image):
        return tranform_image(image)



class WebVid10M(Dataset):
    def __init__(
            self,
            csv_path, video_folder,
            sample_size=256, sample_stride=4, sample_n_frames=16,
            is_image=False,
        ):
        zero_rank_print(f"loading annotations from {csv_path} ...")
        with open(csv_path, 'r') as csvfile:
            self.dataset = list(csv.DictReader(csvfile))
        self.length = len(self.dataset)
        zero_rank_print(f"data scale: {self.length}")

        self.video_folder    = video_folder
        self.sample_stride   = sample_stride
        self.sample_n_frames = sample_n_frames
        self.is_image        = is_image
        
        sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
        self.pixel_transforms = transforms.Compose([
            SquarePad(),
            transforms.RandomHorizontalFlip(),
            transforms.Resize(sample_size[0], antialias=True),
            transforms.CenterCrop(sample_size),
            transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
        ])
    
    def get_batch(self, idx):
        video_dict = self.dataset[idx]
        videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']

        file_ext = 'mp4'
        if 'file_ext' in video_dict:
            file_ext = video_dict['file_ext']

        if file_ext == 'mp4':
            video_dir    = os.path.join(self.video_folder,  f"{videoid}.mp4")
            video_reader = VideoReader(video_dir)
        else:
            video_dir    = os.path.join(self.video_folder, page_dir, f"{videoid}.{file_ext}")
            video_reader = GifReader(video_dir)
        video_length = len(video_reader)
        
        if not self.is_image:
            clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
            start_idx   = random.randint(0, video_length - clip_length)
            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
        else:
            batch_index = [random.randint(0, video_length - 1)]

        if file_ext == 'mp4':
            pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
        else:
            pixel_values = torch.from_numpy(video_reader.get_batch(batch_index)).permute(0, 3, 1, 2).contiguous()
        pixel_values = pixel_values / 255.
        del video_reader

        if self.is_image:
            pixel_values = pixel_values[0]
        
        return pixel_values, name, page_dir, videoid

    def __len__(self):
        return self.length

    def __getitem__(self, idx):
        while True:
            try:
                pixel_values, name, page_dir, videoid = self.get_batch(idx)
                break

            except Exception as e:
                print(f"error loading {idx}: {e}")
                idx = random.randint(0, self.length-1)

        pixel_values = self.pixel_transforms(pixel_values)
        sample = dict(pixel_values=pixel_values, text=name , page_dir=page_dir, videoid=videoid)
        return sample



In [6]:
import torch
import imageio
import torchvision
def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
    videos = rearrange(videos, "b c t h w -> t b c h w")
    outputs = []
    for x in videos:
        x = torchvision.utils.make_grid(x, nrow=n_rows)
        x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
        if rescale:
            x = (x + 1.0) / 2.0  # -1,1 -> 0,1
        x = (x * 255).numpy().astype(np.uint8)
        outputs.append(x)

    os.makedirs(os.path.dirname(path), exist_ok=True)
    imageio.mimsave(path, outputs, fps=fps)



In [7]:
#from animatediff.utils.util import save_videos_grid

dataset = WebVid10M(
    csv_path="./data/animated_diff_ds_1704397657.csv",
    video_folder="./data/",
    sample_size=256,
    sample_stride=4, sample_n_frames=16,
    is_image=False,
)


dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=0,)
for idx, batch in enumerate(dataloader):
    print(batch["pixel_values"].shape, len(batch["text"]))
    for i in range(batch["pixel_values"].shape[0]):
        title = batch["text"]
        dir = batch['page_dir']
        fn = batch['videoid']
        save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)
        print(f"{dir}/{fn}: as {idx}-{i}.mp4")

    if idx > 3:
        break


loading annotations from ./data/animated_diff_ds_1704397657.csv ...
data scale: 117
torch.Size([4, 16, 3, 256, 256]) 4
['smoke', 'smoke', 'smoke', 'smoke']/['JV24kJi', '3fe3291bf506e4c66dddd3b6f51b7656_w200', 'olga-ryzhychenko-31', 'EEPpNs']: as 0-0.mp4
['smoke', 'smoke', 'smoke', 'smoke']/['JV24kJi', '3fe3291bf506e4c66dddd3b6f51b7656_w200', 'olga-ryzhychenko-31', 'EEPpNs']: as 0-1.mp4
['smoke', 'smoke', 'smoke', 'smoke']/['JV24kJi', '3fe3291bf506e4c66dddd3b6f51b7656_w200', 'olga-ryzhychenko-31', 'EEPpNs']: as 0-2.mp4
['smoke', 'smoke', 'smoke', 'smoke']/['JV24kJi', '3fe3291bf506e4c66dddd3b6f51b7656_w200', 'olga-ryzhychenko-31', 'EEPpNs']: as 0-3.mp4
torch.Size([4, 16, 3, 256, 256]) 4
['smoke', 'smoke', 'smoke', 'smoke']/['giphy', 'jeong-h-lee-ezgif-com-resize', 'ezgif.com-video-to-gif+(14)', '761391_6bc8c']: as 1-0.mp4
['smoke', 'smoke', 'smoke', 'smoke']/['giphy', 'jeong-h-lee-ezgif-com-resize', 'ezgif.com-video-to-gif+(14)', '761391_6bc8c']: as 1-1.mp4
['smoke', 'smoke', 'smoke', 's