# New to this MVP? 

### Watch my [Tutorial](https://youtu.be/WzGw5L0Y_N8) on YouTube before you start!

# Step No 1 - Make sure you have GPU enabled.

**Before you can with the application start:** Run the following code and make sure it lists GPU, not CPU or TPU. If it doesn't, go to "Runtime" -> "Change runetime type", select "GPU" and restart the notebook (should be done automatically).

In [None]:
!nvidia-smi

# Step No 2 - Initialize the application.

Next, you have to initiate the application. Therefore, hit the play-button below the following section-header. If you would like to debug the application or take a look into the magic, expand the section-header.

### Hit the play-button below this line.

#### Import Libraries

In [None]:
import pandas as pd
import os
import cv2 as cv
import numpy as np
import PIL
import io
import html
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import pickle
from torch.autograd import Variable
from google.colab import drive
from google.colab.output import eval_js
from google.colab import output
from base64 import b64decode
from tabulate import tabulate
from IPython.display import display, Javascript

os.system("""pip install mediapipe --quiet""")
import mediapipe as mp

os.system("""pip install alive-progress -q""")
from alive_progress import alive_bar

#### Assign device

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

#### Initialize functions

In [None]:
def initialize_model(n_way, k_shot, classes):
  """
  main function to initialize both model components based on a given n_way, k_shot scenario
  Args:
    n_way (int): Number of different classes to be classfied
    k_shot (int): Number of support samples per class
    classes (list): Names of classes to predict
  Returns:
    feature_encoder (LSTMEncoder-object): initialized LSTMEncoder-object
    relation_network (RelationNetwork-object): initialized RelationNetwork-object
  """

  # Check if classes-length match n_way
  if n_way != len(classes):
    print('Error: The size of your list of gestures does not match N_WAY. Please correct this and restart the process.')
    return None, None

  # Clone github repo where the model is stored
  succes = os.system("""git clone https://github.com/nielsschluesener/S-STRHanGe.git""")

  # define model_name and its path
  model_name = f'{n_way}way-{k_shot}shot'
  model_path = 'S-STRHanGe/deployment/models/'

  # get deployment parameters
  if os.path.isfile(os.path.join(model_path, f'{model_name}_deployment_param.pkl')):
    with open(os.path.join(model_path, f'{model_name}_deployment_param.pkl'), 'rb') as f:
      deployment_param = pickle.load(f)

  # create feature_encoder and relation_network objects with params given by deployment_params
  feature_encoder = LSTMEncoder(deployment_param['feature_length'], deployment_param['num_units_lstm_encoder'], deployment_param['num_lstm_layer_encoder']).to(device)
  relation_network = RelationNetwork(deployment_param['num_units_lstm_encoder']*2, deployment_param['num_units_lstm_relationnet'], deployment_param['num_units_fc_relu'], deployment_param['num_lstm_layer_relationnet']).to(device) 

  # load the trained weights
  feature_encoder.load_state_dict(torch.load(os.path.join(str(os.path.join(model_path, f'{model_name}_feature_encoder.pkl')))))
  relation_network.load_state_dict(torch.load(os.path.join(str(os.path.join(model_path, f'{model_name}_relation_network.pkl')))))

  return feature_encoder, relation_network


