<div align="center">
<a href="http://camma.u-strasbg.fr/">
<img src="lib/camma_logo.png" width="18%">
</a>
</div>



Weakly-supervised ConvLSTM Surgical Tool Tracker
================
------
**A re-implementation of the surgical tool tracker in** :<br>
<i>Nwoye, C. I., Mutter, D., Marescaux, J., & Padoy, N. (2019). 
    Weakly supervised convolutional LSTM approach for tool tracking in laparoscopic videos. 
    International journal of computer assisted radiology and surgery, 14(6), 1059-1067.<br></i>
(c) Research Group CAMMA, University of Strasbourg, France<br>
Website: http://camma.u-strasbg.fr<br>
Code author: Chinedu Nwoye <br>
    
-----

The model is built using the `tf.contrib` lib. Hence, TensorFlow version > 1.15 is discouraged.

<br> Download code and libraries

In [1]:
# !git clone https://github.com/CAMMA-public/ConvLSTM-Surgical-Tool-Tracker.git
# %cd ConvLSTM-Surgical-Tool-Tracker

# print("Repo cloned and extracted ...")

%cd ConvLSTM-Surgical-Tool-Tracker

[WinError 2] The system cannot find the file specified: 'ConvLSTM-Surgical-Tool-Tracker'
c:\Users\liams\Documents\GitHub\ProjectSurgeryHernia\ConvLSTM-Surgical-Tool-Tracker


<br> Download sample video data

In [2]:
import requests
import os
import zipfile

url = 'https://s3.unistra.fr/camma_public/github/convlstm_tracker/data.zip'
filename = 'data.zip'
print("Downloading data file...")
# Check if data directory already exists with files
if os.path.exists('data') and any(os.listdir('data')):
    print("Data files already exist locally. Skipping download.")
else:
    response = requests.get(url, stream=True)
    with open(filename, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)

    print("Download completed. Extracting files...")
    with zipfile.ZipFile(filename, 'r') as zip_ref:
        zip_ref.extractall('.')
        
    print("Download and extraction completed ...")

Downloading data file...
Data files already exist locally. Skipping download.


<br>Download model weights

In [3]:
import os

print("Downloading model weights...")
response = requests.get('https://s3.unistra.fr/camma_public/github/convlstm_tracker/ckpt.zip', stream=True)
# Check if checkpoint directory already exists with files
if os.path.exists('ckpt') and any(file.endswith('.index') for file in os.listdir('ckpt') if os.path.isfile(os.path.join('ckpt', file))):
    print("Checkpoint files already exist locally. Skipping download.")
else:
    with open('ckpt.zip', 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)

    print("Download completed. Extracting files...")
    with zipfile.ZipFile('ckpt.zip', 'r') as zip_ref:
        zip_ref.extractall('.')

    print("Download and extraction completed ...")

Downloading model weights...
Checkpoint files already exist locally. Skipping download.


<br> Some important installationns

In [4]:
if 'google.colab' in str(get_ipython()):  # colab installs tf.2.2 on default.
    !pip uninstall -y tensorflow
    !pip install tensorflow-gpu==1.14
!pip install imageio
!pip install imageio-ffmpeg

print("Installations completed ...")

Installations completed ...


<br> Imports

In [5]:

from tf_compat import tf


import model
import os
import numpy as np
import cv2
import imageio
import sys
from matplotlib import animation, rc, pyplot as plt
plt.rcParams['animation.ffmpeg_path'] = '/usr/bin/ffmpeg'
from IPython.display import HTML
tf.logging.set_verbosity(tf.logging.ERROR)
%matplotlib inline

print("imports success...")

Instructions for updating:
non-resource variables are not supported in the long term
imports success...


<br> Variables & Device setup

In [6]:
img_height   = 480 #@param {type:"integer"}
img_width    = 854 #@param {type:"integer"}
img_channel  = 3   #@param {type:"integer"}
num_classes  = 7   #@param {type:"integer"}
offset_x     = 20  #@param {type:"integer"}
offset_y     = 11  #@param {type:"integer"}
data_path    = 'data/surgical_video.avi' #@param {type:"string"} you can modify this if you evaluate on a different video
ckpt_path    = 'ckpt' #@param {type:"string"}

print("Model and device variables set .. ")

Model and device variables set .. 


<br> Model architecture

