In [None]:
!git clone https://github.com/harlanhong/ICCV2023-MCNET.git
%cd ICCV2023-MCNET
!pip install -r requirements.txt
!pip install gdown
!pip install  einops torchdiffeq scikit-image==0.18.0
!pip install imageio[ffmpeg]
!pip install imageio[pyav]
!pip install aiogram accelerate peft  git+https://github.com/hukkelas/DSFD-Pytorch-Inference.git git+https://github.com/huggingface/diffusers.git@ee35f1914af802efd4945f47232e8501276d1662

In [1]:
#%cd ICCV2023-MCNET if restart notebook

/kaggle/working/ICCV2023-MCNET


In [2]:
import matplotlib
matplotlib.use('Agg')
import os, sys
import yaml
from argparse import ArgumentParser, Namespace
from tqdm import tqdm
import modules.generator as GEN
import imageio
import numpy as np
from skimage.transform import resize
from skimage import img_as_ubyte
import torch
from sync_batchnorm import DataParallelWithCallback
from modules.keypoint_detector import KPDetector
from animate import normalize_kp
from scipy.spatial import ConvexHull
from collections import OrderedDict
import pdb
if sys.version_info[0] < 3:
    raise Exception("You must use Python 3 or higher. Recommended version is Python 3.7")

def load_checkpoints(config_path, checkpoint_path, cpu=False):

    with open(config_path) as f:
        config = yaml.load(f, Loader=yaml.Loader)
    if opt.kp_num != -1:
        config['model_params']['common_params']['num_kp'] = opt.kp_num
    generator = getattr(GEN, opt.generator)(**config['model_params']['generator_params'],**config['model_params']['common_params'],**{'mbunit':opt.mbunit,'mb_spatial':opt.mb_spatial,'mb_channel':opt.mb_channel})
    if not cpu:
        generator.cuda()
    kp_detector = KPDetector(**config['model_params']['kp_detector_params'],
                             **config['model_params']['common_params'])
    if not cpu:
        kp_detector.cuda()
    
    if cpu:
        checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu'))
    else:
        checkpoint = torch.load(checkpoint_path,map_location="cuda:0")
    
    ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['generator'].items())
    generator.load_state_dict(ckp_generator)
    ckp_kp_detector = OrderedDict((k.replace('module.',''),v) for k,v in checkpoint['kp_detector'].items())
    kp_detector.load_state_dict(ckp_kp_detector)
    
    if not cpu:
        generator = DataParallelWithCallback(generator)
        kp_detector = DataParallelWithCallback(kp_detector)

    generator.eval()
    kp_detector.eval()
    
    return generator, kp_detector


def make_animation(source_image, driving_video, generator, kp_detector, relative=True, adapt_movement_scale=True, cpu=False):
    sources = []
    drivings = []
    with torch.no_grad():
        predictions = []
        source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2)
        if not cpu:
            source = source.cuda()
        driving = torch.tensor(np.array(driving_video)[np.newaxis].astype(np.float32)).permute(0, 4, 1, 2, 3)

        kp_source = kp_detector(source)
        if not cpu:
            kp_driving_initial = kp_detector(driving[:, :, 0].cuda())
        else:
            kp_driving_initial = kp_detector(driving[:, :, 0])
        for frame_idx in tqdm(range(driving.shape[2])):
            driving_frame = driving[:, :, frame_idx]
            if not cpu:
                driving_frame = driving_frame.cuda()
            kp_driving = kp_detector(driving_frame)
            kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
                                   kp_driving_initial=kp_driving_initial, use_relative_movement=relative,
                                   use_relative_jacobian=relative, adapt_movement_scale=adapt_movement_scale)
            out = generator(source, kp_source=kp_source, kp_driving=kp_norm)
            drivings.append(np.transpose(driving_frame.data.cpu().numpy(), [0, 2, 3, 1])[0])
            sources.append(np.transpose(source.data.cpu().numpy(), [0, 2, 3, 1])[0])
            predictions.append(np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0])
    return sources, drivings, predictions

