<a href="https://colab.research.google.com/github/eyaler/avatars4all/blob/master/incomplete_webrtc_fomm_live.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo for paper "First Order Motion Model for Image Animation"

## **Live webcam in the browser**

### Made just a little bit more accessible by Eyal Gruss (eyalgruss@gmail.com)

##### Original project: https://aliaksandrsiarohin.github.io/first-order-model-website

##### Original notebook: https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb

##### Faceswap notebook: https://colab.research.google.com/github/AliaksandrSiarohin/motion-cosegmentation/blob/master/part_swap.ipynb

##### Notebook with video enhancement: https://colab.research.google.com/github/tg-bomze/Face-Image-Motion-Model/blob/master/Face_Image_Motion_Model_(Photo_2_Video)_Eng.ipynb

##### Avatarify - a live vesrsion (requires local installation): https://github.com/alievk/avatarify

##### This live Colab solution is heavily based on the WebRTC implementation: https://github.com/thefonseca/colabrtc, https://github.com/aiortc/aiortc

##### Other WebRTC implementations: https://github.com/l4rz/first-order-model/tree/master/webrtc, https://gist.github.com/myagues/aac0c597f8ad0fa7ebe7d017b0c5603b

#### **Stuff I made**:
##### Avatars4all repository: https://github.com/eyaler/avatars4all
##### Notebook for talking head model: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_bibi.ipynb
##### Notebook for full body models: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_fufu.ipynb
##### Notebook for live webcam in the browser: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_live.ipynb
##### Notebook for Wav2Lip audio based lip syncing: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/melaflefon.ipynb
##### List of more generative tools: https://j.mp/generativetools

In [None]:
#@title Setup
#@markdown For best performance make sure the output shows Tesla P100 or V100. Otherwise you can do: Runtime -> Reset all runtimes

machine = !nvidia-smi -L
print(machine)

%cd /content
!git clone --depth 1 https://github.com/eyaler/first-order-model
!wget --no-check-certificate -nc https://openavatarify.s3.amazonaws.com/weights/vox-adv-cpk.pth.tar -P /content
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/vox-adv-cpk.pth.tar

!mkdir -p /root/.cache/torch/hub/checkpoints
%cd /root/.cache/torch/hub/checkpoints
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/s3fd-619a316812.pth
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/2DFAN4-11f355bf06.pth.tar
%cd /content

!pip install imageio==2.9.0
!pip install git+https://github.com/1adrianb/face-alignment@v1.0.1

!git clone -n https://github.com/thefonseca/colabrtc
%cd /content/colabrtc
!git checkout 90d14e0
!pip install fire
!pip install av
!pip install aiortc
!pip install nest_asyncio

import sys
sys.path.extend(['/content/colabrtc/colabrtc','/content/first-order-model'])

print(machine)

In [None]:
#@title Get the Avatar images from the web
#@markdown 1. You can change the URLs to your **own** stuff!
#@markdown 2. Alternatively, you can upload **local** files in the next cell

image1_url = 'https://www.beat.com.au/wp-content/uploads/2018/05/ilana.jpg' #@param {type:"string"}
image2_url = 'https://img.zeit.de/zeit-magazin/2017-03/marina-abramovic-performance-kuenstlerin-the-cleaner-monografie-oevre-bilder/marina-abramovic-performance-kuenstlerin-the-cleaner-monografie-oevre-10.jpg/imagegroup/original__620x620__desktop' #@param {type:"string"}
image3_url = 'https://i.pinimg.com/originals/27/86/58/2786580674b7c9b20ead54f53bf0be9e.jpg' #@param {type:"string"}

if image1_url:
  !wget "$image1_url" -O /content/image1

if image2_url:
  !wget "$image2_url" -O /content/image2

if image3_url:
  !wget "$image3_url" -O /content/image3

In [None]:
#@title Optionally upload local Avatar images { run: "auto" }
manually_upload_images = False #@param {type:"boolean"}
if manually_upload_images:
  from google.colab import files
  import shutil

  %cd /content/sample_data
  try:
    uploaded = files.upload()
  except Exception as e:
    %cd /content
    raise e

  for i,fn in enumerate(uploaded, start=1):
    shutil.move('/content/sample_data/'+fn, '/content/image%d'%i)
    if i==3:
      break


