<a href="https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_live.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Demo for paper "First Order Motion Model for Image Animation"

## **Live webcam in the browser!**

### Original project: https://aliaksandrsiarohin.github.io/first-order-model-website

#### Made just a little bit more accessible by Eyal Gruss ([@eyaler](https://twitter.com/eyaler) / [eyalgruss.com](https://eyalgruss.com) / [eyalgruss@gmail.com](mailto:eyalgruss@gmail.com))

#### Short link here: https://j.mp/cam2head

##### Click below for more refrences:

##### Original notebook: https://colab.research.google.com/github/AliaksandrSiarohin/first-order-model/blob/master/demo.ipynb

##### Faceswap notebook: https://colab.research.google.com/github/AliaksandrSiarohin/motion-cosegmentation/blob/master/part_swap.ipynb

##### Notebook with video enhancement: https://colab.research.google.com/github/tg-bomze/Face-Image-Motion-Model/blob/master/Face_Image_Motion_Model_(Photo_2_Video)_Eng.ipynb

##### Avatarify - a live vesrsion (requires local installation): https://github.com/alievk/avatarify

##### This live Colab solution is heavily based on the WebSocket implementation: https://github.com/a2kiti/webCamGoogleColab, https://qiita.com/a2kiti/items/f32de4f51a31d609e5a5

##### Other notable attempts based on WebRTC and aioRTC (https://github.com/aiortc/aiortc):
##### https://github.com/thefonseca/colabrtc
##### https://github.com/l4rz/first-order-model/tree/master/webrtc
##### https://gist.github.com/myagues/aac0c597f8ad0fa7ebe7d017b0c5603b
##### https://colab.research.google.com/github/eyaler/avatars4all/blob/master/incomplete_webrtc_fomm_live.ipynb (EG)

##### Randomly generated images from:
##### https://thispersondoesnotexist.com
##### https://fakeface.rest
##### https://www.thiswaifudoesnotexist.net
##### https://thisfursonadoesnotexist.com
##### https://eyalgruss.com/thismuppetdoesnotexist (@norod78, EG)

#### **Stuff I made**:
##### Avatars4all repository: https://github.com/eyaler/avatars4all
##### Notebook for live webcam in the browser: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_live.ipynb
##### Notebook for talking head model: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_bibi.ipynb
##### Notebook for full body models (FOMM): https://colab.research.google.com/github/eyaler/avatars4all/blob/master/fomm_fufu.ipynb
##### Notebook for full body models (impersonator): https://colab.research.google.com/github/eyaler/avatars4all/blob/master/ganozli.ipynb
##### Notebook for full body models (impersonator++): https://colab.research.google.com/github/eyaler/avatars4all/blob/master/ganivut.ipynb
##### Notebook for Wav2Lip audio based lip syncing: https://colab.research.google.com/github/eyaler/avatars4all/blob/master/melaflefon.ipynb
##### List of more generative tools (outdated): https://j.mp/generativetools

# Run me!

In [None]:
#@title Setup
#@markdown For best performance make sure you have a good internet connection.
machine = !nvidia-smi -L
print(machine)

%cd /content
!git clone --depth 1 https://github.com/eyaler/first-order-model
!wget --no-check-certificate -nc https://openavatarify.s3.amazonaws.com/weights/vox-adv-cpk.pth.tar
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/vox-adv-cpk.pth.tar

!mkdir -p /root/.cache/torch/hub/checkpoints
%cd /root/.cache/torch/hub/checkpoints
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/s3fd-619a316812.pth
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/2DFAN4-11f355bf06.pth.tar
%cd /content

!pip install git+https://github.com/1adrianb/face-alignment@v1.0.1

# !wget --no-check-certificate -nc https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
# !wget --no-check-certificate -nc https://eyalgruss.com/fomm/ngrok-v3-stable-linux-amd64.tgz
# !tar xvzf ngrok-v3-stable-linux-amd64.tgz
!wget --no-check-certificate -nc https://github.com/cloudflare/cloudflared/releases/latest/download/cloudflared-linux-amd64
!wget --no-check-certificate -nc https://eyalgruss.com/fomm/cloudflared/releases/latest/download/cloudflared-linux-amd64
!mv cloudflared-linux-amd64 cloudflared
!chmod +x cloudflared

!pip install bottle
!pip install bottle_websocket
!pip install wsaccel ujson
!pip install gevent

import warnings
warnings.filterwarnings("ignore")
from IPython.display import display, Javascript
from google.colab.output import eval_js