In [7]:
tf.reset_default_graph()
with tf.device('/GPU:0'):  
    img_ph  = tf.placeholder(dtype=tf.float32, shape=[None,None,3], name='inputs')
    x       = tf.expand_dims(img_ph, 0)   
    x       = tf.image.resize_bilinear(x, size=(480,854))             
    seek_ph = tf.placeholder(dtype=tf.int64, shape=[None], name='inputs')
    network = model.Model(images=x, seek=seek_ph, num_classes=num_classes)
    logits, lhmaps  = network.build_model() 
    logits  = tf.cast(tf.round(tf.sigmoid(logits)), tf.int32)
    lhmaps  = lhmaps * tf.cast(logits, tf.float32)

print("Model loaded successfully...")

Model blocks:  [2, 2, 2, 2]
	Receiving image:: (1, 480, 854, 3)
Constructing ResNet backbone:
	Building units: conv1 -> (1, 120, 214, 64)
	Building unit: conv2_1: (1, 120, 214, 64)
	Building unit: conv2_2: (1, 120, 214, 64)
	Building unit: conv3_1: (1, 120, 214, 64)
	Building unit: conv3_2: (1, 60, 107, 128)
	Building unit: conv4_1: (1, 60, 107, 128)
	Building unit: conv4_2: (1, 60, 107, 256)
	Building unit: conv5_1: (1, 60, 107, 256)
	Building unit: conv5_2: (1, 60, 107, 512)
	Building units: ExtraNet/spatio-temporal/convlstm/convlstm: -> (1, 60, 107, 512)
	Building units: ExtraNet/FCN: -> (1, 60, 107, 7)
Model loaded successfully...


  kernel_initializer=tf.contrib.layers.variance_scaling_initializer())


<br> Saver and weights

In [8]:
with tf.name_scope("saver_and_writer"):                  
    saver = tf.train.Saver()  
    state = tf.train.get_checkpoint_state(ckpt_path)
    ckpt  = state.model_checkpoint_path

print('Loading checkpoint from :',ckpt)

Loading checkpoint from : ckpt\model.ckpt


<br> Evaluate on video dataset

In [9]:
from tqdm import tqdm

PREDICTIONS    = []
CLASS_LHMAPS   = []
reader         = imageio.get_reader(data_path)
gpu_opts = tf.GPUOptions(
    allow_growth              = True,   # grow memory instead of pre-allocating all
    per_process_gpu_memory_fraction = 0.95
)
sess_config = tf.ConfigProto(
    gpu_options          = gpu_opts,
    allow_soft_placement = False,       # hard-fail if an op can’t sit on GPU
    log_device_placement = False
)
with tf.Session(config=sess_config) as sess:   
    sess.run([tf.local_variables_initializer(), tf.global_variables_initializer()])
    saver.restore(sess, ckpt)
    for seek, frame in enumerate(tqdm(reader, desc="Processing video frames")):
        predict, lhmap = sess.run([logits, lhmaps], feed_dict={img_ph:frame, seek_ph:[seek]})
        PREDICTIONS.append(predict)
        CLASS_LHMAPS.append(lhmap)
        
print("Evaluation done...")   

InvalidArgumentError: Graph execution error:

Detected at node 'ExpandDims' defined at (most recent call last):
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\runpy.py", line 193, in _run_module_as_main
      "__main__", mod_spec)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\runpy.py", line 85, in _run_code
      exec(code, run_globals)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel_launcher.py", line 17, in <module>
      app.launch_new_instance()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\traitlets\config\application.py", line 1043, in launch_instance
      app.start()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\kernelapp.py", line 712, in start
      self.io_loop.start()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\tornado\platform\asyncio.py", line 215, in start
      self.asyncio_loop.run_forever()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\asyncio\base_events.py", line 541, in run_forever
      self._run_once()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\asyncio\base_events.py", line 1786, in _run_once
      handle._run()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\asyncio\events.py", line 88, in _run
      self._context.run(self._callback, *self._args)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\kernelbase.py", line 510, in dispatch_queue
      await self.process_one()
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\kernelbase.py", line 499, in process_one
      await dispatch(*args)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\kernelbase.py", line 406, in dispatch_shell
      await result
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\kernelbase.py", line 730, in execute_request
      reply_content = await reply_content
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\ipkernel.py", line 387, in do_execute
      cell_id=cell_id,
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\ipykernel\zmqshell.py", line 528, in run_cell
      return super().run_cell(*args, **kwargs)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\IPython\core\interactiveshell.py", line 2975, in run_cell
      raw_cell, store_history, silent, shell_futures, cell_id
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\IPython\core\interactiveshell.py", line 3029, in _run_cell
      return runner(coro)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\IPython\core\async_helpers.py", line 78, in _pseudo_sync_runner
      coro.send(None)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\IPython\core\interactiveshell.py", line 3257, in run_cell_async
      interactivity=interactivity, compiler=compiler, result=result)
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\IPython\core\interactiveshell.py", line 3472, in run_ast_nodes
      if (await self.run_code(code, result,  async_=asy)):
    File "c:\Users\liams\anaconda3\envs\Keras_env\lib\site-packages\IPython\core\interactiveshell.py", line 3552, in run_code
      exec(code_obj, self.user_global_ns, self.user_ns)
    File "C:\Users\liams\AppData\Local\Temp\ipykernel_10616\1984492305.py", line 4, in <module>
      x       = tf.expand_dims(img_ph, 0)
