<a href="https://colab.research.google.com/github/detektor777/colab_list/blob/main/EZ_DAIN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[google drive](https://drive.google.com/drive/my-drive)

In [None]:
INPUT_FILEPATH = "input.mp4" #@param{type:"string"}

OUTPUT_FILE_PATH = "output.mp4" #@param{type:"string"}

TARGET_FPS = 60 #@param{type:"number"}

FRAME_INPUT_DIR = '/content/gdrive/MyDrive/dain_input' #@param{type:"string"}

FRAME_OUTPUT_DIR = '/content/gdrive/MyDrive/dain_output' #@param{type:"string"}

START_FRAME = 1 #@param{type:"number"}

END_FRAME = -1 #@param{type:"number"}

SEAMLESS = False #@param{type:"boolean"}

AUTO_REMOVE = False #@param{type:"boolean"}

In [None]:
#@title Connect Google Drive
from google.colab import drive
drive.mount('/content/gdrive')
print('Google Drive connected.')

In [None]:
#@title Check your current GPU
# If you are lucky, you get 16GB VRAM. If you are not lucky, you get less. VRAM is important. The more VRAM, the higher the maximum resolution will go.

# 16GB: Can handle 720p. 1080p will procude an out-of-memory error. 
# 8GB: Can handle 480p. 720p will produce an out-of-memory error.

!nvidia-smi --query-gpu=gpu_name,driver_version,memory.total --format=csv

In [None]:
#@title Setup everything. This takes a while. Just wait ~20 minutes in total.
%%capture
# Install old pytorch to avoid faulty output
%cd /content/
!wget -c https://repo.anaconda.com/miniconda/Miniconda3-4.5.4-Linux-x86_64.sh
!chmod +x Miniconda3-4.5.4-Linux-x86_64.sh
!bash ./Miniconda3-4.5.4-Linux-x86_64.sh -b -f -p /usr/local
!conda install pytorch==1.1 cudatoolkit torchvision -c pytorch -y
!conda install ipykernel -y

!pip install scipy==1.1.0
!pip install imageio
!CUDA_VISIBLE_DEVICES=0
!sudo apt-get install imagemagick imagemagick-doc
print("Finished installing dependencies.")

# Clone DAIN sources
%cd /content
!git clone -b master --depth 1 https://github.com/baowenbo/DAIN /content/DAIN
%cd /content/DAIN
!git log -1

# Building DAIN
%cd /content/DAIN/my_package/
!./build.sh
print("Building #1 done.")

# Building DAIN PyTorch correlation package.
%cd /content/DAIN/PWCNet/correlation_package_pytorch1_0
!./build.sh
print("Building #2 done.")

# Downloading pre-trained model
%cd /content/DAIN
!mkdir model_weights
!wget -O model_weights/best.pth http://vllab1.ucmerced.edu/~wenbobao/DAIN/best.pth

In [None]:
#@title Detecting FPS of input file.
%shell yes | cp -f /content/gdrive/My\ Drive/{INPUT_FILEPATH} /content/DAIN/

import os
filename = os.path.basename(INPUT_FILEPATH)

import cv2
cap = cv2.VideoCapture(f'/content/DAIN/{filename}')

fps = cap.get(cv2.CAP_PROP_FPS)
print(f"Input file has {fps} fps")

if(fps/TARGET_FPS>0.5):
  print("Define a higher fps, because there is not enough time for new frames. (Old FPS)/(New FPS) should be lower than 0.5. Interpolation will fail if you try.")

In [None]:
#@title Video to png
import cv2
import imageio
import os
import tqdm
import subprocess
import numpy as np
import time

if not os.path.isdir(FRAME_INPUT_DIR):
  os.mkdir(FRAME_INPUT_DIR)