class LSTMEncoder(nn.Module):
  """First part of the model, which encodes the sequence of hand gesture keypoints"""
  def __init__(self, input_size, hidden_lstm_size, num_lstm_layer):
    """
    Object initialization.
    Args:
      input_size (int): size of the input data
      hidden_lstm_size (int): number of units in lstm-cell
      num_lstm_layer (int): number of stacked lstm-layers
    """
    super(LSTMEncoder, self).__init__()
    self.input_size = input_size 
    self.hidden_lstm_size = hidden_lstm_size 
    self.num_lstm_layer = num_lstm_layer

    #create lstm layer(s)
    self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_lstm_size, num_layers = num_lstm_layer, batch_first=True)

  def forward(self,x):
    """
    Forward pass through the LSTM Encoder
    Args:
      x (torch tensor): data input (sequence of hand keypoints as batch)
    Return:
      out (torch tensor): encoded data (encoded sequence of hand keypoints as batch) with size batch, sequence, hidden_size
    """

    #create zeros with size num_layer, sequences, hidden_size 
    #h0 containing the final hidden state for each element in the sequence
    h0 = torch.zeros(self.num_lstm_layer, x.size(0), self.hidden_lstm_size).to(device)  
    #c0 containing the final cell state for each element in the sequence
    c0 = torch.zeros(self.num_lstm_layer, x.size(0), self.hidden_lstm_size).to(device) 

    #pass data through the lstm layer(s)
    out, _ = self.lstm(x, (h0,c0))
    return out 


class RelationNetwork(nn.Module):
  """Second part of the model, which calculates the relation-scores, given the concatenated query and support samples"""
  def __init__(self,input_size, hidden_lstm_size, fc_sizes: list, num_lstm_layer):
    """
    Object initialization.
    Args:
      input_size (int): size of the input data
      hidden_lstm_size (int): number of units in lstm-cell
      fc_sizes (list): number of units per fully connected / linear layer
      num_lstm_layer (int): number of stacked lstm-layers
    """
    super(RelationNetwork, self).__init__()
    self.input_size = input_size 
    self.hidden_lstm_size = hidden_lstm_size
    self.layers = nn.ModuleList()
    self.num_lstm_layer = num_lstm_layer

    #create lstm layer(s)
    self.lstm = nn.LSTM(input_size = input_size, hidden_size = hidden_lstm_size, num_layers = num_lstm_layer, batch_first=True)
    #next input_size is outputsize of the lstm layer (hidden_lstm_size)
    input_size = hidden_lstm_size

    #loop over fc_sizes, create linear layer + ReLU Activation and adjust the input_size
    for size in fc_sizes:
      self.layers.append(nn.Linear(input_size, size))
      self.layers.append(nn.ReLU())
      input_size = size  
 
    #add a final Linear Layer which reduces the input down to a single number, aka. the relation score, by applieing a sigmoid function
    self.end_layer = nn.Sequential(nn.Linear(input_size,1), nn.Sigmoid())
    
  def forward(self,x):
    """
    Forward pass through the Relation Network
    Args:
      x (torch tensor): concatenated data input of query- and support-samples 
    Return:
      out (torch tensor): relation scores
    """
    
    #create zeros with size num_layer, sequences, hidden_size 
    #h0 containing the final hidden state for each element in the sequence 
    h0 = torch.zeros(self.num_lstm_layer, x.size(0), self.hidden_lstm_size).to(device) 
    #c0 containing the final cell state for each element in the sequence
    c0 = torch.zeros(self.num_lstm_layer, x.size(0), self.hidden_lstm_size).to(device) 

    #pass data through the lstm layer(s)
    out, _ = self.lstm(x, (h0,c0))
    #reduce output to the last sequence-state 
    out = out[:, -1, :]

    #pass data through the fc layer(s)
    for layer in self.layers:
      out = layer(out)
    
    #pass data through the final layer to calculate the relation score
    out = self.end_layer(out) 
    return out