def use_cam(url, quality=0.8):
  print("start camera")
  js = Javascript('''
    console.clear();
    async function useCam(url, quality) {

      const fps = document.createElement('div');
      fps.style.marginTop = "16px";
      document.body.appendChild(fps);
      const panel = document.createElement('div');

      function on_dragover (event)
      {
          event.preventDefault();
          event.dataTransfer.dropEffect = 'copy';
          document.body.style.backgroundColor = 'yellow';
      }

      function on_dragleave (event)
      {
          event.preventDefault();
          document.body.style.backgroundColor = 'initial';
      }

      function on_drop (event)
      {
          event.preventDefault();
          if (connection.readyState !== WebSocket.OPEN) {return}
          document.body.style.backgroundColor = 'initial';
          if (avatar!=last) {
            if (last=="1") {av1_btn.click();}
            else if (last=="2") {av2_btn.click();}
            else {av3_btn.click();}
          }
          var imageUrl = event.dataTransfer.getData("text/html")||event.dataTransfer.getData("url");
          var file = event.dataTransfer.files ? event.dataTransfer.files[0] : null;
          if (file) {
            console.log('retrieving image from file...');
            let reader = new FileReader();
            reader.onload = function (event)
            {
              connection.send('drag' + event.target.result);
            };
            reader.readAsDataURL(file);
          } else if (imageUrl) {
            console.log('retrieving image from URL: ' + imageUrl);
            connection.send('url'+imageUrl);
          }
      }
      document.body.addEventListener ('dragover',  on_dragover, false);
      document.body.addEventListener ('dragleave', on_dragleave, false);
      document.body.addEventListener ('drop' ,     on_drop, false);

      const div = document.createElement('div');
      const div1 = document.createElement('div');
      const div2 = document.createElement('div');
      div2.style.textAlign = 'right';
      div.appendChild(div1);
      div.appendChild(div2);
      div.style.marginTop = "16px";
      var display_size = 256;
      panel.style.width = (display_size*2+16).toString()+"px";
      div.style.display= "flex";
      div.style.justifyContent= "space-between";
      panel.appendChild(div);
      document.body.appendChild(panel);
      //video element
      const video = document.createElement('video');
      video.style.display = 'None';
      const stream = await navigator.mediaDevices.getUserMedia({audio: false, video: { width:{min:256} , height: {min:256} , frameRate:24}});
      div.appendChild(video);
      video.srcObject = stream;
      await video.play();

      //canvas for display. frame rate is depending on display size and jpeg quality.
      const src_canvas = document.createElement('canvas');
      src_canvas.height  = display_size;
      src_canvas.width = display_size; // * video.videoWidth / video.videoHeight;
      const src_canvasCtx = src_canvas.getContext('2d');

      src_canvasCtx.translate(src_canvas.width, 0);
      src_canvasCtx.scale(-1, 1);
      div1.appendChild(src_canvas);

      const dst_canvas = document.createElement('canvas');
      dst_canvas.width  = src_canvas.width;
      dst_canvas.height = src_canvas.height;
      const dst_canvasCtx = dst_canvas.getContext('2d');
      div2.appendChild(dst_canvas);

      const vsld1 = document.createElement('input');
      const vsld2 = document.createElement('input');
      vsld1.style.marginTop = "16px";
      vsld2.style.marginTop = "16px";
      vsld1.type = "range";
      vsld1.min = "0";
      vsld1.max = "0.6";
      vsld1.step = "0.01";
      vsld1.defaultValue = "0.2";
      vsld1.style.width = "95%";
      vsld2.style.width = "95%";
      vsld2.type = "range";
      vsld2.min = "0";
      vsld2.max = "0.6";
      vsld2.step = "0.01";
      vsld2.defaultValue = "0";
      div1.appendChild(vsld1);
      div2.appendChild(vsld2);

      //exit button
      const btn_div = document.createElement('div');
      //document.body.appendChild(btn_div);
      const exit_btn = document.createElement('button');
      exit_btn.innerHTML = '<u>E</u>xit';
      var exit_flg = true;
      //exit_btn.onclick = function() {exit_flg = false;};
      //btn_div.appendChild(exit_btn);

      const btn3_div = document.createElement('div');
      btn3_div.style.marginTop = "16px";
      btn3_div.style.display= "flex";
      btn3_div.style.justifyContent= "space-between";
      panel.appendChild(btn3_div);

      const btn1_div = document.createElement('div');
      btn1_div.style.marginTop = "16px";
      btn1_div.style.display= "flex";
      btn1_div.style.justifyContent= "space-between";
      panel.appendChild(btn1_div);

      const btn2_div = document.createElement('div');
      btn2_div.style.marginTop = "16px";
      btn2_div.style.display= "flex";
      btn2_div.style.justifyContent= "space-between";
      //panel.appendChild(btn2_div);

      const btn2b_div = document.createElement('div');
      btn2b_div.style.marginTop = "16px";
      btn2b_div.style.display= "flex";
      btn2b_div.style.justifyContent= "space-between";
      panel.appendChild(btn2b_div);

      const btn4_div = document.createElement('div');
      btn4_div.style.marginTop = "16px";
      btn4_div.style.display= "flex";
      btn4_div.style.justifyContent= "space-between";
      panel.appendChild(btn4_div);

      function toggle(btn) {
          av1_btn.style.fontWeight='normal';
          av2_btn.style.fontWeight='normal';
          av3_btn.style.fontWeight='normal';
          av4_btn.style.fontWeight='normal';
          av5_btn.style.fontWeight='normal';
          av6_btn.style.fontWeight='normal';
          av7_btn.style.fontWeight='normal';
          av8_btn.style.fontWeight='normal';
          av9_btn.style.fontWeight='normal';
          av10_btn.style.fontWeight='normal';
          av11_btn.style.fontWeight='normal';
          av12_btn.style.fontWeight='normal';
          btn.style.fontWeight='bold';
      }

      var avatar = "1";
      var last = avatar;
      //avatar1 button
      const av1_btn = document.createElement('button');
      av1_btn.innerHTML = 'Avatar <u>1</u>';
      av1_btn.onclick = function() {avatar = "1";last=avatar;toggle(this);};
      av1_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "1";last=avatar;toggle(this);}};
      av1_btn.style.width = "22.5%";
      btn1_div.appendChild(av1_btn);

      //avatar2 button
      const av2_btn = document.createElement('button');
      av2_btn.innerHTML = 'Avatar <u>2</u>';
      av2_btn.onclick = function() {avatar = "2";last=avatar;toggle(this);};
      av2_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "2";last=avatar;toggle(this);}};
      av2_btn.style.width = "22.5%";
      btn1_div.appendChild(av2_btn);

      //avatar3 button
      const av3_btn = document.createElement('button');
      av3_btn.innerHTML = 'Avatar <u>3</u>';
      av3_btn.onclick = function() {avatar = "3";last=avatar;toggle(this);};
      av3_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "3";last=avatar;toggle(this);}};
      av3_btn.style.width = "22.5%";
      btn1_div.appendChild(av3_btn);

      //random human button
      const av4_btn = document.createElement('button');
      av4_btn.innerHTML = 'Human (<u>4</u>)';
      av4_btn.onclick = function() {avatar = "4";toggle(this);};
      av4_btn.okeydown = function(e) {if (e.code==13||e.code==32) {avatar = "4";toggle(this);}};
      av4_btn.style.width = "22.5%";
      btn1_div.appendChild(av4_btn);

      //random man button
      const av5_btn = document.createElement('button');
      av5_btn.innerHTML = 'Man (<u>5</u>)';
      av5_btn.onclick = function() {avatar = "5";toggle(this);};
      av5_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "5";toggle(this);}};
      av5_btn.style.width = "22.5%";
      btn2_div.appendChild(av5_btn);

      //random woman button
      const av6_btn = document.createElement('button');
      av6_btn.innerHTML = 'Woman (<u>6</u>)';
      av6_btn.onclick = function() {avatar = "6";toggle(this);};
      av6_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "6";toggle(this);}};
      av6_btn.style.width = "22.5%";
      btn2_div.appendChild(av6_btn);

      //random boy button
      const av7_btn = document.createElement('button');
      av7_btn.innerHTML = 'Boy (<u>7</u>)';
      av7_btn.onclick = function() {avatar = "7";toggle(this);};
      av7_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "7";toggle(this);}};
      av7_btn.style.width = "22.5%";
      btn2_div.appendChild(av7_btn);

      //random girl button
      const av8_btn = document.createElement('button');
      av8_btn.innerHTML = 'Girl (<u>8</u>)';
      av8_btn.onclick = function() {avatar = "8";toggle(this);};
      av8_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {avatar = "8";toggle(this);}};
      av8_btn.style.width = "22.5%";
      btn2_div.appendChild(av8_btn);

      //random waifu button
      const av9_btn = document.createElement('button');
      av9_btn.innerHTML = 'Waifu (<u>9</u>)';
      av9_btn.onclick = function() {avatar = "9";toggle(this);};
      av9_btn.okeydown = function(e) {if (e.code==13||e.code==32) {avatar = "9";toggle(this);}};
      av9_btn.style.width = "22.5%";
      btn2b_div.appendChild(av9_btn);

      //random fursona button
      const av10_btn = document.createElement('button');
      av10_btn.innerHTML = 'Fursona (<u>0</u>)';
      av10_btn.onclick = function() {avatar = "0";toggle(this);};
      av10_btn.okeydown = function(e) {if (e.code==13||e.code==32) {avatar = "0";toggle(this);}};
      av10_btn.style.width = "22.5%";
      btn2b_div.appendChild(av10_btn);

      //random muppet button
      const av11_btn = document.createElement('button');
      av11_btn.innerHTML = 'Muppet (<u>-</u>)';
      av11_btn.onclick = function() {avatar = "-";toggle(this);};
      av11_btn.okeydown = function(e) {if (e.code==13||e.code==32) {avatar = "-";toggle(this);}};
      av11_btn.style.width = "22.5%";
      btn2b_div.appendChild(av11_btn);

      //you button
      const av12_btn = document.createElement('button');
      av12_btn.innerHTML = 'You (<u>=</u>)';
      av12_btn.onclick = function() {avatar = "=";toggle(this);};
      av12_btn.okeydown = function(e) {if (e.code==13||e.code==32) {avatar = "=";toggle(this);}};
      av12_btn.style.width = "22.5%";
      btn2b_div.appendChild(av12_btn);


      toggle(av1_btn);

      function reset() {
          vsld1.value = vsld1.defaultValue;
          vsld2.value = vsld2.defaultValue;
          sld.value = sld.defaultValue;
          alp.value = alp.defaultValue;
          msg.value = msg.defaultValue;
          auto_btn.checked = auto_btn.defaultChecked;
          kp_btn.checked = kp_btn.defaultChecked;
          adam_btn.checked = adam_btn.defaultChecked;
          relm_btn.checked = relm_btn.defaultChecked;
          relj_btn.checked = relj_btn.defaultChecked;
          sld_out.innerHTML = parseFloat(sld.value).toFixed(1);
          alp_out.innerHTML = parseFloat(alp.value).toFixed(1);
          msg_out.innerHTML = msg.value;
          real_frame_count = 0;
          if (start!=null) {start=performance.now();}
          calib_btn.click();
      }

      document.addEventListener('keydown', function (event) {
        if ( event.key == '1' ) { av1_btn.click();  }
        else if ( event.key == '2' ) { av2_btn.click();  }
        else if ( event.key == '3' ) { av3_btn.click();  }
        else if ( event.key == '4' ) { av4_btn.click();  }
        else if ( event.key == '5' ) { av5_btn.click();  }
        else if ( event.key == '6' ) { av6_btn.click();  }
        else if ( event.key == '7' ) { av7_btn.click();  }
        else if ( event.key == '8' ) { av8_btn.click();  }
        else if ( event.key == '9' ) { av9_btn.click();  }
        else if ( event.key == '0' ) { av10_btn.click();  }
        else if ( event.key == '-' ) { av11_btn.click();  }
        else if ( event.key == '=' ) { av12_btn.click();  }
        else if ( event.key.toLowerCase() == 'c' || event.key == 'ב' || event.key == '`' || event.key == ';') { calib_btn.click();  }
        else if ( event.key.toLowerCase() == 'r' || event.key == 'ר' || event.code==27 || event.code==8) {reset();          }
        else if ( event.key.toLowerCase() == 's' || event.key == 'ד') { adam_btn.click();  }
        else if ( event.key.toLowerCase() == 'm' || event.key == 'צ') { relm_btn.click();  }
        else if ( event.key.toLowerCase() == 'j' || event.key == 'ח') { relj_btn.click();  }
        else if ( event.key.toLowerCase() == 'l' || event.key == 'ך') { kp_btn.click();  }
        else if ( event.key.toLowerCase() == 'b' || event.key == 'נ') { alp.value=(parseFloat(alp.value)==0)?"0.5":"0"; alp_out.innerHTML = "Alpha blend:&nbsp;&nbsp;"+parseFloat(alp.value).toFixed(1);}
        else if ( event.key.toLowerCase() == 'a' || event.key == 'ש') { auto_btn.click();}
      });

      //calib button
      const calib_btn = document.createElement('button');
      calib_btn.innerHTML = '<u>C</u>alibrate (<u>`</u>)';
      var calib_flg = "1";
      calib_btn.style.width = "48.33%";
      calib_btn.onclick = function() {calib_flg = "1";};
      calib_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {calib_flg = "1";}};
      btn3_div.appendChild(calib_btn);
      calib_btn.focus();

      //auto button
      const auto_label = document.createElement('label');
      btn3_div.appendChild(auto_label);
      const auto_btn = document.createElement('input');
      auto_btn.type = "checkbox";
      auto_btn.defaultChecked = false;
      auto_label.style.width = "22.5%";
      auto_label.innerHTML = '<u>A</u>uto<br>calibrate';
      auto_label.style.textAlign = 'center';
      auto_btn.style.marginRight = '10px';
      auto_label.insertBefore(auto_btn, auto_label.firstChild);

      //reset button
      const reset_btn = document.createElement('button');
      reset_btn.innerHTML = '<u>R</u>eset (<u>ESC</u>/<u>BS</u>)';
      reset_btn.onclick = function() {reset();};
      reset_btn.onkeydown = function(e) {if (e.code==13||e.code==32) {reset();}};
      reset_btn.style.width = "22.5%";
      btn3_div.appendChild(reset_btn);

      //adam button
      const adam_label = document.createElement('label');
      btn4_div.appendChild(adam_label);
      const adam_btn = document.createElement('input');
      adam_btn.type = "checkbox";
      adam_btn.defaultChecked = true;
      adam_label.style.width = "22.5%";
      adam_label.innerHTML = 'Adaptive<br><u>s</u>cale';
      adam_label.style.textAlign = 'center';
      adam_btn.style.marginRight = '10px';
      adam_label.insertBefore(adam_btn, adam_label.firstChild);

      //relm button
      const relm_label = document.createElement('label');
      btn4_div.appendChild(relm_label);
      const relm_btn = document.createElement('input');
      relm_btn.type = "checkbox";
      relm_btn.defaultChecked = true;
      relm_label.style.width = "22.5%";
      relm_label.innerHTML = 'Relative<br><u>m</u>ovement';
      relm_label.style.textAlign = 'center';
      relm_btn.style.marginRight = '10px';
      relm_label.insertBefore(relm_btn, relm_label.firstChild);

      //relj button
      const relj_label = document.createElement('label');
      btn4_div.appendChild(relj_label);
      const relj_btn = document.createElement('input');
      relj_btn.type = "checkbox";
      relj_btn.defaultChecked = true;
      relj_label.style.width = "22.5%";
      relj_label.innerHTML = 'Relative<br><u>J</u>acobian';
      relj_label.style.textAlign = 'center';
      relj_btn.style.marginRight = '10px';
      relj_label.insertBefore(relj_btn, relj_label.firstChild);

      //kp button
      const kp_label = document.createElement('label');
      btn4_div.appendChild(kp_label);
      const kp_btn = document.createElement('input');
      kp_btn.type = "checkbox";
      kp_btn.defaultChecked = false;
      kp_label.style.width = "22.5%";
      kp_label.innerHTML = 'Show<br><u>l</u>andmarks';
      kp_label.style.textAlign = 'center';
      kp_btn.style.marginRight = '10px';
      kp_label.insertBefore(kp_btn, kp_label.firstChild);


      //slider
      const btm_div = document.createElement('div');
      btm_div.style.display= "flex";
      btm_div.style.justifyContent= "space-between";
      const btm0_div = document.createElement('div');
      const btm1_div = document.createElement('div');
      const btm2_div = document.createElement('div');
      btm0_div.style.display= "flex";
      btm0_div.style.flexDirection = "column";
      btm0_div.style.justifyContent= "space-around";
      btm1_div.style.display= "flex";
      btm1_div.style.flexDirection = "column";
      btm1_div.style.justifyContent= "space-around";
      btm2_div.style.display= "flex";
      btm2_div.style.width= "69%";
      btm2_div.style.textAlign= "right";
      btm2_div.style.flexDirection = "column";
      btm2_div.style.justifyContent= "space-around";
      panel.appendChild(btm_div);
      btm_div.appendChild(btm0_div);
      btm_div.appendChild(btm1_div);
      btm_div.appendChild(btm2_div);

      const sld = document.createElement('input');
      const sld_out = document.createElement('div');
      const sld_text = document.createElement('div');
      sld.type = "range";
      sld.min = "0.1";
      sld.max = "5.0";
      sld.step = "0.1";
      btm_div.style.marginTop = "16px";
      sld.defaultValue = "1.0";
      sld_text.innerHTML = "Exaggeration&nbsp;factor:";
      sld_out.innerHTML = parseFloat(sld.value).toFixed(1);
      sld.oninput = function(event) {sld_out.innerHTML = parseFloat(this.value).toFixed(1);};
      btm0_div.appendChild(sld_text);
      btm1_div.appendChild(sld_out);
      btm2_div.appendChild(sld);

      //alpha
      const alp = document.createElement('input');
      const alp_out = document.createElement('div');
      const alp_text = document.createElement('div');
      alp.type = "range";
      alp.min = "0";
      alp.max = "1";
      alp.step = "0.1";
      alp.defaultValue = "0";
      alp.style.marginTop = "16px";
      alp_out.style.marginTop = "16px";
      alp_text.style.marginTop = "16px";
      alp_text.innerHTML = "Alpha&nbsp;<u>b</u>lend:";
      alp_out.innerHTML = parseFloat(alp.value).toFixed(1);
      alp.oninput = function(event) {alp_out.innerHTML = parseFloat(this.value).toFixed(1);};
      btm0_div.appendChild(alp_text);
      btm1_div.appendChild(alp_out);
      btm2_div.appendChild(alp);

      //msg
      var real_frame_count = 0;
      var start = null;
      const msg = document.createElement('input');
      const msg_out = document.createElement('div');
      const msg_text = document.createElement('div');
      msg.type = "range";
      msg.min = "1";
      msg.max = "20";
      msg.step = "1";
      msg.defaultValue = "6";
      msg.style.marginTop = "16px";
      msg_out.style.marginTop = "16px";
      msg_text.style.marginTop = "16px";
      msg_text.innerHTML = "Message&nbsp;buffer:";
      msg_out.innerHTML = msg.value;
      msg.oninput = function(event) {msg_out.innerHTML = msg.value; real_frame_count = 0; start = null;};
      btm0_div.appendChild(msg_text);
      btm1_div.appendChild(msg_out);
      btm2_div.appendChild(msg);

      //log
      let jsLog = function(abc) {
        document.querySelector("#output-area").appendChild(document.createTextNode(`${abc} `));
      };
      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);

      //for websocket connection.
      var connection = 0;
      var in_transit_count = 0;
      var payload_size = 0;

      var socketOnOpen = function(e) {
        console.log("websocket open");
        jsLog(" Websocket open. ");
        start=performance.now();
      }

      var socketOnMessage = function(e) {
        in_transit_count-=1;
        var image = new Image();
        image.src = e.data;
        //image.onload = function() {dst_canvasCtx.drawImage(image,parseInt(vsld2.value), parseInt(vsld2.value), display_size-2*parseInt(vsld2.value), display_size-2*parseInt(vsld2.value),0,0, display_size, display_size);};
        image.onload = function() {dst_canvasCtx.drawImage(image,0,0); real_frame_count+=1;};
        if (start) {fps.innerHTML = "payload=" + payload_size + " fps="+(real_frame_count*1000/(performance.now()-start)).toFixed(1)+" --- Drag & drop local/web images to upload new avatars!";}
      };

      var socketOnClose = function(e) {
        console.log('websocket disconnected - waiting for connection');
        websocketWaiter();
      };

      function websocketWaiter() {
        setTimeout(function() {
          connection = new WebSocket(url);
          connection.onopen = socketOnOpen;
          connection.onmessage = socketOnMessage;
          connection.onclose = socketOnClose;
        }, 1000);
      };

      websocketWaiter();
      jsLog("camera="+video.videoWidth+"x"+video.videoHeight+".");

      // loop
      async function _canvasUpdate() {
        var s = Math.min(video.videoWidth, video.videoHeight) * (1-vsld1.value); // adapted from https://github.com/alievk/avatarify
        src_canvasCtx.drawImage(video,Math.round(video.videoWidth-s)/2, Math.round(video.videoHeight-s)/2, Math.round(s), Math.round(s),0,0, display_size, display_size);

        if (connection.readyState === WebSocket.OPEN && in_transit_count<parseInt(msg.value))
        {
          in_transit_count+=1;
          var img = src_canvas.toDataURL('image/jpeg', quality);
          var sld_str = parseFloat(sld.value).toFixed(1);
          var alpha = parseFloat(alp.value).toFixed(1);
          var crop = parseFloat(vsld2.value).toFixed(2);
          var auto_flg = (auto_btn.checked)?"1":"0";
          var adam_flg = (adam_btn.checked)?"1":"0";
          var relm_flg = (relm_btn.checked)?"1":"0";
          var relj_flg = (relj_btn.checked)?"1":"0";
          var kp_flg = (kp_btn.checked)?"1":"0";
          var payload = calib_flg+avatar+sld_str+alpha+crop+auto_flg+adam_flg+relm_flg+relj_flg+kp_flg+img;
          payload_size = payload.length;
          connection.send(payload);
          avatar="`";
          calib_flg = "0";
        }
        if (exit_flg) {
            requestAnimationFrame(_canvasUpdate);
        }
        else {
          stream.getVideoTracks()[0].stop();
          connection.close();
        }
      }
      _canvasUpdate();
    }
    ''')
  display(js)
  eval_js('useCam("{}", {})'.format(url, quality))