library = "imageio" #@param ["cv2","pyav","imageio","ffmpeg","skvideo","scipy","moviepy"]
delay = 0.05 #@param [0, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5]
if (library == "cv2"):
    path = '/mnt/gdrive/MyDrive/'
    full_path = os.path.join(path, INPUT_FILEPATH)

    cmd = 'ffprobe -v error -select_streams v:0 -count_packets -show_entries stream=nb_read_packets -of csv=p=0 ' + full_path
    output = subprocess.check_output(cmd, shell=True).decode('utf-8').strip()

    cmd = 'ffprobe -v error -select_streams v:0 -show_entries stream=r_frame_rate -of csv=p=0 ' + full_path
    output = subprocess.check_output(cmd, shell=True).decode('utf-8').strip()
    vidcap = cv2.VideoCapture(full_path)
    fps = vidcap.get(cv2.CAP_PROP_FPS)
    frame_count = int(vidcap.get(cv2.CAP_PROP_FRAME_COUNT))

    duration = frame_count / fps

    print("FPS: ", fps)
    print("Duration: ", duration)
    print("Frames: ", frame_count)

    vidcap = cv2.VideoCapture(full_path)
    success, image = vidcap.read()
    
    tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
    pbar_cv2 = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
    while success:
        try:
            cv2.imwrite(f"{FRAME_INPUT_DIR}/{int(vidcap.get(cv2.CAP_PROP_POS_FRAMES)):05d}.png", image)
        except cv2.error:
            print("Error writing to disk. Retrying...")
            continue
        pbar_cv2.update(1)
        time.sleep(float(delay))
        success, image = vidcap.read()

    pbar_cv2.close()
    vidcap.release()

elif (library == "pyav"):
    !pip install av
    import av
    full_path = os.path.join(path, INPUT_FILEPATH)
    video = av.open(full_path)

    fps_value = video.streams.video[0].average_rate
    frame_count = video.streams.video[0].frames

    if fps_value is not None:
        duration = frame_count / fps_value

        print("FPS: ", fps_value)
        print("Duration: ", duration)
        print("Frames: ", frame_count)

        tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
        pbar_pyav = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
        for packet in video.demux():
            for frame in packet.decode():
                img = frame.to_ndarray(format='rgb24')
                img = img.astype('uint8')  # преобразование к типу uint8
                img_path = f"{FRAME_INPUT_DIR}/{packet.pts:05d}.png"
                imageio.imwrite(img_path, img)
                pbar_pyav.update(1)
                time.sleep(float(delay))

        pbar_pyav.close()
        video.close()
    else:
        print("Error: FPS value is None.")


elif (library == "imageio"):
    path = '/content/gdrive/MyDrive/'
    full_path = os.path.join(path, INPUT_FILEPATH)

    video = imageio.get_reader(full_path)

    fps = video.get_meta_data()['fps']
    frame_count = video.count_frames()
    duration = frame_count / fps

    print("FPS: ", fps)
    print("Duration: ", duration)
    print("Frames: ", frame_count)

    tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
    pbar_imageio = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
    for i, frame in enumerate(video):
      img_path = f"{FRAME_INPUT_DIR}/{i+1:05d}.png"
      if os.path.isfile(img_path):
          pbar_imageio.update(1)
          continue
      while True:
          try:
              imageio.imwrite(img_path, frame)
          except Exception as e:
              print(f"Error writing to disk: {str(e)}. Retrying...")
              continue
          break
      pbar_imageio.update(1)
      time.sleep(float(delay))

    pbar_imageio.close()
    video.close()