def create_support_set(n_way, k_shot, classes):
  """
  main function to create a support set for a given n_way, k_shot scenario 
  Args:
    n_way (int): Number of different classes to be classfied
    k_shot (int): Number of support samples per class
  Returns:
    support_x (torch.tensor): keypoint data of the support sample
  """

  # Create supportset-process-logging
  indexes = []
  for i in range(1, K_SHOT+1):
    indexes.append(f'#{i}')
  supset_state = pd.DataFrame(np.empty([K_SHOT, N_WAY], dtype = str), index = indexes, columns = GESTURES)

  # Initialize supset dir
  supset_dir = 'support_set'
  support_x, support_y = initialize_supset_dir(n_way, k_shot, supset_dir)
  if support_x != None:
    return support_x, support_y

  # Loop over the classes and create k_shot-videos
  for n in range(1, n_way+1):
    for s in range(1, k_shot + 1):
      # print state and instructions
      supset_state.iloc[s-1,n-1] = '''This turn's gesture!'''
      output.clear()
      print('Current state your your support-set: \n')
      print(tabulate(supset_state, headers = supset_state.columns, tablefmt = 'simple', stralign="center"))
      print(f'''\nWait until the green countdown hits 0 and perform your gesture "{str(classes[n-1])}" into your webcam!\n''')
      time.sleep(5) if n == 1 and s == 1 else time.sleep(2)
      
      log = record_video(supset_dir + f'/class{n}_sample{s}/video/gesture.mp4') 
      if log == False:
        print('Video recording failed.')
        return None, None

      #Update supset_state
      supset_state.iloc[s-1,n-1] = '✓'

  output.clear()
  print('All videos have been captured. \n')
  print(tabulate(supset_state, headers = supset_state.columns, tablefmt = 'simple', stralign="center"))
  print('\nThe keypoints are now being extracted. This may take a few minutes dependent on the size of the support set. \n')

  # Loop over the videos and extract keypoints
  with alive_bar(n_way*k_shot, title=f'Extracting keypoints...', force_tty = True) as bar:
    for n in range(1, n_way+1):
      for s in range(1, k_shot + 1):
        extract_supset_frames_from_video(n, s, supset_dir)
        generate_supset_keypoints_from_frame(n, s, supset_dir)
        bar()

  # generate support_x and support_y from keypoints selected
  support_x, support_y = create_supset_from_keypoints(n_way, k_shot, supset_dir)

  print('\nAll keypoints have been extracted. Your support set is ready to go!')

  return support_x

def record_video(filename='video.mp4'):
  js = Javascript("""
    async function recordVideo() {

      const div = document.createElement('div');
      const capture = document.createElement('button');
      const stopCapture = document.createElement("button");

      // Define a output element for debugging purposes
      const errorLog = document.createElement("div");

      capture.textContent = "Recording starts in 3";
      capture.style.background = "green";
      capture.style.color = "white";

      stopCapture.textContent = "Recording ends in 3";
      stopCapture.style.background = "red";
      stopCapture.style.color = "white";
      div.appendChild(capture);

      const video = document.createElement('video');
      const recordingVid = document.createElement("video");
      video.style.display = 'block';

      // Disable audio capture and limit to video
      const stream = await navigator.mediaDevices.getUserMedia({audio: false, video: true });
      // create a media recorder instance, which is an object
      // that will let you record what you stream.

      const options = {
        videoBitsPerSecond : 2500000,
        mimeType: "video/webm;codecs=vp8" //9 
      }

      let recorder = new MediaRecorder(stream, options);
      document.body.appendChild(div);
      div.appendChild(video);

      // Append output element for debugging
      div.appendChild(errorLog);

      // Debugging only: Give back the videoBitsPerSecond of the option defined in the MediaRecorder constructor
      // errorLog.textContent = recorder.videoBitsPerSecond;  

      // Create a media element.  
      video.srcObject = stream;
      await video.play();

      // Resize the output to fit the video element.
      google.colab.output.setIframeHeight(document.documentElement.scrollHeight, true);
      
      // start recording
      // Defines the time left for the countdown; just for fun not necessary
      var recordingCountdown = 3; 
      var preCountdown = 3;

    // Pre Countdown Timer start ; just for fun not necessary
        var preCountdownTimer = setInterval(function(){ 
        preCountdown--; 
        capture.textContent = "Recording starts in "+ preCountdown; 
        // if preCountdown is <= 0 run recording function
        if(preCountdown <= 0){
            clearInterval(preCountdownTimer);
            startRecording();}
        },1000);
      // Pre Countdown Timer end ; just for fun not necessary

      // function which runs the recording
      function startRecording() {

        // switches the buttons to stopCapture one, just visually, 
        // stopCapture as no function anymore
        capture.replaceWith(stopCapture);

        // starts recording
        recorder.start();

        // Recording Countdown Timer start ; just for fun not necessary
        var recordingCountdownTimer = setInterval(function(){ 
        recordingCountdown--; 
        stopCapture.textContent = "Recording ends in "+recordingCountdown; 
        if(recordingCountdown <= 0){ 
            clearInterval(recordingCountdownTimer);}
        },1000);
      // Recording Countdown Timer end ; just for fun not necessary

      // sets delay/timeout of 3000ms to automatically stop recording
      setTimeout(event => {
          recorder.stop();
        }, 3000);
      };
      
      let recData = await new Promise((resolve) => recorder.ondataavailable = resolve);
      let arrBuff = await recData.data.arrayBuffer();
      
      // stop the stream and remove the video element
      stream.getVideoTracks()[0].stop();
      div.remove();

      let binaryString = "";
      let bytes = new Uint8Array(arrBuff);
      bytes.forEach((byte) => {
        binaryString += String.fromCharCode(byte);
      })
      return btoa(binaryString);
    }
    """)
  try:
    display(js)
    data = eval_js('recordVideo({})')
    binary = b64decode(data)
    with open(filename, "wb") as video_file:
      video_file.write(binary)
  except Exception as err:
      # In case any exceptions arise
      print(str(err))
      return False
  return filename