opt = Namespace(config='config/vox-256.yaml',
                checkpoint='/kaggle/input/checkpoint/00000099-checkpoint.pth.tar',
                source_image='img.jpg',
                driving_video='output.mp4',
                result_video='result.mp4',
                relative=True,
                adapt_scale=True,
                generator='Unet_Generator_keypoint_aware',
                kp_num=15,
                mb_channel=512,
                mb_spatial=32,
                mbunit='ExpendMemoryUnit',
                memsize=1,
                find_best_frame=False,
                best_frame=None,
                cpu=False)

generator, kp_detector = load_checkpoints(config_path=opt.config, checkpoint_path=opt.checkpoint, cpu=opt.cpu)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


In [13]:
def generate_animation(source_image, driving_video, output_video):    
    source_image = imageio.imread(source_image)
    reader = imageio.get_reader(driving_video)
    fps = reader.get_meta_data()['fps']
    driving_video = []
    try:
        for im in reader:
            driving_video.append(im)
    except RuntimeError:
        pass
    reader.close()

    source_image = resize(source_image, (256, 256))[..., :3]
    driving_video = [resize(frame, (256, 256))[..., :3] for frame in driving_video]

    sources, drivings, predictions = make_animation(source_image, driving_video, generator, kp_detector, relative=opt.relative, adapt_movement_scale=opt.adapt_scale, cpu=opt.cpu)
    imageio.mimsave(output_video, [img_as_ubyte(p) for p in predictions], fps=fps)

In [4]:
from PIL import Image
from io import BytesIO
import os
import logging
import cv2
import face_detection
import numpy as np
import asyncio
from collections import deque

from aiogram.types import FSInputFile, InputFile
from aiogram.enums import ParseMode
from aiogram import Bot, Dispatcher, Router, types
from aiogram.filters import CommandStart, Command
from aiogram import F

import torch
from diffusers import StableDiffusionInstructPix2PixPipeline, LCMScheduler


photo_storage = {}


def load_pix2pix_model(path, device):
    pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
        path,
        torch_dtype=torch.float16,
        safety_checker=None,
        # local_files_only=True,
        # cache_dir=mode
    )

    pipeline.scheduler = LCMScheduler.from_config(pipeline.scheduler.config)

    pipeline.load_lora_weights(
        pretrained_model_name_or_path_or_dict="latent-consistency/lcm-lora-sdv1-5",
        weight_name="pytorch_lora_weights.safetensors",
        # cache_dir="lcm",
        # local_files_only=True
    )

    pipeline.generator = torch.Generator(device=device).manual_seed(42)

    pipeline.load_ip_adapter(
        pretrained_model_name_or_path_or_dict="h94/IP-Adapter",
        subfolder="models",
        weight_name="ip-adapter_sd15.bin",
        # local_files_only=True,
        # cache_dir="adapter"
    )
    pipeline.set_ip_adapter_scale(1)

    pipeline = pipeline.to(device)
    return pipeline
    


model_flowers_id = "misshimichka/pix2pix_people_flowers_v2"
model_cat_id = "misshimichka/pix2pix_cat_ears"
model_clown_id = "misshimichka/pix2pix_clown_faces"
model_butterfly_id = "misshimichka/pix2pix_butterflies"
model_pink_id = "misshimichka/pix2pix_pink_hair"
model_id = "misshimichka/instructPix2PixCartoon_4860_ckpt"


    
models = {
    "default": load_pix2pix_model(model_id, 'cuda:0'),
    "flowers": load_pix2pix_model(model_flowers_id, 'cuda:0'),
    "cat": load_pix2pix_model(model_cat_id, 'cuda:0'),
    "butterfly": load_pix2pix_model(model_butterfly_id, 'cuda:1'),
    "clown": load_pix2pix_model(model_clown_id, 'cuda:1'),
    "pink": load_pix2pix_model(model_pink_id, 'cuda:1')
}

detector = face_detection.build_detector(
  "DSFDDetector", confidence_threshold=.5, nms_iou_threshold=.3)


2024-05-15 09:58:33.467607: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-15 09:58:33.467664: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-15 09:58:33.474717: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