elif (library == "ffmpeg"):
    !pip install ffmpeg-python
    import ffmpeg
    path = '/mnt/gdrive/MyDrive/'
    full_path = os.path.join(path, INPUT_FILEPATH)

    probe = ffmpeg.probe(full_path)
    video_info = next(stream for stream in probe['streams'] if stream['codec_type'] == 'video')
    fps = video_info['r_frame_rate']
    duration = float(video_info['duration'])
    frame_count = int(video_info['nb_frames'])

    print("FPS: ", fps)
    print("Duration: ", duration)
    print("Frames: ", frame_count)

    pbar_ffmpeg = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
    process = (
        ffmpeg
        .input(full_path)
        .output('pipe:', format='rawvideo', pix_fmt='rgb24', qscale=0)
        .run_async(pipe_stdout=True)
    )

    for i in range(frame_count):
        try:
            raw_video = process.stdout.read(video_info['width'] * video_info['height'] * 3)
            frame = np.frombuffer(raw_video, dtype='uint8').reshape((video_info['height'], video_info['width'], 3))
            frame_path = f"{FRAME_INPUT_DIR}/{i:05d}.png"
            if os.path.isfile(frame_path):
              pbar_ffmpeg.update(1)
              continue
            imageio.imwrite(frame_path, frame)
        except Exception as e:
            print(f"Error writing to disk: {str(e)}. Retrying...")
            continue
        pbar_ffmpeg.update(1)
        time.sleep(float(delay))

    pbar_ffmpeg.close()
    process.wait()
elif (library == "skvideo"):
    !pip install scikit-video
    import skvideo.io
    path = '/mnt/gdrive/MyDrive/'
    full_path = os.path.join(path, INPUT_FILEPATH)

    video = skvideo.io.vread(full_path)

    fps_str = skvideo.io.ffprobe(full_path)['video']['@avg_frame_rate']
    fps_parts = fps_str.split('/')
    fps = int(fps_parts[0]) / int(fps_parts[1])
    frame_count = len(video)
    duration = frame_count / fps

    print("FPS: ", fps)
    print("Duration: ", duration)
    print("Frames: ", frame_count)

    tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
    outputdict = {'-q:1': '1'}
    pbar_skvideo = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
    for i, frame in enumerate(video):
        while True:
            try:
                skvideo.io.vwrite(f"{FRAME_INPUT_DIR}/{i:05d}.png", frame, outputdict=outputdict)
            except Exception as e:
                print(f"Error writing to disk: {str(e)}. Retrying...")
                continue
            break
        pbar_skvideo.update(1)
        time.sleep(float(delay))

    pbar_skvideo.close()

elif (library == "scipy"):
    import scipy.misc
    import scipy.ndimage
    
    full_path = os.path.join(path, INPUT_FILEPATH)
    video = imageio.get_reader(full_path)
    fps_value = video.get_meta_data()['fps']
    frame_count = video.get_length()

    if fps_value is not None:
        duration = frame_count / fps_value

        print("FPS: ", fps_value)

        tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
        pbar_scipy = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
        for i, im in enumerate(video):
            img_path = f"{FRAME_INPUT_DIR}/{i+1:05d}.png"
            im = Image.fromarray(im)
            im.save(img_path, quality=100)
            pbar_scipy.update(1)
            time.sleep(float(delay))

        pbar_scipy.close()
    else:
        print("Error: FPS value is None.")

elif library == "moviepy":
    from moviepy.editor import VideoFileClip
    from PIL import Image
    full_path = os.path.join(path, INPUT_FILEPATH)
    video = VideoFileClip(full_path)
    fps_value = video.fps
    frame_count = int(video.duration * fps_value)

    print("FPS: ", fps_value)

    tqdm.tqdm.set_lock(tqdm.tqdm.get_lock())
    pbar_moviepy = tqdm.tqdm(total=frame_count, ncols=100, position=0, leave=True)
    for i, frame in enumerate(video.iter_frames()):
        img_path = f"{FRAME_INPUT_DIR}/{i+1:05d}.png"
        im = Image.fromarray(frame)
        im.save(img_path, quality=100)
        pbar_moviepy.update(1)
        time.sleep(float(delay))

    pbar_moviepy.close()

In [None]:
#@title Check frames (optional)
import os
import time

frames = None
while frames is None:
    try:
        frames = [int(f.split('.')[0].replace('frame', '')) for f in os.listdir(FRAME_INPUT_DIR) if f.endswith('.png')]
    except:
        print("Error reading the list of files. Retrying in 2 seconds")
        time.sleep(2)
        import os