def extract_keypoints(results):
  """
  Extracts keypoints from a given mediapipe-Hands-result
  Args:
    results (mp-hands-object): result of a mediapipe-Hands hand-detection
  Returns:
    keypoints (np array): 63 keypoints + hand-indicator + classification-score or 65 zeros
  """
  # If the detection was succesfull, return 63 keypoints + hand-indicator and classification-score, else return 65 zeros
  if results.multi_hand_landmarks:
    landmarks = np.array([[res.x, res.y, res.z] for res in results.multi_hand_landmarks[0].landmark ]).flatten() if results.multi_hand_landmarks else np.zeros(63)
    return landmarks
  else:
    return np.zeros(63)

def detect_hand(image_path):
  """
  Runs the standard mediapipe-hand-detection model on a single image
  Args:
    image_path (str): path to an image
  Returns:
    result (mp-hands-object): result of a mediapipe-Hands hand-detection
  """
  #Run the mp_hands.Hands-Model with standard-Params (one Hand max, min_confidence of 0.5) on a flipped image (because webcam images are flipped aswell)
  with mp.solutions.hands.Hands(
      static_image_mode = True,
      max_num_hands = 1,
      min_detection_confidence = 0.5) as hands:
      image = cv.flip(cv.imread(image_path), 1)
      result = hands.process(cv.cvtColor(image, cv.COLOR_BGR2RGB))
  return result

def move_values(data, seq_length, direction = 'back'):
  """
  Move values in each sequence either towards the back of the sequence or towards the front.
  Args:
    data (np.array): Data instance, usually x
    seq_length (int): Length of sequence
    direction (str): target direction, either 'back' (values move towards the back) or 'front' (values move towarrds the front)
  Returns:
    x (np.array): Instance (instance of data) with values moved towards the intended direction
  """

  #Make sure the instance is not empty
  max_values = data.max(axis = 1)
  if np.array_equal(max_values,np.zeros(seq_length)) == True:
    return np.zeros((seq_length, 63))
  #Else, compute the index of the first keypoints and the last keypoints
  else:
    i_first_value = next(i for i,v in enumerate(max_values) if v != 0)
    i_last_value = next(i for i,v in reversed(list(enumerate(max_values))) if v != 0)

    #create temp datasets which holds the values to be moved and the array of zeros, with which the sequence is padded
    non_zeros = data[i_first_value:i_last_value+1]
    zeros = np.zeros((seq_length - non_zeros.shape[0], 63))

    #Move the values
    if direction == 'back':
      x = np.concatenate((zeros, non_zeros), axis=0)
    elif direction == 'front':
      x = np.concatenate((non_zeros, zeros), axis=0)
    
    return x

