In [1]:
import gc
import math
from PIL import Image

import torch
import numpy as np
from torchvision.transforms.functional import to_tensor, center_crop
from encoded_video import EncodedVideo, write_video
from IPython.display import Video

In [2]:
print("🧠 Loading Model...")
model = torch.hub.load(
    "AK391/animegan2-pytorch:main",
    "generator",
    pretrained=True,
    device="cpu",
    progress=True,
)

🧠 Loading Model...


Using cache found in /Users/marlenemhangami/.cache/torch/hub/AK391_animegan2-pytorch_main


In [6]:
def face2paint(model: torch.nn.Module, img: Image.Image, size: int = 512, device: str = 'cpu'):
    w, h = img.size
    s = min(w, h)
    img = img.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    img = img.resize((size, size), Image.LANCZOS)

    with torch.no_grad():
        input = to_tensor(img).unsqueeze(0) * 2 - 1
        output = model(input.to(device)).cpu()[0]

        output = (output * 0.5 + 0.5).clip(0, 1) * 255.0

    return output


# This function is taken from pytorchvideo!
def uniform_temporal_subsample(x: torch.Tensor, num_samples: int, temporal_dim: int = -3) -> torch.Tensor:
    t = x.shape[temporal_dim]
    assert num_samples > 0 and t > 0
    # Sample by nearest neighbor interpolation if num_samples > t.
    indices = torch.linspace(0, t - 1, num_samples)
    indices = torch.clamp(indices, 0, t - 1).long()
    return torch.index_select(x, temporal_dim, indices)


# This function is taken from pytorchvideo!
def short_side_scale(
    x: torch.Tensor,
    size: int,
    interpolation: str = "bilinear",
) -> torch.Tensor:
    assert len(x.shape) == 4
    assert x.dtype == torch.float32
    c, t, h, w = x.shape
    if w < h:
        new_h = int(math.floor((float(h) / w) * size))
        new_w = size
    else:
        new_h = size
        new_w = int(math.floor((float(w) / h) * size))

    return torch.nn.functional.interpolate(x, size=(new_h, new_w), mode=interpolation, align_corners=False)


def inference_step(vid, start_sec, duration, out_fps):

    clip = vid.get_clip(start_sec, start_sec + duration)
    video_arr = torch.from_numpy(clip['video']).permute(3, 0, 1, 2)
    audio_arr = np.expand_dims(clip['audio'], 0)
    audio_fps = None if not vid._has_audio else vid._container.streams.audio[0].sample_rate

    x = uniform_temporal_subsample(video_arr, duration * out_fps)
    x = center_crop(short_side_scale(x, 512), 512)
    x /= 255.0
    x = x.permute(1, 0, 2, 3)
    with torch.no_grad():
        output = model(x.to('cpu')).detach().cpu()
        output = (output * 0.5 + 0.5).clip(0, 1) * 255.0
        output_video = output.permute(0, 2, 3, 1).numpy()

    return output_video, audio_arr, out_fps, audio_fps


def predict_fn(filepath, start_sec, duration):
    out_fps = 18
    vid = EncodedVideo.from_path(filepath)
    for i in range(duration):
        print(f"🖼️ Processing step {i + 1}/{duration}...")
        video, audio, fps, audio_fps = inference_step(vid=vid, start_sec=i + start_sec, duration=1, out_fps=out_fps)
        gc.collect()
        if i == 0:
            video_all = video
            audio_all = audio
        else:
            video_all = np.concatenate((video_all, video))
            audio_all = np.hstack((audio_all, audio))

    print(f"💾 Writing output video...")
    
    try:
        write_video('out.mp4', video_all, fps=fps, audio_array=audio_all, audio_fps=audio_fps, audio_codec='aac')
    except:
        print("❌ Error when writing with audio...trying without audio")
        write_video('out.mp4', video_all, fps=fps)

    print(f"✅ Done!")
    del video_all
    del audio_all

    return 'out.mp4'

In [7]:
predict_fn('myvideo.mp4', start_sec=0, duration=6)

🖼️ Processing step 1/6...
🖼️ Processing step 2/6...
🖼️ Processing step 3/6...
🖼️ Processing step 4/6...
🖼️ Processing step 5/6...
🖼️ Processing step 6/6...
💾 Writing output video...
❌ Error when writing with audio...trying without audio
✅ Done!


'out.mp4'

In [9]:
Video("myvideo.mp4", height=300)

In [8]:
Video("out.mp4",height=300, width=500)