print(machine)

In [None]:
#@title Get the Avatar images from the web
#@markdown 1. You can change the URLs to your **own** stuff!
#@markdown 2. Alternatively, you can upload **local** files in the next cell
#@markdown 3. You can later also **drag and drop** images on the GUI to upload new avatars!

image1_url = 'https://www.beat.com.au/wp-content/uploads/2018/05/ilana.jpg' #@param {type:"string"}
image2_url = 'https://img.zeit.de/zeit-magazin/2017-03/marina-abramovic-performance-kuenstlerin-the-cleaner-monografie-oevre-bilder/marina-abramovic-performance-kuenstlerin-the-cleaner-monografie-oevre-10.jpg/imagegroup/original__620x620__desktop' #@param {type:"string"}
image3_url = 'https://i.pinimg.com/originals/27/86/58/2786580674b7c9b20ead54f53bf0be9e.jpg' #@param {type:"string"}

if image1_url:
  !wget "$image1_url" -O /content/image1

if image2_url:
  !wget "$image2_url" -O /content/image2

if image3_url:
  !wget "$image3_url" -O /content/image3

In [None]:
#@title Optionally upload local Avatar images { run: "auto" }
#@markdown Instructions: mark the checkbox + press play if it doesn't start by itself + click the upload button that will appear below