The config attributes {'algorithm_type': 'dpmsolver++', 'lower_order_final': True, 'skip_prk_steps': True, 'solver_order': 2, 'solver_type': 'midpoint'} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.
  return self.fget.__get__(instance, owner)()


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

The config attributes {'algorithm_type': 'dpmsolver++', 'lower_order_final': True, 'skip_prk_steps': True, 'solver_order': 2, 'solver_type': 'midpoint'} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

The config attributes {'algorithm_type': 'dpmsolver++', 'lower_order_final': True, 'skip_prk_steps': True, 'solver_order': 2, 'solver_type': 'midpoint'} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

The config attributes {'algorithm_type': 'dpmsolver++', 'lower_order_final': True, 'skip_prk_steps': True, 'solver_order': 2, 'solver_type': 'midpoint'} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

The config attributes {'algorithm_type': 'dpmsolver++', 'lower_order_final': True, 'skip_prk_steps': True, 'solver_order': 2, 'solver_type': 'midpoint'} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

The config attributes {'algorithm_type': 'dpmsolver++', 'lower_order_final': True, 'skip_prk_steps': True, 'solver_order': 2, 'solver_type': 'midpoint'} were passed to LCMScheduler, but are not expected and will be ignored. Please verify your scheduler_config.json configuration file.


In [15]:
!gdown 1dk973WGzD7n9NlIw3cl6J4bqHKObgPs2

Downloading...
From: https://drive.google.com/uc?id=1dk973WGzD7n9NlIw3cl6J4bqHKObgPs2
To: /kaggle/working/ICCV2023-MCNET/wow-grey.mp4
100%|█████████████████████████████████████████| 295k/295k [00:00<00:00, 107MB/s]


In [20]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows * cols
    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

def crop_img(im):
  if isinstance(im, Image.Image):
    im = cv2.cvtColor(np.array(im), cv2.COLOR_RGB2BGR)
  elif isinstance(im, str) and os.path.exists(im):
    im = cv2.imread(im)
    im = cv2.resize(im, (512, 512))
    im = im[:, :, ::-1]
  else:
    return None

  detections = detector.detect(im)

  if detections.shape[0] != 1:
    return None
  xmin, ymin, xmax, ymax, _ = [int(i) + 1 for i in detections.tolist()[0]]
  ymin = max(0, ymin - 50)
  ymax = min(512, ymax + 50)
  xmin = max(0, xmin - 50)
  xmax = min(xmax + 50, 512)
  cropped_img = im[ymin:ymax, xmin:xmax]

  im_pil = Image.fromarray(cropped_img)
  img = im_pil.resize((512, 512))
  return img


def generate(original_image, mode):
    print("Generating image...")

    cropped_image = crop_img('our_img.jpg')

    if not cropped_image:
        return None

    edited_image = models[mode](
        prompt="Refashion the photo into a sticker.",
        image=cropped_image,
        ip_adapter_image=cropped_image,
        num_inference_steps=4,
        image_guidance_scale=1,
        guidance_scale=2,
        num_images_per_prompt=3
    ).images
    
    torch.cuda.empty_cache()

    
    for idx, img in enumerate(edited_image):
      img.save(f"result{idx}.webp", "webp")

    return image_grid(edited_image, 3, 1)


def get_styles_markup():
    default_btn = types.InlineKeyboardButton(text="Default 🤫🧏‍", callback_data="default")
    flowers_btn = types.InlineKeyboardButton(text="Flowers 🌸🌺", callback_data="flowers")
    cat_btn = types.InlineKeyboardButton(text="Cat ears 🐈🐱", callback_data="cat")
    butterfly_btn = types.InlineKeyboardButton(text="Butterflies 🦋🌈", callback_data="butterfly")
    clown_btn = types.InlineKeyboardButton(text="Clown 🤡🤣", callback_data="clown")
    pink_btn = types.InlineKeyboardButton(text="Pink hair 🩷✨", callback_data="pink")
    animate_btn = types.InlineKeyboardButton(text="Animate 1️⃣4️⃣8️⃣8️⃣", callback_data="animate")
    markup = types.InlineKeyboardMarkup(
        inline_keyboard=[[default_btn, flowers_btn],
                         [cat_btn, butterfly_btn],
                         [clown_btn, pink_btn],
                        [animate_btn]]
    )
    return markup

