# Install Dependencies

In [None]:
# Uninstall CoLab version of pytorch and install detectron2 compatiable version
!pip uninstall torch torchvision torchtext torchaudio -y
!pip install torch==1.10.1 torchvision==0.11.2 torchaudio==0.10.1

# Install pyyaml (version 5.1)
!pip install pyyaml==5.1

# Import torch
import torch
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]

# Ensure detectron2 instalation matches pytorch version
!pip install detectron2 -f https://dl.fbaipublicfiles.com/detectron2/wheels/$CUDA_VERSION/torch$TORCH_VERSION/index.html

## ⚠️  **Attention:** 
After the installation of detectron2 in the above cell, **you may need to restart the runtime in Colab**. You can do this by selecting `Runtime > Restart Runtime` from the task bar, or by running the command `exit(0)`

Once detectron has been successfully installed, you can continue to the next cells.

In [None]:
# Import detectron2
import detectron2
from detectron2.utils.logger import setup_logger

# Set up detectron2 logger
setup_logger()

# Import common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow

# Import detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

# Import custom utility functions
import sys

# Fine-tune COCO-pretrained R50-FPN Keypoint RCNN on dataset
from detectron2.engine import DefaultTrainer

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
import os
import sys
%load_ext autoreload
%autoreload 2

# Specify options

In [None]:
# Specify the location of the dataset folder
DATA_LOC = "/content/drive/MyDrive/Rhexis/datasets/test_pulls"

# Specify the location of the repo folder
REPO_LOC = "/content/drive/MyDrive/Stanford/rhexis-trajectory"

# Specify the path location of the best keypoint detection model checkpoint
BEST_MODEL_CHECKPOINT = "/content/drive/MyDrive/Stanford/Spring2022/RhexisProject/Rhexis/checkpoints_BEST/BEST/keypoint_rcnn_X_101_32x8d_FPN_3x_LR_0.001_OKS_SIGMAS_0.03,0.03_LRDECAY_NONEaugmentation_YES/model_final.pth"

# (OPTIONAL) Specify manual keypoint annotations to use instead of predicted
MANUAL_ANNOTATION_LOC = "/content/drive/MyDrive/Stanford/Spring2022/RhexisProject/Rhexis/datasets/manual_trajectories"

In [None]:
test_mode = False

In [None]:
sys.path.insert(0,f"{REPO_LOC}/Semantic_Segmentation")
sys.path.insert(0,f"{REPO_LOC}/Trajectory_Generation")
import trajectory_utils as tutil

In [None]:
subdir_list = None
if test_mode:
  # Complete label folder AC4_rhexis
  # great double instrument test SQ13_rhexis
  subdir_list = ['AC4_rhexis'] 
else:
  # Attendings: 'CataractCoach' 'SQ'
  # Seniors Residents: 'KY' 'AC'
  # Junior Residents: 'Medi' (these are long)
  substring_list = ['AC1','CataractCoach1'] #['Medi_08.18']

  subdir_list = []
  for substring in substring_list:
    subdir_list += tutil.get_folders_from_substrings(DATA_LOC, substring)

print("Code will run on these folders:")
print(subdir_list)

# Segmentation Generation

In [None]:
# This code only needs to run once
# Current progress: Started but did not finish 4
generate_labels = False

In [None]:
if generate_labels:
  base = os.getcwd()
  cwd = os.path.join(REPO_LOC,"Semantic_Segmentation")
  % cd $cwd
  from utils import *
  # split the list to go easy on RAM
  if test_mode:
    semantic_segmentation(DATA_LOC, subdir_list, task = 2, test_mode = test_mode, use_image_subdir = False)
  else:
    num_in_part = round(len(subdir_list)/3)
    k, m = divmod(len(subdir_list), num_in_part)
    part_list = [subdir_list[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(num_in_part)]

    for part in part_list:
      semantic_segmentation(DATA_LOC, part, task = 2, test_mode = test_mode, use_image_subdir = False)

  % cd $base

# Feature Extraction

In [None]:
generate_trajectories = True

In [None]:
if generate_trajectories:
  base = os.getcwd()
  os.chdir(REPO_LOC)
  import trajectory_generation_functions as tgf
  tgf.generate_trajectories(DATA_LOC, subdir_list, BEST_MODEL_CHECKPOINT, use_image_folder = False, MANUAL_ANNOTATION_LOC=MANUAL_ANNOTATION_LOC)
  % cd $base