manually_upload_images = False #@param {type:"boolean"}
if manually_upload_images:
  from google.colab import files
  import shutil

  %cd /content/sample_data
  try:
    uploaded = files.upload()
  except Exception as e:
    %cd /content
    raise e

  for i,fn in enumerate(uploaded, start=1):
    shutil.move('/content/sample_data/'+fn, '/content/image%d'%i)
    if i==3:
      break
  %cd /content


In [None]:
#@title Prepare assets
center_image1_to_head = True #@param {type:"boolean"}
crop_image1_to_head = False #@param {type:"boolean"}
image1_crop_expansion_factor = 2.5 #@param {type:"number"}

center_image2_to_head = True #@param {type:"boolean"}
crop_image2_to_head = True #@param {type:"boolean"}
image2_crop_expansion_factor = 2.5 #@param {type:"number"}

center_image3_to_head = True #@param {type:"boolean"}
crop_image3_to_head = False #@param {type:"boolean"}
image3_crop_expansion_factor = 2.5 #@param {type:"number"}

center_image_to_head = (center_image1_to_head, center_image2_to_head, center_image3_to_head)
crop_image_to_head = (crop_image1_to_head, crop_image2_to_head, crop_image3_to_head)
image_crop_expansion_factor = (image1_crop_expansion_factor, image2_crop_expansion_factor, image3_crop_expansion_factor)