In [None]:
#@title Prepare assets
center_image1_to_head = True #@param {type:"boolean"}
crop_image1_to_head = False #@param {type:"boolean"}
image1_crop_expansion_factor = 2.5 #@param {type:"number"}

center_image2_to_head = True #@param {type:"boolean"}
crop_image2_to_head = True #@param {type:"boolean"}
image2_crop_expansion_factor = 2.5 #@param {type:"number"}

center_image3_to_head = True #@param {type:"boolean"}
crop_image3_to_head = False #@param {type:"boolean"}
image3_crop_expansion_factor = 2.5 #@param {type:"number"}

center_image_to_head = (center_image1_to_head, center_image2_to_head, center_image3_to_head)
crop_image_to_head = (crop_image1_to_head, crop_image2_to_head, crop_image3_to_head)
image_crop_expansion_factor = (image1_crop_expansion_factor, image2_crop_expansion_factor, image3_crop_expansion_factor)

import imageio
import numpy as np
from google.colab.patches import cv2_imshow
from skimage.transform import resize

import face_alignment
import torch

if not hasattr(face_alignment.utils, '_original_transform'):
    face_alignment.utils._original_transform = face_alignment.utils.transform

def patched_transform(point, center, scale, resolution, invert=False):
    return face_alignment.utils._original_transform(
        point, center, torch.tensor(scale, dtype=torch.float32), torch.tensor(resolution, dtype=torch.float32), invert)

face_alignment.utils.transform = patched_transform

try:
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
                                      device='cuda')
except Exception:
  !rm -rf /root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth
  !rm -rf /root/.cache/torch/hub/checkpoints/2DFAN4-11f355bf06.pth.tar
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
                                      device='cuda')

def create_bounding_box(target_landmarks, expansion_factor=1):
    target_landmarks = np.array(target_landmarks)
    x_y_min = target_landmarks.reshape(-1, 68, 2).min(axis=1)
    x_y_max = target_landmarks.reshape(-1, 68, 2).max(axis=1)
    expansion_factor = (expansion_factor-1)/2
    bb_expansion_x = (x_y_max[:, 0] - x_y_min[:, 0]) * expansion_factor
    bb_expansion_y = (x_y_max[:, 1] - x_y_min[:, 1]) * expansion_factor
    x_y_min[:, 0] -= bb_expansion_x
    x_y_max[:, 0] += bb_expansion_x
    x_y_min[:, 1] -= bb_expansion_y
    x_y_max[:, 1] += bb_expansion_y
    return np.hstack((x_y_min, x_y_max-x_y_min))

def fix_dims(im):
    if im.ndim == 2:
        im = np.tile(im[..., None], [1, 1, 3])
    return im[...,:3]