min_frame = min(frames)
max_frame = max(frames)
print(min_frame)
print(max_frame)

missing_frames = []
for i in range(min_frame, max_frame+1):
    if i not in frames:
        missing_frames.append(i)

if len(missing_frames) > 0:
    print(f"Missing frames: {missing_frames}")
else:
    print("All frames present")
    

In [None]:
#@title Continue

import glob
import re

# Ищем все файлы с расширением .png
png_files = glob.glob(FRAME_OUTPUT_DIR+'/'+'*.png')

if len(png_files) > 0:
    # Если есть файлы, выбираем файл с максимальным номером
    latest_file = max(png_files, key=lambda f: int(re.findall('\d+', f)[-1]))
    print(latest_file)
    # Извлекаем номер из имени файла и делим его на 1000, чтобы получить количество тысяч
    thousands = int(re.findall('\d+', latest_file)[-1]) // 1000
    thousands = thousands-1
else:
    # Если файлы отсутствуют, устанавливаем значение thousands в 1
    thousands = 1

print(f"thousands: {thousands}")


path_to_dir = '/path/to/directory' # замените на путь к нужной директории

frame_count = len([file for file in os.listdir(FRAME_INPUT_DIR) if file.endswith('.png')])

print(f'frame_count {frame_count}')


In [None]:
#@title Interpolation
%shell mkdir -p '{FRAME_OUTPUT_DIR}'
%cd /content/DAIN


!python -W ignore colab_interpolate.py --netName DAIN_slowmotion --time_step {fps/TARGET_FPS} --start_frame {thousands} --end_frame {frame_count} --frame_input_dir '{FRAME_INPUT_DIR}' --frame_output_dir '{FRAME_OUTPUT_DIR}'

In [None]:
#@title Rename
import glob
import os
import time

png_files = sorted(glob.glob(FRAME_OUTPUT_DIR+'/'+'*.png'))
for i, file in enumerate(png_files):
    new_filename = '{:05d}.png'.format(i+1)
    success = False
    while not success:
        try:
            os.rename(file, os.path.join(os.path.dirname(file), new_filename))
            success = True
        except Exception as e:
            print(f"Error renaming {file} to {new_filename}: {e}")
            time.sleep(1)  # ждем 1 секунду перед повторной попыткой

In [None]:
#@title Create output video
import cv2
import os
from tqdm.notebook import tqdm

gdrive_path = '/content/gdrive/MyDrive'
output_file = os.path.join(gdrive_path, OUTPUT_FILE_PATH)
print(output_file)

cap = cv2.VideoCapture(filename)
fps_of_video = int(cap.get(cv2.CAP_PROP_FPS))
cap.release()

img_files = os.listdir(FRAME_OUTPUT_DIR)
img_files.sort()

frame = cv2.imread(os.path.join(FRAME_OUTPUT_DIR, img_files[0]))
height, width, _ = frame.shape

fourcc = cv2.VideoWriter_fourcc(*'mp4v')
writer = cv2.VideoWriter(output_file, fourcc, TARGET_FPS, (width, height))

for img_file in tqdm(img_files, desc='Create video'):
    img_path = os.path.join(FRAME_OUTPUT_DIR, img_file)
    frame = cv2.imread(img_path)

    writer.write(frame)

writer.release()
print('created')


In [None]:
#@title [Experimental] Create video with sound
# Only run this, if the original had sound.
%cd {FRAME_OUTPUT_DIR}
%shell ffmpeg -i '/content/DAIN/{filename}' -acodec copy output-audio.aac
%shell ffmpeg -y -r {TARGET_FPS} -f image2 -pattern_type glob -i '*.png' -i output-audio.aac -shortest '/content/gdrive/My Drive/{OUTPUT_FILE_PATH}'

if (AUTO_REMOVE):
  !rm -rf {FRAME_OUTPUT_DIR}/*
  !rm -rf output-audio.aac