def initialize_supset_dir(n_way, k_shot, supset_dir):
  """
  Initialiazes the support_set directory
  Args:
    n_way (int): Number of different classes to be classfied
    k_shot (int): Number of support samples per class
    supset_dir (str): Target name of the support set directory
  """

  # Make sure the user does not delete the current supset_dir with its values unintentionally
  if os.path.exists(supset_dir):
    safety_check = input ("Caution: This step will delete the old support set. Do you want to continue? Type Yes or No and hit Enter.") 
    if safety_check == 'Yes':
      log = os.system("""rm -rf support_set""")
      output.clear()
    elif safety_check == 'No':
      support_x, support_y = create_supset_from_keypoints(n_way, k_shot, supset_dir)
      print('Procedure canceled. Old support set was preserved.')
      return support_x, support_y

  # Else, create support_set directory with sub_directories
  if not os.path.exists(supset_dir):
    os.makedirs(supset_dir)
    for n in range(1, n_way+1):
      for s in range(1, k_shot + 1):
        os.makedirs(supset_dir + f'/class{n}_sample{s}')
        os.makedirs(supset_dir + f'/class{n}_sample{s}/video')
        os.makedirs(supset_dir + f'/class{n}_sample{s}/frames')
        os.makedirs(supset_dir + f'/class{n}_sample{s}/keypoints')

  return None, None

def extract_supset_frames_from_video(n_way, k_shot, supset_dir):
  """
  Extracts frames from given support-set video, defined by its class and k'th-shot
  Args:
    n_way (int): class of the video
    k_shot (int): k'th-shot of the video
    supset_dir (str): Support set directory
  """

  # If video exists, turn video into jpgs
  if os.path.isfile(supset_dir + f'/class{n_way}_sample{k_shot}/video/gesture.mp4'):
    os.system(f''' ffmpeg -i support_set/class{n_way}_sample{k_shot}/video/gesture.mp4 -qscale:v 2 -vf "scale=360:240,fps=24" support_set/class{n_way}_sample{k_shot}/frames/frame_%02d.jpg ''')

def generate_supset_keypoints_from_frame(n_way, k_shot, supset_dir):
  """
  Generated and saves keypoints from given supset-frames, defined by its class and k'th-shot
  Args:
    n_way (int): class of the frames
    k_shot (int): k'th-shot of the frames
    supset_dir (str): Support set directory
  """
  num_frames = 72
  mp_hands = mp.solutions.hands

  # Loop over frames
  for frame_num in range(1, num_frames + 1):
  # Generate keypoints and save as npy
    if os.path.isfile(supset_dir + f'/class{n_way}_sample{k_shot}/frames/frame_{str(frame_num).zfill(2)}.jpg'):
      hand = detect_hand(supset_dir + f'/class{n_way}_sample{k_shot}/frames/frame_{str(frame_num).zfill(2)}.jpg')
      keypoints = extract_keypoints(hand)
      np.save(supset_dir + f'/class{n_way}_sample{k_shot}/keypoints/frame_{str(frame_num).zfill(2)}.npy', keypoints)

def create_supset_from_keypoints(n_ways, k_shots, supset_dir):
  """
  Collects and concats all keypoints from a given subset_dir
  Args:
    n_way (int): number of classes
    k_shot (int): number of 
    supset_dir (str): Support set directory
  Returns:
    support_x (torch.tensor): keypoint data of the support sample
    support_y (torch.tensor): labels of the support sample
  """

  num_frames = 72
  X, y = [], []

  # Loop over every supset-sample, and collect X and y data
  for n in range(1, n_ways+1):
    for s in range(1, k_shots + 1):
      sequence = []
      for frame in range (1, num_frames + 1):
        if os.path.isfile(supset_dir + f'/class{n}_sample{s}/keypoints/frame_{str(frame).zfill(2)}.npy'):
          keypoints = np.load(supset_dir + f'/class{n}_sample{s}/keypoints/frame_{str(frame).zfill(2)}.npy')
          sequence.append(keypoints)
        else:
          sequence.append(np.zeros(63))

      X.append(sequence)
      y.append(n)

  # create np.arrays for X and y and move values to the back 
  X = np.array(X)
  y = np.array(y)

  for i in range(len(X)):
    X[i] = move_values(X[i], num_frames)

  # turn X and y to torch tensors
  support_x, support_y  = torch.from_numpy(X).float(), torch.from_numpy(y).float().type(torch.LongTensor)

  return support_x, support_y