import imageio
import numpy as np
from google.colab.patches import cv2_imshow
from skimage.transform import resize

import face_alignment
import torch

if not hasattr(face_alignment.utils, '_original_transform'):
    face_alignment.utils._original_transform = face_alignment.utils.transform

def patched_transform(point, center, scale, resolution, invert=False):
    return face_alignment.utils._original_transform(
        point, center, torch.tensor(scale, dtype=torch.float32), torch.tensor(resolution, dtype=torch.float32), invert)

face_alignment.utils.transform = patched_transform

try:
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
                                      device='cuda')
except Exception:
  !rm -rf /root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth
  !rm -rf /root/.cache/torch/hub/checkpoints/2DFAN4-11f355bf06.pth.tar
  fa = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=True,
                                      device='cuda')

def create_bounding_box(target_landmarks, expansion_factor=1):
    target_landmarks = np.array(target_landmarks)
    x_y_min = target_landmarks.reshape(-1, 68, 2).min(axis=1)
    x_y_max = target_landmarks.reshape(-1, 68, 2).max(axis=1)
    expansion_factor = (expansion_factor-1)/2
    bb_expansion_x = (x_y_max[:, 0] - x_y_min[:, 0]) * expansion_factor
    bb_expansion_y = (x_y_max[:, 1] - x_y_min[:, 1]) * expansion_factor
    x_y_min[:, 0] -= bb_expansion_x
    x_y_max[:, 0] += bb_expansion_x
    x_y_min[:, 1] -= bb_expansion_y
    x_y_max[:, 1] += bb_expansion_y
    return np.hstack((x_y_min, x_y_max-x_y_min))

