# **All necessary imports**

In [104]:

import cv2
import numpy as np
import os
import pandas as pd
import zipfile
import time
import sys
import torch
import torchvision as tv
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as TF
import matplotlib.pyplot as plt 
import os.path
import wget

from PIL import ImageDraw
from google.colab import drive
from torchvision.datasets import CIFAR10
from torchvision.datasets.vision import StandardTransform
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, random_split
from torchvision.transforms.functional import to_pil_image
from PIL import Image
from matplotlib import pyplot as plt
from torchvision import transforms
from datetime import timedelta
from torchvision.models import resnet18
from torch.optim import Adam, AdamW
# from torch import nn can be removed?

%matplotlib inline

!pip install wget

DATA_DIR = "./data"

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# **All variables**

In [105]:
Saving_frames_per_second = 30

#choose on which video landmarks should be predicted
Test_video_file = 'matthis_vid_v1.mp4'

#choose which state dict to load
Model_state = 'v6'

#Insert params depending on state dict that will be loaded (needs to be the same)
NET = "ResNet18"
FC_LAYER = "Lin-ReLu-Lin"

# **All used functions**

In [106]:
def device():
    return torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def add_fc(net, layer):
    net.fc = layer
    return net

def plot_image(image: Image, labeling: np.ndarray=None):
    try:
      if isinstance(image, torch.Tensor) or isinstance(image, np.ndarray):
          image = to_pil_image(image)
    except TypeError:
      pass
    finally:
      plt.imshow(image, interpolation='nearest', cmap='gray')
      
      if labeling is not None:
          print(len(labeling))
          for i in range(0, len(labeling)-1, 2):
              plt.plot(labeling[i + 0],labeling[i + 1], marker=".", color='cyan')
      plt.show()
      
def plot_predicted_facial_landmarks(*pil_images):
  coordinates = predict_facial_landmarks(*pil_images)
  for i in range(len(pil_images)):
    plot_image(pil_images[i],coordinates[i])

def return_predicted_facial_landmarks(*pil_images):
  coordinates = predict_facial_landmarks(*pil_images)
  for i in range(len(pil_images)):
    plot_image(pil_images[i],coordinates[i])
  

def predict_facial_landmarks(*pil_images):
  preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor()
  ])

  tensor_images = [preprocess(image.convert('L').convert('RGB')) for image in pil_images]
  image_batch = torch.stack(tensor_images,dim=0).to(device())
  net.eval()
  labels = net(image_batch)
  return labels.cpu().detach().numpy()

def add_predictions(*pil_images):
  coordinates = predict_facial_landmarks(*pil_images)
  for i in range(len(pil_images)):
    for j in range(0, len(coordinates[i]), 2):
      radius = 1
      leftUpPoint = (coordinates[i][j] - radius, coordinates[i][j+1] - radius)
      rightDownPoint = (coordinates[i][j] + radius, coordinates[i][j+1] + radius)
      draw = ImageDraw.Draw(pil_images[i])
      draw.ellipse((leftUpPoint, rightDownPoint), fill = 'red')
      #pil_images[i].putpixel( (coordinates[i][j], coordinates[i][j+1]), (255, 0, 0))
  return pil_images

def format_timedelta(td):
    """Utility function to format timedelta objects in a cool way (e.g 00:00:20.05) 
    omitting microseconds and retaining milliseconds"""
    result = str(td)
    try:
        result, ms = result.split(".")
    except ValueError:
        return result + ".00".replace(":", "-")
    ms = int(ms)
    ms = round(ms / 1e4)
    return f"{result}.{ms:02}".replace(":", "-")


def get_saving_frames_durations(cap, saving_fps):
    """A function that returns the list of durations where to save the frames"""
    s = []
    # get the clip duration by dividing number of frames by the number of frames per second
    clip_duration = cap.get(cv2.CAP_PROP_FRAME_COUNT) / cap.get(cv2.CAP_PROP_FPS)
    # use np.arange() to make floating-point steps
    for i in np.arange(0, clip_duration, 1 / saving_fps):
        s.append(i)
    return s