def perform_prediction(support_x, classes, feature_encoder, relation_network):
  """
  Main function for performing a prediction 
  Args:
    support_x (torch.tensor): keypoint data of the support sample
    classes (list): names of the classes, defined by the user
    feature_encoder (LSTMEncoder-object): initialized LSTMEncoder-object
    relation_network (RelationNetwork-object): initialized RelationNetwork-object
  """
  # load deployment_params
  model_name = f'{len(classes)}way-{int(len(support_x) / len(classes))}shot'
  model_path = 'S-STRHanGe/deployment/models/'
  if os.path.isfile(os.path.join(model_path, f'{model_name}_deployment_param.pkl')):
    with open(os.path.join(model_path, f'{model_name}_deployment_param.pkl'), 'rb') as f:
      deployment_param = pickle.load(f)

  # generate x-data
  x = create_prediction_data()

  print('Predicting... \n')
  time.sleep(1)
  # Calculate relations
  relations = calc_prediction(feature_encoder = feature_encoder, relation_network = relation_network, 
                              support_x = support_x, query_x = x, 
                              num_classes = deployment_param['num_classes'], support_num_per_class = deployment_param['support_num_per_class'], 
                              seq_length = deployment_param['sequence_length'], num_units_lstm_encoder = deployment_param['num_units_lstm_encoder'])

  print('Prediction finished! \n')

  # Output results
  print(f'The model predicts: {classes[torch.argmax(relations).item()]}\n')

  print('The calculated relation scores to all classes are:\n')
  print(tabulate(relations.tolist(), tablefmt = 'simple', headers=classes, numalign="right", floatfmt=".2f"))
  print('\nThe higher the score, the more similarity the model sees between your gesture and the gestures of that class in the support-set.')

  print('\nRerun the cell to perform another prediction!')

def create_prediction_data():
  """
  Creates the data neccesary for a prediction
  Returns:
    x (torch.tensor): keypoint data to predict
  """
  
  predict_dir = 'prediction'
  num_frames = 72

  # Create or empty prediction-directory 
  if os.path.exists(predict_dir):
    log = os.system("""rm -rf prediction""")
  
  os.makedirs(predict_dir)
  os.makedirs(predict_dir + f'/video')
  os.makedirs(predict_dir + f'/frames')
  os.makedirs(predict_dir + f'/keypoints')


  # Record the video
  print(f'''Wait until the green countdown hits 0 and perform the gesture you want the model to detect into your webcam!''')
  time.sleep(5) 
  log = record_video(predict_dir + '/video/gesture.mp4')
  if log == False:
    print('Video recording failed.')
    return None

  output.clear()
  print('Video has been captured. \n')

  # Turn video into frames
  if os.path.isfile(predict_dir + f'/video/gesture.mp4'):
    os.system(f''' ffmpeg -i {predict_dir}/video/gesture.mp4 -qscale:v 2 -vf "scale=360:240,fps=24" {predict_dir}/frames/frame_%02d.jpg ''')

  sequence = []
  #Generate keypoints from frames and save as npy
  with alive_bar(num_frames, title=f'Extracting keypoints...', force_tty = True) as bar:
    for frame_num in range(1, num_frames + 1):
      if os.path.isfile(predict_dir + f'/frames/frame_{str(frame_num).zfill(2)}.jpg'):
        hand = detect_hand(predict_dir + f'/frames/frame_{str(frame_num).zfill(2)}.jpg')
        keypoints = extract_keypoints(hand)
        np.save(predict_dir + f'/keypoints/frame_{str(frame_num).zfill(2)}.npy', keypoints)
        sequence.append(keypoints)      
      else:
        sequence.append(np.zeros(63))
      bar()
  print('\nKeypoints have been extracted. \n')

  # Generate x
  x = np.array(sequence)
  x = move_values(x, x.shape[0])
  x = np.expand_dims(x, axis = 0)
  x = torch.from_numpy(x).float()

  return x