def fix_dims(im):
    if im.ndim == 2:
        im = np.tile(im[..., None], [1, 1, 3])
    return im[...,:3]

def get_crop(im, center_face=True, crop_face=True, expansion_factor=1, landmarks=None):
    im = fix_dims(im)
    if (center_face or crop_face) and not landmarks:
        landmarks = fa.get_landmarks_from_image(im)
    if (center_face or crop_face) and landmarks:
        rects = create_bounding_box(landmarks, expansion_factor=expansion_factor)
        x0,y0,w,h = sorted(rects, key=lambda x: x[2]*x[3])[-1]
        if crop_face:
            s = max(h, w)
            x0 += (w-s)//2
            x1 = x0 + s
            y0 += (h-s)//2
            y1 = y0 + s
        else:
            img_h,img_w = im.shape[:2]
            img_s = min(img_h,img_w)
            x0 = min(max(0, x0+(w-img_s)//2), img_w-img_s)
            x1 = x0 + img_s
            y0 = min(max(0, y0+(h-img_s)//2), img_h-img_s)
            y1 = y0 + img_s
    else:
        h,w = im.shape[:2]
        s = min(h,w)
        x0 = (w-s)//2
        x1 = x0 + s
        y0 = (h-s)//2
        y1 = y0 + s
    return int(x0),int(x1),int(y0),int(y1)

def pad_crop_resize(im, x0=None, x1=None, y0=None, y1=None, new_h=256, new_w=256):
    im = fix_dims(im)
    h,w = im.shape[:2]
    if x0 is None:
      x0 = 0
    if x1 is None:
      x1 = w
    if y0 is None:
      y0 = 0
    if y1 is None:
      y1 = h
    if x0<0 or x1>w or y0<0 or y1>h:
        im = np.pad(im, pad_width=[(max(-y0,0),max(y1-h,0)),(max(-x0,0),max(x1-w,0)),(0,0)], mode='edge')
    im = im[max(y0,0):y1-min(y0,0),max(x0,0):x1-min(x0,0)]
    im = resize(im, (im.shape[0] if new_h is None else new_h, im.shape[1] if new_w is None else new_w))
    return im

source_image = []
orig_image = []
for i in range(3):
    img = imageio.imread('/content/image%d'%(i+1))
    img = pad_crop_resize(img, *get_crop(img, center_face=center_image_to_head[i], crop_face=crop_image_to_head[i], expansion_factor=image_crop_expansion_factor[i]), new_h=None, new_w=None)
    orig_image.append(img)
    source_image.append(resize(img, (256,256)))
num_avatars = len(source_image)

cv2_imshow(np.hstack(source_image)[...,::-1]*255)

In [None]:
#@title Go live!
#@markdown Kindly approve camera access if asked. If it seems stuck for a long time - click stop and play this cell again.
tunnel = 'argo' #@param ['argo']
# removed ngrok as it now requires authentication

import requests
import re

!pkill -f ngrok
!pkill -f cloudflared
try:
  _pool.terminate()
except:
  pass
try:
  save_socket.close()
except:
  pass
try:
  server.shutdown()
except:
  pass

port = 6006
if tunnel=='ngrok':
  !nohup /content/ngrok http --inspect=false $port &
elif tunnel=='argo':
  !nohup /content/cloudflared tunnel --url http://localhost:$port --metrics localhost:49589 &

from time import time, sleep
import json
ngrok_url = None
while not ngrok_url:
  try:
    if tunnel=='ngrok':
      ngrok_json = !curl http://localhost:4040/api/tunnels
      ngrok_url = json.loads(ngrok_json[0])['tunnels'][0]['public_url'].split('://',1)[-1]
    elif tunnel=='argo':
      argo_metrics = requests.get("http://localhost:49589/metrics").text
      ngrok_url = re.search('cloudflared_tunnel_user_hostnames_counts{userHostname="https://(.+?)"}', argo_metrics).group(1)
  except Exception as e:
    print('Trying to connect tunnel...', e)
    sleep(1)
from IPython.display import clear_output
clear_output()
ngrok_url = 'wss://'+ngrok_url
print(ngrok_url)

%cd /content/first-order-model

from demo import load_checkpoints
generator, kp_detector = load_checkpoints(config_path='/content/first-order-model/config/vox-adv-256.yaml',
                            checkpoint_path='/content/vox-adv-cpk.pth.tar')


from scipy.spatial import ConvexHull
def normalize_kp(kp):
    kp = kp - kp.mean(axis=0, keepdims=True)
    area = ConvexHull(kp[:, :2]).volume
    area = np.sqrt(area)
    kp[:, :2] = kp[:, :2] / area
    return kp

import torch
from skimage import img_as_ubyte
import cv2
import bottle
import gevent
from bottle.ext.websocket import GeventWebSocketServer
from bottle.ext.websocket import websocket
from multiprocessing import Pool
from PIL import Image
import contextlib
from io import BytesIO, StringIO
import base64
from logger import Visualizer
vis = Visualizer(kp_size=3, colormap='gist_rainbow')

def norm_source(i,crop=0):
    with torch.no_grad():
        img = source_image[i]
        if crop:
            img = orig_image[i]
            h,w = img.shape[:2]
            s = min(h,w) * (1-crop) # adapted from https://github.com/alievk/avatarify
            img = resize(img[int((h-s)/2):int((h+s)/2),int((w-s)/2):int((w+s)/2)], (256,256))

        source[i] = torch.tensor(img[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()
        kp_source[i] = kp_detector(source[i])
        source_area[i] = ConvexHull(kp_source[i]['value'][0].data.cpu().numpy()).volume

gen_urls = ["https://thispersondoesnotexist.com/",
           "https://fakeface.rest/face/view?gender=male&minimum_age=18",
           "https://fakeface.rest/face/view?gender=female&minimum_age=18",
           "https://fakeface.rest/face/view?gender=male&maximum_age=17",
           "https://fakeface.rest/face/view?gender=female&maximum_age=17",
           "https://www.thiswaifudoesnotexist.net/example-",
           "https://thisfursonadoesnotexist.com/v2/jpgs-2x/seed",
           "https://eyalgruss.com/thismuppetdoesnotexist/seed"]

if len(orig_image)==num_avatars:
    orig_image += [None]*(len(gen_urls)+1)

if len(source_image)==num_avatars:
    source_image += [None]*(len(gen_urls)+1)

def load_stylegan_avatar(avatar, crop=0): # adapted from https://github.com/alievk/avatarify
    url = gen_urls[avatar-num_avatars]
    if url.endswith('example-'):
      url += '%d.jpg'%np.random.randint(10000,100000)
    elif url.endswith('seed'):
      url += '%05d.jpg'%np.random.randint(100000)
    if url.startswith('https://fakeface.rest'):
      return
    r = requests.get(url, headers={'User-Agent': "My User Agent 1.0"}).content
    image = np.frombuffer(r, np.uint8)
    image = cv2.imdecode(image, cv2.IMREAD_COLOR)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    orig_image[avatar] = image
    source_image[avatar] = resize(image, (256, 256))

    norm_source(avatar, crop=crop)

source = [None]*len(orig_image)
kp_source = [None]*len(orig_image)
source_area = [None]*len(orig_image)
have_gen = [False]*len(gen_urls)
crops = [0]*len(orig_image)
for i in range(len(orig_image)-1):
    if i<num_avatars:
        norm_source(i)
    else:
        try:
            load_stylegan_avatar(i)
            have_gen[i-num_avatars] = True
        except Exception as e:
            print(e)

def full_normalize_kp(kp_driving, driving_area, kp_driving_initial, adapt_movement_scale=False,
                 use_relative_movement=False, use_relative_jacobian=False, exaggerate_factor=1):
    if adapt_movement_scale:
        adapt_movement_scale = np.sqrt(source_area[avatar]) / np.sqrt(driving_area)
    else:
        adapt_movement_scale = 1

    kp_new = {k: v for k, v in kp_driving.items()}

    if use_relative_movement:
        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
        kp_value_diff *= adapt_movement_scale * exaggerate_factor
        kp_new['value'] = kp_value_diff + kp_source[avatar]['value']

        if use_relative_jacobian:
            jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
            kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source[avatar]['jacobian'])

    return kp_new


kp_driving_initial = None
driving_area = None
def make_animation(driving_frame, adapt_movement_scale=False, use_relative_movement=False, use_relative_jacobian=False, exaggerate_factor=1, reset=False, auto=False):

    global kp_driving_initial, driving_area

    with torch.no_grad():
        driving_frame = torch.tensor(driving_frame[np.newaxis].astype(np.float32)).permute(0, 3, 1, 2).cuda()

        kp_driving = kp_detector(driving_frame)

        if auto and kp_driving_initial is not None and not reset:
            new_dist = ((kp_source[avatar]['value'] - kp_driving['value']) ** 2).sum().data.cpu().numpy()
            old_dist = ((kp_source[avatar]['value'] - kp_driving_initial['value']) ** 2).sum().data.cpu().numpy()
        if kp_driving_initial is None or reset or auto and new_dist<old_dist:
            kp_driving_initial = kp_driving
            driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume

        kp_norm = full_normalize_kp(kp_driving=kp_driving, driving_area=driving_area,
                                kp_driving_initial=kp_driving_initial, adapt_movement_scale=adapt_movement_scale, use_relative_movement=use_relative_movement,
                                use_relative_jacobian=use_relative_jacobian, exaggerate_factor=exaggerate_factor)
        out = generator(source[avatar], kp_source=kp_source[avatar], kp_driving=kp_norm)

        return np.transpose(out['prediction'].data.cpu().numpy(), [0, 2, 3, 1])[0]

avatar = -1
anti_aliasing = False
save_socket = None
socket = bottle.Bottle()
@socket.route('/', apply=[websocket])
def wsbin(ws):
    global avatar, save_socket, have_gen
    save_socket = ws
    reset = True
    new_image = None
    wait_start = time()
    while True:
        try:
            frame_start = time()
            img_str = ws.receive()
            t1 = time()-frame_start

            if img_str is not None and (img_str.startswith('drag') or img_str.startswith('url')):
                if img_str.startswith('url'):
                  img_str = img_str[3:].split('<img ',1)[-1]
                  if 'src="' in img_str:
                    img_str = img_str.split('src="',1)[-1]
                  else:
                    img_str = img_str.split('href="',1)[-1]
                  img_str = img_str.split('"',1)[0]
                if 'data:image/' not in img_str:
                  get_image = requests.get(img_str, headers={'User-Agent': "My User Agent 1.0"}).content
                else:
                  get_image = base64.b64decode(img_str.split(',')[1])#, validate=True)
                get_image = Image.open(BytesIO(get_image))
                new_image = np.array(get_image)
                continue

            start = time()
            decimg = base64.b64decode(img_str[17:].split(',')[1])#, validate=True)
            decimg = Image.open(BytesIO(decimg))
            decimg = (np.array(decimg)/255).astype(np.float32)
            t2 = time()-start

            new_crop = float(img_str[8:12])

            reset |= img_str[0]=="1"

            if img_str[1]=="`":
                new_avatar = -1
            elif img_str[1]=="0":
                new_avatar = 9
            elif img_str[1]=="-":
                new_avatar = 10
            elif img_str[1]=="=":
                new_avatar = 11
            else:
                new_avatar = int(img_str[1])-1
            if new_avatar>=0:
                if new_avatar==num_avatars+len(gen_urls):
                    orig_image[new_avatar] = decimg
                    source_image[new_avatar] = decimg #resize(decimg, (256, 256))
                elif new_avatar>=num_avatars:
                    if have_gen[new_avatar-num_avatars]:
                        have_gen[new_avatar-num_avatars]=False
                    else:
                        if new_crop != crops[new_avatar]:
                            crops[new_avatar] = new_crop
                        load_stylegan_avatar(new_avatar, crop=crops[new_avatar])
                avatar = new_avatar
                reset = True

            if new_image is not None and avatar<num_avatars:
                new_image = pad_crop_resize(new_image, *get_crop(new_image, center_face=True, crop_face=True, expansion_factor=2.5), new_h=None, new_w=None)
                orig_image[avatar] = new_image
                source_image[avatar] = resize(new_image,(256,256))
                reset = True

            exaggerate_factor = float(img_str[2:5])
            alpha = float(img_str[5:8])
            auto = int(img_str[12])
            adapt_movement_scale = int(img_str[13])
            use_relative_movement = int(img_str[14])
            use_relative_jacobian = int(img_str[15])
            show_kp = int(img_str[16])
            if new_crop != crops[avatar] or avatar==num_avatars+len(gen_urls) or new_image is not None:
                new_image = None
                crops[avatar] = new_crop
                norm_source(avatar,crop=crops[avatar])

            #h,w = decimg.shape[:2]
            #s=min(h,w)
            #decimg = resize(decimg[(h-s)//2:(h+s)//2,(w-s)//2:(w+s)//2], (256, 256), anti_aliasing=anti_aliasing)[..., :3]

            start = time()
            out_img = make_animation(decimg, adapt_movement_scale=adapt_movement_scale, use_relative_movement=use_relative_movement,
                                   use_relative_jacobian=use_relative_jacobian, exaggerate_factor=exaggerate_factor, reset=reset, auto=auto)
            t3 = time()-start
            reset = False

            out_img = np.clip(out_img, 0, 1)

            if show_kp:
                if alpha>0:

                  with contextlib.redirect_stdout(StringIO()):
                      kp_source = fa.get_landmarks(255 * decimg)
                  if kp_source:
                    spatial_size = np.array(decimg.shape[:2][::-1])[np.newaxis]
                    decimg = vis.draw_image_with_kp(decimg, kp_source[0] * 2 / spatial_size - 1)
                with contextlib.redirect_stdout(StringIO()):
                    kp_driver = fa.get_landmarks(255 * out_img)
                if kp_driver:
                    spatial_size = np.array(out_img.shape[:2][::-1])[np.newaxis]
                    out_img = vis.draw_image_with_kp(out_img, kp_driver[0] * 2 / spatial_size - 1)

            if alpha:
              out_img = cv2.addWeighted(out_img, 1-alpha, decimg, alpha, 0)

            out_img = (out_img * 255).astype(np.uint8)

            #encode to string
            start = time()
            _, encimg = cv2.imencode(".jpg", out_img[...,::-1], [int(cv2.IMWRITE_JPEG_QUALITY), 80])
            rep_str = encimg.tostring()
            rep_str = "data:image/jpeg;base64," + base64.b64encode(rep_str).decode('utf-8')
            t4 = time()-start

            start = time()
            ws.send(rep_str)
            t5 = time()-start
            tsum = t1+t2+t3+t4+t5
            tframe = time()-frame_start
            twait = frame_start-wait_start
            tcycle = time()-wait_start
            #print('receive=%d decode=%d animate=%d encode=%d send=%d sum=%d total=%d wait=%d sum=%d total=%d'%(t1*1000,t2*1000,t3*1000,t4*1000,t5*1000,tsum*1000,tframe*1000,twait*1000,(t6+t0)*1000,tcycle*1000))
            wait_start = time()
        except Exception as e:
            #raise e
            pass
            #print(e)

import logging
from bottle import ServerAdapter
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler
from geventwebsocket.logging import create_logger

class MyGeventWebSocketServer(ServerAdapter):
    def run(self, handler):
        server = pywsgi.WSGIServer((self.host, self.port), handler, handler_class=WebSocketHandler)

        if not self.quiet:
            server.logger = create_logger('geventwebsocket.logging')
            server.logger.setLevel(logging.INFO)
            server.logger.addHandler(logging.StreamHandler())

        self.server = server
        server.serve_forever()

    def shutdown(self):
        self.server.stop()
        self.server.close()

if __name__ == '__main__':
    # prepare multiprocess
    _pool = Pool(processes=2)
    _pool.apply_async(use_cam, (ngrok_url, 0.8))
    print(machine)
    server = MyGeventWebSocketServer(host='0.0.0.0', port=port)
    from IPython.utils import io
    with io.capture_output() as captured:
        socket.run(server=server)