def extract_frames_from(video_file):
  # read the video file    
  cap = cv2.VideoCapture(video_file)

  # Check if camera opened successfully
  if(cap.isOpened()== False):
	  print("Error opening video stream or file")
  # get the FPS of the video
  fps = cap.get(cv2.CAP_PROP_FPS)
  # if the SAVING_FRAMES_PER_SECOND is above video FPS, then set it to FPS (as maximum)
  saving_frames_per_second = min(fps, Saving_frames_per_second)
  # get the list of duration spots to save
  saving_frames_durations = get_saving_frames_durations(cap, saving_frames_per_second)
  # start the loop
  count = 0
  frames = []
  while True:
      is_read, frame = cap.read()
      if not is_read:
          # break out of the loop if there are no frames to read
          break
      # get the duration by dividing the frame count by the FPS
      frame_duration = count / fps
      try:
          # get the earliest duration to save
          closest_duration = saving_frames_durations[0]
      except IndexError:
          # the list is empty, all duration frames were saved
          break
      if frame_duration >= closest_duration:
          # if closest duration is less than or equals the frame duration, 
          # then save the frame
          frame_duration_formatted = format_timedelta(timedelta(seconds=frame_duration))
          #cv2.imwrite(os.path.join(filename, f"frame{frame_duration_formatted}.jpg"), frame) 
          ## drop the duration spot from the list, since this duration spot is already saved
          frames.append(frame)
          try:
              saving_frames_durations.pop(0)
          except IndexError:
              pass
      # increment the frame count
      count += 1
  return frames

def plot_faces(images, coordinates=None, num=5):
  for i in range(num):
    if coordinates == None or coordinates[i] == None:
      plot_face(images[i])
    else:
      plot_face(images[i],coordinates[i])


def plot_face(image, coordinates=None, num=5):
    plt.imshow(image, interpolation='nearest',cmap="gray")
    if coordinates is not None:
      for i in range(0,len(coordinates)-1,2):
        plt.plot(coordinates[i + 0],coordinates[i + 1],marker=".",color="red")
    plt.show()
    
def opencv_to_pil_image(opencv_image):
  color_coverted = cv2.cvtColor(opencv_image, cv2.COLOR_BGR2RGB)
  pil_image=Image.fromarray(color_coverted)
  return pil_image

def pil_to_opencv_image(pil_image):
  np_image=np.array(pil_image)  

  # convert to a openCV2 image, notice the COLOR_RGB2BGR which means that 
  # the color is converted from RGB to BGR format
  opencv_image=cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
  return opencv_image

def opencv_images_to_video(opencv_images,video_filename):
  height, width, layers = opencv_images[0].shape
  image_size = (width, height)

  out = cv2.VideoWriter(video_filename, cv2.VideoWriter_fourcc(*'MP4V'), Saving_frames_per_second, image_size)
  for frame in opencv_images:
    out.write(frame)
  out.release()

# **Fetching video and model state**

In [107]:
if not os.path.isfile(f"./{Test_video_file}"):
  url = f"https://github.com/ko-redtruck/facial-landmark-detection/raw/main/{Test_video_file}"
  wget.download(url, ".")

if not os.path.isfile(f"./{Model_state}"):
  url = f"https://github.com/ko-redtruck/facial-landmark-detection/raw/main/{Model_state}"
  wget.download(url, ".")

# **Defining model to load state to**

In [108]:
# Define parameter mappings
fc_layers = {
    "Lin-ReLu-Lin": nn.Sequential(
        nn.Linear(512,256),
        nn.ReLU(),
        nn.Linear(256,30)
    ),
    "Linear": nn.Linear(512, 30)
}

networks = {
    "ResNet18": add_fc(resnet18(pretrained=True), fc_layers[FC_LAYER])
}


net = networks[NET].to(device())

#Load Pre-Trained Model
net.load_state_dict(torch.load('/content/v6'))
net.eval()

  f"The parameter '{pretrained_param}' is deprecated since 0.13 and will be removed in 0.15, "


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

# **Predicting landmarks**

In [109]:
  
frames = [opencv_to_pil_image(opencv_image) for opencv_image in extract_frames_from(Test_video_file)]

transform_image = TF.CenterCrop(224)
frames = [transform_image(pil_image) for pil_image in frames]
#plot_faces(frames,None,50)  
#plot_predicted_facial_landmarks(*frames)
frames = add_predictions(*frames)
print(frames)

opencv_images_to_video([pil_to_opencv_image(pil_image) for pil_image in frames],'video_predicted.mp4')

(<PIL.Image.Image image mode=RGB size=224x224 at 0x7FDDF0FF9C10>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDDF10D47D0>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDDF0FCDBD0>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177A90>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E1770D0>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E1779D0>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177250>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177A10>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177950>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177910>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E1774D0>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177750>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177510>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177610>, <PIL.Image.Image image mode=RGB size=224x224 at 0x7FDE2E177B50>, <PIL.Image.Image image m