def calc_prediction(feature_encoder, relation_network, support_x, query_x, num_classes, support_num_per_class, seq_length, num_units_lstm_encoder):
  """
  Prediction-Function 
  Args:
    feature_encoder (LSTMEncoder-object): initialized LSTMEncoder-object
    relation_network (RelationNetwork-object): initialized RelationNetwork-object
    support_x (torch tensor): data for support set
    support_y (torch tensor): labels for support set
    query_x (torch tensor): data for query set
    num_classes (int): number of classes in the given task
    support_num_per_class (int): number of support samples per class in support set
    seq_length (int): length of sequence
    num_units_lstm_encoder (int): number of hidden units in the LSTM-Cells of the feature_encoder
  Returns:
    relations (torch tensor): relation scores 
  """

  # calculate support features by passing the support set through the feature_encoder-module 
  support_features = feature_encoder(Variable(support_x).to(device))
  # calculate query features by passing the query set through the feature_encoder-module 
  query_features = feature_encoder(Variable(query_x).to(device))
          
  #Prepare encoded query and support set for relations calculation (reshaping, transposing)
  support_features = support_features.view(num_classes, support_num_per_class, seq_length, num_units_lstm_encoder)
  support_features = torch.sum(support_features, 1).squeeze(1)
  support_features_ext = support_features.unsqueeze(0).repeat(1,1,1,1).to(device)
  query_features_ext = query_features.unsqueeze(0).repeat(num_classes,1,1,1)
  query_features_ext = torch.transpose(query_features_ext,0,1).to(device)

  # concatenated encoded query & support set
  relation_pairs = torch.cat((support_features_ext,query_features_ext),3).view(-1,seq_length, num_units_lstm_encoder*2).to(device)

  # calculate relations by passing the concatenated query & support set through the relation_network-module 
  relations = relation_network(relation_pairs).view(-1,num_classes).to(device)

  return relations

# Step No 3 - Configurate your model.

As a third step, you have to configure the model.

Define the number of classes you want your model to distinguish (5 or 10) in **N_WAY** and the number of support samples per class you want to provide to your model in **K_SHOT** (1, 2 or 5).

Additionally, replace the placeholders in the list GESTURES with the classnames of your choice. Make sure they are comma-seperated and in between the quotation marks.

Run the cell.


In [None]:
N_WAY = 5
K_SHOT = 1

GESTURES = ["Replace_this_String_with_the_Name_of_your_Gesture_Number_1",
            "Replace_this_String_with_the_Name_of_your_Gesture_Number_2",
            "Replace_this_String_with_the_Name_of_your_Gesture_Number_3",
            "Replace_this_String_with_the_Name_of_your_Gesture_Number_4",
            "Replace_this_String_with_the_Name_of_your_Gesture_Number_5"] 

feature_encoder, relation_network = initialize_model(N_WAY, K_SHOT, GESTURES)

# Step No 4 - Create the model's support set.

In order to perform predictions, the model needs some samples (defined by K_SHOT) of each class. In order to do this, run the following cell.

The application will loop over every class and ask you to perform the class's hand gesture K_SHOT times. Further details will be explained during runtime. Afterwards, the data gets processed, which may take a while.

In [None]:
support_set = create_support_set(N_WAY, K_SHOT, GESTURES)

# Step No 5 - Perform predictions!

Everything is set up! In order to perform a prediction, run the following cell and follow the instructions.

In [None]:
perform_prediction(support_set, GESTURES, feature_encoder, relation_network)