async def handle_selection(message: types.Message):
    index = int(message.text) - 1
    chat_id = message.chat.id
    await bot.send_sticker(
                chat_id=chat_id,
                sticker=FSInputFile(f"result{index}.webp"),
                emoji="🎁",
            )
    await bot.send_photo(chat_id, photo=FSInputFile(path=f"result{index}.webp"), caption="1488")


async def handle_start(message: types.Message):
    await message.reply("Welcome to Stickerify bot! 🥶\nSend me a photo and I will create your own sticker 👠")


async def handle_photo(message: types.Message):
    photo = message.photo[-1]
    file_id = photo.file_id
    chat_id = message.chat.id
    if chat_id not in photo_storage.keys():
        photo_storage[chat_id] = deque()
    photo_storage[chat_id].append(file_id)

    await message.reply("Choose your sticker style:", reply_markup=get_styles_markup())


async def handle_debug(message: types.Message):
    print(photo_storage)


async def process_stickerify_callback(callback_query: types.CallbackQuery):
    chat_id = callback_query.from_user.id
    sticker_style = callback_query.data
    print(sticker_style)
    if chat_id in photo_storage.keys() and len(photo_storage[chat_id]) > 0:
        file_id = photo_storage[chat_id].popleft()
        try:
            file = await bot.get_file(file_id)
            file_path = file.file_path
            contents = await bot.download_file(file_path)

            img = Image.open(BytesIO(contents.getvalue()))
            img.save('our_img.jpg')

            if sticker_style != 'animate':
                await bot.send_message(chat_id, "Started generating your sticker! 👨‍🔬")
                stickerified_images = generate(img, sticker_style)
                if not stickerified_images:
                    await bot.send_message(chat_id, "Unfortunately, we couldn't find a human face on your "
                                                    "photo, or there were too many of them 😰 Please, "
                                                    "send another photo.")
                    return



                stickerified_images.save(f"{chat_id}_result.jpeg")
                await bot.send_photo(chat_id, photo=FSInputFile(path=f"{chat_id}_result.jpeg"), caption="Type number from 1 to 9 to pick up sticker.")
            else:
                await bot.send_message(chat_id, "Started animating your sticker! 👨‍🔬")
                generate_animation('our_img.jpg', 'wow-grey.mp4', 'result.mp4')
                torch.cuda.empty_cache()
                await bot.send_video(chat_id, FSInputFile('result.mp4'))
                


        except Exception as e:
            print(e)
            await bot.send_message(chat_id, f"Sorry, an error occurred.\n{e}")

    else:
        await bot.send_message(chat_id, "We couldn't find your photo. Please send it again.")


def setup_handlers(router: Router):
    router.message.register(handle_selection, F.text.lower().in_(['1', '2', '3', '4', '5', '6', '7', '8', '9']))
    router.message.register(handle_start, CommandStart())
    router.message.register(handle_photo, F.content_type.in_({'photo'}))
    router.message.register(handle_debug, Command("debug"))
    router.callback_query.register(process_stickerify_callback)


bot = Bot("5818667076:AAEJw19A7hTV6XjZOu38SYS3T9Jm8ZOicZ8", parse_mode=ParseMode.HTML)


async def main():
    router = Router()
    setup_handlers(router)

    dispatcher = Dispatcher()
    dispatcher.include_router(router)
    await dispatcher.start_polling(bot)

Use `default=DefaultBotProperties(...)` instead.
  bot = Bot("5818667076:AAEJw19A7hTV6XjZOu38SYS3T9Jm8ZOicZ8", parse_mode=ParseMode.HTML)


In [21]:
await main()

animate


  source_image = imageio.imread(source_image)
100%|██████████| 26/26 [00:03<00:00,  8.04it/s]