Node: 'ExpandDims'
Cannot assign a device for operation ExpandDims: {{node ExpandDims}} was explicitly assigned to /device:GPU:0 but available devices are [ /job:localhost/replica:0/task:0/device:CPU:0 ]. Make sure the device specification refers to a valid device. The requested device appears to be a GPU, but CUDA is not enabled.
	 [[ExpandDims]]

<br> Some visualization helper functions

In [None]:
# Get coordinates

def get_center_coordinates(lhmap):
    coord = np.where(lhmap == lhmap.max()) 
    cx    = (coord[1][0] * img_width // 107) + offset_x
    cy    = (coord[0][0] * img_height // 60) + offset_y
    return (cx, cy)

def get_box_coordinates(lhmap):
    coord = np.where(lhmap>0)
    if len(coord[0])>0 and len(coord[1])>0 :
        x0 = (coord[1].min() * img_width // 107) - offset_x
        x1 = (coord[1].max() * img_width // 107) + offset_x
        y0 = (coord[0].min() * img_height // 60) - offset_y
        y1 = (coord[0].max() * img_height // 60) + offset_y
    else:
        x0,x1,y0,y1 = -1,-1,-1,-1
    return (x0,y0,x1,y1)


# Build animators
def build_animators():
    BUFFER_BOX_CENTER = []
    colors    = [(255,0,0),(255,255,0),(0,0,255),(255,0,255),(255,128,0),(0,255,255),(0,255,0)] 
    radius    = 28
    thickness = 4
    reader    = imageio.get_reader(data_path)
    fig       = plt.figure()
    for k, (img, predict, lhmap) in enumerate(zip(reader, PREDICTIONS, CLASS_LHMAPS)):
        img_overlay     = img.copy()
        for i in range(num_classes):
            cam         = lhmap[0,:,:,i]
            x1,y1,x2,y2 = get_box_coordinates(cam)
            cx,cy       = get_center_coordinates(cam)
            color       = colors[i]
            cv2.rectangle(img_overlay, (x1,y1), (x2,y2), color, thickness)
            cv2.circle(img_overlay, (cx,cy), radius, color, -1)
        cv2.circle(img_overlay, (offset_x,offset_y), radius, (0,0,0), -1)
        BUFFER_BOX_CENTER.append([plt.imshow(img_overlay)])
    return fig, BUFFER_BOX_CENTER
        

# Colorizer
def cstr(s, color='black'):
    return "<text style=color:{}>{}</text>".format(color, s)

print("Model ready to track...")

<br>

#### Tracking the video
Build animator to display the tool trajectory (_Colormap displays the legend for the tracker_)

In [None]:
fig, OVERLAY = build_animators()

HTML('='*20+"> [  Tool Colormap:                                       "
           +cstr("Grasper", "red") +" | "+cstr("Bipolar", "yellow") +"  |  "+cstr("Hook", "blue")+"  |  "
           +cstr("Scissors", "violet")+"  |  " +cstr("Clipper", "orange") 
           +"  |  "+cstr("Irrigator", "mouve") +"  |  "+cstr("Specimen bag  ", "green")+'  ] <'+'='*20 )

<br> Let's track the instruments in the video<br>

In [None]:
anim = animation.ArtistAnimation(fig, OVERLAY, interval=160, blit=True, repeat_delay=1000)
HTML(anim.to_html5_video())

End