def get_crop(im, center_face=True, crop_face=True, expansion_factor=1, landmarks=None):
    im = fix_dims(im)
    if (center_face or crop_face) and not landmarks:
        landmarks = fa.get_landmarks_from_image(im)
    if (center_face or crop_face) and landmarks:
        rects = create_bounding_box(landmarks, expansion_factor=expansion_factor)
        x0,y0,w,h = sorted(rects, key=lambda x: x[2]*x[3])[-1]
        if crop_face:
            s = max(h, w)
            x0 += (w-s)//2
            x1 = x0 + s
            y0 += (h-s)//2
            y1 = y0 + s
        else:
            img_h,img_w = im.shape[:2]
            img_s = min(img_h,img_w)
            x0 = min(max(0, x0+(w-img_s)//2), img_w-img_s)
            x1 = x0 + img_s
            y0 = min(max(0, y0+(h-img_s)//2), img_h-img_s)
            y1 = y0 + img_s
    else:
        h,w = im.shape[:2]
        s = min(h,w)
        x0 = (w-s)//2
        x1 = x0 + s
        y0 = (h-s)//2
        y1 = y0 + s
    return int(x0),int(x1),int(y0),int(y1)

def pad_crop_resize(im, x0=None, x1=None, y0=None, y1=None, new_h=256, new_w=256):
    im = fix_dims(im)
    h,w = im.shape[:2]
    if x0 is None:
      x0 = 0
    if x1 is None:
      x1 = w
    if y0 is None:
      y0 = 0
    if y1 is None:
      y1 = h
    if x0<0 or x1>w or y0<0 or y1>h:
        im = np.pad(im, pad_width=[(max(-y0,0),max(y1-h,0)),(max(-x0,0),max(x1-w,0)),(0,0)], mode='edge')
    im = im[max(y0,0):y1-min(y0,0),max(x0,0):x1-min(x0,0)]
    if new_h is not None or new_w is not None:
        im = resize(im, (im.shape[0] if new_h is None else new_h, im.shape[1] if new_w is None else new_w))
    return im

source_image = []
orig_image = []
for i in range(3):
    img = imageio.imread('/content/image%d'%(i+1))
    img = pad_crop_resize(img, *get_crop(img, center_face=center_image_to_head[i], crop_face=crop_image_to_head[i], expansion_factor=image_crop_expansion_factor[i]), new_h=None, new_w=None)
    orig_image.append(img)
    source_image.append(resize(img, (256,256)))
num_avatars = len(source_image)

cv2_imshow(np.hstack(source_image)[...,::-1]*255)

In [None]:
#@title Modify signaling.py

%%writefile /content/colabrtc/colabrtc/signaling.py
import json
import logging
import random
import IPython
import asyncio

from aiortc import RTCIceCandidate, RTCSessionDescription
from aiortc.contrib.signaling import object_from_string, object_to_string, BYE
from aiortc.contrib.signaling import ApprtcSignaling

from server import FilesystemRTCServer

try:
    import aiohttp
    import websockets
except ImportError:  # pragma: no cover
    aiohttp = None
    websockets = None

logger = logging.getLogger("colabrtc.signaling")

try:
    from google.colab import output
except ImportError:
    output = None
    logger.info('google.colab not available')


class ColabApprtcSignaling(ApprtcSignaling):
    def __init__(self, room=None, javacript_callable=False):
        super().__init__(room)

        self._javascript_callable = javacript_callable

        if output and javacript_callable:
            output.register_callback(f'{room}.colab.signaling.connect', self.connect_sync)
            output.register_callback(f'{room}.colab.signaling.send', self.send_sync)
            output.register_callback(f'{room}.colab.signaling.receive', self.receive_sync)
            output.register_callback(f'{room}.colab.signaling.close', self.close_sync)

    @property
    def room(self):
        return self._room

    async def connect(self):
        join_url = self._origin + "/join/" + self._room

        # fetch room parameters
        self._http = aiohttp.ClientSession()
        async with self._http.post(join_url) as response:
            # we cannot use response.json() due to:
            # https://github.com/webrtc/apprtc/issues/562
            data = json.loads(await response.text())
        assert data["result"] == "SUCCESS"
        params = data["params"]

        self.__is_initiator = params["is_initiator"] == "true"
        self.__messages = params["messages"]
        self.__post_url = (
            self._origin + "/message/" + self._room + "/" + params["client_id"]
        )

        # connect to websocket
        self._websocket = await websockets.connect(
            params["wss_url"], extra_headers={"Origin": self._origin}
        )
        await self._websocket.send(
            json.dumps(
                {
                    "clientid": params["client_id"],
                    "cmd": "register",
                    "roomid": params["room_id"],
                }
            )
        )

        print(f"AppRTC room is {params['room_id']} {params['room_link']}")

        return params

    def connect_sync(self):
        loop = asyncio.get_event_loop()
        result = loop.run_until_complete(self.connect())
        if self._javascript_callable:
            return IPython.display.JSON(result)
        return result

    def close_sync(self):
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(self.close())

    def recv_nowait(self):
        try:
            return self._websocket.messages.popright() # .get_nowait()
        #except (asyncio.queues.QueueEmpty, IndexError):
        except IndexError:
            pass

    async def receive(self):
        if self.__messages:
            message = self.__messages.pop()
        else:
            message = self.recv_nowait()
            if message:
                message = json.loads(message)["msg"]

        if message:
            logger.debug("< " + message)
            return object_from_string(message)

    def receive_sync(self):
        loop = asyncio.get_event_loop()
        message = loop.run_until_complete(self.receive())
        if message and self._javascript_callable:
            message = object_to_string(message)
            print('receive:', message)
            message = json.loads(message)
            message = IPython.display.JSON(message)
        return message

    async def send(self, obj):
        message = object_to_string(obj)
        logger.debug("> " + message)
        if self.__is_initiator:
            await self._http.post(self.__post_url, data=message)
        else:
            await self._websocket.send(json.dumps({"cmd": "send", "msg": message}))

    def send_sync(self, message):
        print('send:', message)
        if type(message) == str:
            message_json = json.loads(message)
            if 'candidate' in message_json:
                message_json['type'] = 'candidate'
                message_json["id"] = message_json["sdpMid"]
                message_json["label"] = message_json["sdpMLineIndex"]
                message = json.dumps(message_json)
                message = object_from_string(message)
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(self.send(message))


class ColabSignaling:
    def __init__(self, signaling_folder=None, webrtc_server=None, room=None, javacript_callable=False):
        if room is None:
            room = "".join([random.choice("0123456789") for x in range(10)])

        if webrtc_server is None and signaling_folder is None:
            raise ValueError('Either a WebRTC server or a signaling folder must be provided.')
        if webrtc_server is None:
            self._webrtc_server = FilesystemRTCServer(folder=signaling_folder)
        else:
            self._webrtc_server = webrtc_server

        self._room = room
        self._javascript_callable = javacript_callable

        if output and javacript_callable:
            output.register_callback(f'{room}.colab.signaling.connect', self.connect_sync)
            output.register_callback(f'{room}.colab.signaling.send', self.send_sync)
            output.register_callback(f'{room}.colab.signaling.receive', self.receive_sync)
            output.register_callback(f'{room}.colab.signaling.close', self.close_sync)

    @property
    def room(self):
        return self._room

    async def connect(self):
        data = self._webrtc_server.join(self._room)
        assert data["result"] == "SUCCESS"
        params = data["params"]

        self.__is_initiator = params["is_initiator"] == "true"
        self.__messages = params["messages"]
        self.__peer_id = params["peer_id"]

        logger.info(f"Room ID: {params['room_id']}")
        logger.info(f"Peer ID: {self.__peer_id}")
        return params

    def connect_sync(self):
        loop = asyncio.get_event_loop()
        result = loop.run_until_complete(self.connect())
        if self._javascript_callable:
            return IPython.display.JSON(result)
        return result

    async def close(self):
        if self._javascript_callable:
            return self.send_sync(BYE)
        else:
            await self.send(BYE)

    def close_sync(self):
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(self.close())

    async def receive(self):
        message = self._webrtc_server.receive_message(self._room, self.__peer_id)
        # if self._javascript_callable:
        #     print('ColabSignaling: sending message to Javascript peer:', message)
        # else:
        #     print('ColabSignaling: sending message to Python peer:', message)
        if message and type(message) == str and not self._javascript_callable:
            message = object_from_string(message)
        return message

    def receive_sync(self):
        loop = asyncio.get_event_loop()
        message = loop.run_until_complete(self.receive())
        if message and self._javascript_callable:
            message = json.loads(message)
            message = IPython.display.JSON(message)
        return message

    async def send(self, message):
        if not self._javascript_callable or type(message) != str:
            message = object_to_string(message)
        self._webrtc_server.send_message(self._room, self.__peer_id, message)

    def send_sync(self, message):
        loop = asyncio.get_event_loop()
        return loop.run_until_complete(self.send(message))


In [None]:
#@title Modify peer-ui.js

%%writefile /content/colabrtc/colabrtc/js/peer-ui.js
var PeerUI = function(room, container_id) {
    // Define initial start time of the call (defined as connection between peers).
    startTime = null;
    constraints = {audio: false, video: true};

    let peerDiv = null;

    if (container_id) {
        peerDiv = document.getElementById(container_id);
    } else {
        peerDiv = document.createElement('div');
        document.body.appendChild(peerDiv);
    }

    var style = document.createElement('style');
    style.type = 'text/css';
    style.innerHTML = `
        .loader {
          position: absolute;
          left: 38%;
          top: 60%;
          z-index: 1;
          width: 50px;
          height: 50px;
          margin: -75px 0 0 -75px;
          border: 16px solid #f3f3f3;
          border-radius: 50%;
          border-top: 16px solid #3498db;
          -webkit-animation: spin 2s linear infinite;
          animation: spin 2s linear infinite;
        }

        @keyframes spin {
          0% { transform: rotate(0deg); }
          100% { transform: rotate(360deg); }
        }
    `;
    document.getElementsByTagName('head')[0].appendChild(style);

    var adapter = document.createElement('script');
    adapter.setAttribute('src','https://webrtc.github.io/adapter/adapter-latest.js');
    document.getElementsByTagName('head')[0].appendChild(adapter);

    //peerDiv.style.width = '70%';

    // Define video elements.
    const videoDiv = document.createElement('div');
    videoDiv.style.display = 'none';
    videoDiv.style.textAlign = '-webkit-center';
    const localView = document.createElement('video');
    const remoteView = document.createElement('video');
    remoteView.autoplay = true;
    //localView.style.display = 'block';
    //remoteView.style.display = 'block';
    localView.style.display = 'inline';
    remoteView.style.display = 'inline';
    localView.height = 240;
    localView.width = 320;
    remoteView.height = 240;
    remoteView.width = 320;
    videoDiv.appendChild(localView);
    videoDiv.appendChild(remoteView);
    const loader = document.createElement('div');
    loader.style.display = 'none';
    loader.className = 'loader';
    videoDiv.appendChild(loader);

    // Logs a message with the id and size of a video element.
    function logVideoLoaded(event) {
        const video = event.target;
        trace(`${video.id} videoWidth: ${video.videoWidth}px, ` +
            `videoHeight: ${video.videoHeight}px.`);

        //localView.style.width = '20%';
        //localView.style.position = 'absolute';
        //remoteView.style.display = 'block';
        localView.style.display = 'inline';
        remoteView.style.display = 'inline';
        //remoteView.style.width = '100%';
        //remoteView.style.height = 'auto';
        loader.style.display = 'none';
        //fullscreenButton.style.display = 'inline';
    }

    //localView.addEventListener('loadedmetadata', logVideoLoaded);
    remoteView.addEventListener('loadedmetadata', logVideoLoaded);
    //remoteView.addEventListener('onresize', logResizedVideo);

    // Define action buttons.
    const controlDiv = document.createElement('div');
    controlDiv.style.textAlign = 'center';
    const startButton = document.createElement('button');
    const fullscreenButton = document.createElement('button');
    const hangupButton = document.createElement('button');
    startButton.textContent = 'Join room: ' + room;
    startButton.style.display = 'none';
    fullscreenButton.textContent = 'Fullscreen';
    hangupButton.textContent = 'Hangup';
    controlDiv.appendChild(startButton);
    controlDiv.appendChild(fullscreenButton);
    controlDiv.appendChild(hangupButton);

    // Set up initial action buttons status: disable call and hangup.
    //callButton.disabled = true;
    hangupButton.style.display = 'none';
    fullscreenButton.style.display = 'none';

    peerDiv.appendChild(videoDiv);
    peerDiv.appendChild(controlDiv);

    this.localView = localView;
    this.remoteView = remoteView;
    this.peerDiv = peerDiv;
    this.videoDiv = videoDiv;
    this.loader = loader;
    this.startButton = startButton;
    this.fullscreenButton = fullscreenButton;
    this.hangupButton = hangupButton;
    this.constraints = constraints;
    this.room = room;

    self = this;
    async function start() {
        await self.connect(this.room);
    }

    // Handles hangup action: ends up call, closes connections and resets peers.
    async function hangup() {
        await self.disconnect();
    }

    function openFullscreen() {
      let elem = remoteView;
      if (elem.requestFullscreen) {
        elem.requestFullscreen();
      } else if (elem.mozRequestFullScreen) { /* Firefox */
        elem.mozRequestFullScreen();
      } else if (elem.webkitRequestFullscreen) { /* Chrome, Safari & Opera */
        elem.webkitRequestFullscreen();
      } else if (elem.msRequestFullscreen) { /* IE/Edge */
        elem.msRequestFullscreen();
      }
    }

    // Add click event handlers for buttons.
    this.startButton.addEventListener('click', start);
    this.fullscreenButton.addEventListener('click', openFullscreen);
    this.hangupButton.addEventListener('click', hangup);
    this.startButton.click()
};


PeerUI.prototype.connect = async function(room) {
    //startButton.disabled = true;
    const stream = await navigator.mediaDevices.getUserMedia(constraints);
    this.localView.srcObject = stream;
    this.localView.play();
    trace('Received local stream.');

    this.loader.style.display = 'block';
    this.startButton.style.display = 'none';
    //this.localView.style.width = '100%';
    //this.localView.style.height = 'auto';
    //this.localView.style.position = 'relative';
    //this.remoteView.style.display = 'none';
    this.videoDiv.style.display = 'block';

    if (google) {
      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
    }

    try {
        //this.joinButton.style.display = 'none';
        this.hangupButton.style.display = 'inline';

        trace('Starting call.');
        this.startTime = window.performance.now();

        this.peer = new Peer();
        await this.peer.connect(this.room);
        //const obj = JSON.stringify([this.peer.connect, this.room]);
        //this.worker.postMessage([this.peer, this.room]);

        this.peer.pc.ontrack = ({track, streams}) => {
            // once media for a remote track arrives, show it in the remote video element
            track.onunmute = () => {
                // don't set srcObject again if it is already set.
                if (this.remoteView.srcObject) return;
                console.log(streams);
                this.remoteView.srcObject = streams[0];
                trace('Remote peer connection received remote stream.');
                this.remoteView.play();
            };
        };

        const localStream = this.localView.srcObject;
        console.log('adding local stream');
        await this.peer.addLocalStream(localStream);

        await this.peer.waitMessage();

    } catch (err) {
        console.error(err);
    }
};

PeerUI.prototype.disconnect = async function() {
    await this.peer.disconnect();
    //this.startButton.style.display = 'inline';
    this.startButton.style.display = 'none';
    //this.joinButton.style.display = 'inline';
    this.hangupButton.style.display = 'none';
    this.fullscreenButton.style.display = 'none';
    this.videoDiv.style.display = 'none';

    trace('Ending call.');
    this.localView.srcObject.getVideoTracks()[0].stop();
    this.peerDiv.remove();
};

// Logs an action (text) and the time when it happened on the console.
function trace(text) {
  text = text.trim();
  const now = (window.performance.now() / 1000).toFixed(3);
  console.log(now, text);
}

In [None]:
#@title Create fomm_live.py

%%writefile /content/colabrtc/examples/fomm_live.py
import numpy as np
import torch

def normalize_kp(kp):
    kp = kp - kp.mean(axis=0, keepdims=True)
    area = ConvexHull(kp[:, :2]).volume
    area = np.sqrt(area)
    kp[:, :2] = kp[:, :2] / area
    return kp

def full_normalize_kp(source_area, kp_source, driving_area, kp_driving, kp_driving_initial, adapt_movement_scale=False,
                      use_relative_movement=False, use_relative_jacobian=False, exaggerate_factor=1):
    if adapt_movement_scale:
        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_new = {k: v for k, v in kp_driving.items()}

    if use_relative_movement:
        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
        kp_value_diff *= adapt_movement_scale * exaggerate_factor
        kp_new['value'] = kp_value_diff + kp_source['value']

        if use_relative_jacobian:
            jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
            kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])

    return kp_new

def make_animation(source, source_area, kp_source, driving_area, kp_driving_initial, driving_frame, kp_detector,
                   generator, adapt_movement_scale=False, use_relative_movement=False,
                   use_relative_jacobian=False,
                   exaggerate_factor=1, reset=False):

    with torch.no_grad():
        driving_frame = torch.tensor(driving_frame[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()

        if kp_driving_initial is None or reset:
            kp_driving_initial = kp_detector(driving_frame)
            driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume

        kp_driving = kp_detector(driving_frame)
        kp_norm = full_normalize_kp(source_area=source_area, kp_source=kp_source, driving_area=driving_area,
                                    kp_driving=kp_driving, kp_driving_initial=kp_driving_initial,
                                    adapt_movement_scale=adapt_movement_scale,
                                    use_relative_movement=use_relative_movement,
                                    use_relative_jacobian=use_relative_jacobian,
                                    exaggerate_factor=exaggerate_factor)
        out = generator(source, kp_source=kp_source, kp_driving=kp_norm)

        return np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0], driving_area, kp_driving_initial

import sys
sys.path.extend(['/content/colabrtc/colabrtc','/content/first-order-model'])
from peer import FrameTransformer
from call import ColabCall
from scipy.spatial import ConvexHull
from skimage.transform import resize
class Avatarify(FrameTransformer):

    def __init__(self, freq=1. / 30, avatar=0):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.avatar = avatar
        self.freq = freq

    def setup(self):
        import traceback
        from demo import load_checkpoints
        import imageio

        self.traceback = traceback
        self.reset = True
        self.kp_driving_initial = None
        self.driving_area = None
        self.generator, self.kp_detector = load_checkpoints(config_path='/content/first-order-model/config/vox-adv-256.yaml',
                                                  checkpoint_path='/content/vox-adv-cpk.pth.tar')


        source_image = imageio.imread('/content/image1')  # going extensionless allows more image formats
        if source_image.ndim == 2:
            source_image = np.tile(source_image[..., None], [1, 1, 3])
        h, w = source_image.shape[:2]
        s = min(h, w)
        source_image = resize(source_image[(h - s) // 2:(h + s) // 2, (w - s) // 2:(w + s) // 2], (256, 256))[..., :3]

        with torch.no_grad():
            self.source = torch.tensor(source_image[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()
            self.kp_source = self.kp_detector(self.source)
            self.source_area = ConvexHull(self.kp_source['value'][0].data.cpu().numpy()).volume

    def transform(self, frame, frame_idx=None, avatar=0):
        if self.freq and frame_idx % int(1. / self.freq) != 0:
            return

        #out_img = frame[...,::-1]
        #return

        if frame.ndim == 2:
            frame = np.tile(frame[..., None], [1, 1, 3])
        h, w = frame.shape[:2]
        s = min(h, w)
        frame = resize(frame[(h - s) // 2:(h + s) // 2, (w - s) // 2:(w + s) // 2], (256, 256))[..., :3]

        try:
            out_img, self.driving_area, self.kp_driving_initial = make_animation(self.source, self.source_area, self.kp_source, self.driving_area, self.kp_driving_initial, frame, self.kp_detector, self.generator,
                                     adapt_movement_scale=True, use_relative_movement=True,
                                     use_relative_jacobian=True,
                                     exaggerate_factor=1,
                                     reset=self.reset)
            self.reset = False
            out_img = (np.clip(out_img, 0, 1) * 255).astype(np.uint8)[..., ::-1]

            return out_img
        except Exception as err:
            self.traceback.print_exc()
            return frame

def run(room=None, signaling_folder='/content/webrtc', avatar=0, frame_freq=1. / 10, verbose=False):
    if room:
        room = str(room)

    afy = Avatarify(freq=frame_freq, avatar=avatar)
    call = ColabCall()
    call.create(room, signaling_folder=signaling_folder, verbose=verbose,
                frame_transformer=afy, multiprocess=False)

import fire
if __name__ == '__main__':
    fire.Fire(run)

In [None]:
#@title Go live!

#exaggerate_factor = 1 #@param {type:"slider", min:0.1, max:5, step:0.1}
#adapt_movement_scale = True #@param {type:"boolean"}
#use_relative_movement = True #@param {type:"boolean"}
#use_relative_jacobian = True #@param {type:"boolean"}

!pkill -f fomm_live.py
!rm -rf /content/webrtc
!rm -f /content/nohup.txt

# Due to multiprocessing support limitations, we need to run the Python peer via commandline.
!nohup python3 /content/colabrtc/examples/fomm_live.py \
--room 237 --avatar 0 > /content/nohup.txt 2>&1 &

import os
from time import sleep

while True:
  if os.path.exists('/content/nohup.txt'):
    with open('/content/nohup.txt') as f:
      if 'INFO:colabrtc.signaling:Peer ID:' in f.read():
        break
  sleep(10)

from call import ColabCall
call = ColabCall()
call.join(room='237')