# Feature Extraction for Graph-based Surgical Scene Understanding

### Project: Learning Domain Generaliazation with Graph Neural Network for Surgical Scene Understanding. Lab:MMLAB, National University of Singapore. Contributors: Lalith, Mobarak. 

## Features used by Graph Neural Network

In [4]:
# data = {}

# data['global_id'] = global_id                                       # image name
# data['img_name']     = global_id + '.jpg'                           # image name
# data['node_num']    = single_app_data['node_num'].value             # total node number

# data['roi_labels']      = single_app_data['classes'][:]             # node labels
# data['edge_labels'] = single_app_data['edge_labels'][:]             # edge  labels

# data['det_boxes'] = single_app_data['boxes'][:]                     # box
# data['roi_scores'] = single_app_data['scores'][:]                   # detection score

# data['edge_num']    = data['edge_labels'].shape[0]                  # edge number
# data['features']        = single_app_data['feature'][:]             # features
# data['spatial_feat'] = single_spatial_data[:]                       # spatial features

# data['word2vec']     = self._get_word2vec(data['roi_labels'])       # word2vec, from roi_labels

## Utils

In [1]:
import numpy as np


def center_offset(box1, box2, im_wh):
    '''
    '''
    c1 = [(box1[2]+box1[0])/2, (box1[3]+box1[1])/2]
    c2 = [(box2[2]+box2[0])/2, (box2[3]+box2[1])/2]
    offset = np.array(c1)-np.array(c2)/np.array(im_wh)
    return offset


def box_with_respect_to_img(box, im_wh):
    '''
        To get [x1/W, y1/H, x2/W, y2/H, A_box/A_img]
    '''
    # ipdb.set_trace()
    feats = [box[0]/(im_wh[0]+ 1e-6), box[1]/(im_wh[1]+ 1e-6), box[2]/(im_wh[0]+ 1e-6), box[3]/(im_wh[1]+ 1e-6)]
    box_area = (box[2]-box[0])*(box[3]-box[1])
    img_area = im_wh[0]*im_wh[1]
    feats +=[ box_area/(img_area+ 1e-6) ]
    return feats


def box1_with_respect_to_box2(box1, box2):
    '''
    '''
    feats = [ (box1[0]-box2[0])/(box2[2]-box2[0]+1e-6),
              (box1[1]-box2[1])/(box2[3]-box2[1]+ 1e-6),
              np.log((box1[2]-box1[0])/(box2[2]-box2[0]+ 1e-6)),
              np.log((box1[3]-box1[1])/(box2[3]-box2[1]+ 1e-6))   
            ]
    return feats


def calculate_spatial_feats(det_boxes, im_wh):
    '''
    stand-alone extract spatial features
    size = node x (node-1), 16 (5 + 5 + 4 + 2)
    '''
    spatial_feats = []
    for i in range(det_boxes.shape[0]):
        for j in range(det_boxes.shape[0]):
            if j == i: continue
            single_feat = []
            # features 5, 5, 4, 2
            box1_wrt_img = box_with_respect_to_img(det_boxes[i], im_wh)
            box2_wrt_img = box_with_respect_to_img(det_boxes[j], im_wh)
            box1_wrt_box2 = box1_with_respect_to_box2(det_boxes[i], det_boxes[j])
            offset = center_offset(det_boxes[i], det_boxes[j], im_wh)
            
            single_feat = single_feat + box1_wrt_img + box2_wrt_img + box1_wrt_box2 + offset.tolist()
            spatial_feats.append(single_feat)
    
    spatial_feats = np.array(spatial_feats)
    return spatial_feats

## Instrument Segmentation challange dataset

In [None]:
import os
import sys
import cv2
import h5py
import argparse

import torch
import torchvision.models

import numpy as np
from PIL import Image
from glob import glob
from models.resnet import *

if sys.version_info[0] == 2: import xml.etree.cElementTree as ET
else: import xml.etree.ElementTree as ET

# input data and IO folder location
mlist = [1,2,3,4,5,6,7,9,10,11,12,14,15,16]
dir_root_gt = '../datasets/instruments18/seq_'

xml_dir_list = []
for i in mlist:
    xml_dir_temp = dir_root_gt + str(i) + '/xml/'
    seq_list_each = glob(xml_dir_temp + '/*.xml')
    xml_dir_list = xml_dir_list + seq_list_each
    
# global variables
INSTRUMENT_CLASSES = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                      'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                      'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')

ACTION_CLASSES = (  'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 
                    'Tool_Manipulation', 'Cutting', 'Cauterization',
                    'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing')

transform = torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    ])

# arguments
parser = argparse.ArgumentParser(description='feature extractor')
parser.add_argument('--use_cbs',            type=bool,      default=True,        help='use CBS')
parser.add_argument('--std',                type=float,     default=1.0,         help='')
parser.add_argument('--std_factor',         type=float,     default=0.9,         help='')
parser.add_argument('--cbs_epoch',          type=int,       default=5,           help='')
parser.add_argument('--kernel_size',        type=int,       default=3,           help='')
parser.add_argument('--fil1',               type=str,       default='LOG',       help='gau, LOG')
parser.add_argument('--fil2',               type=str,       default='gau',       help='gau, LOG')
parser.add_argument('--fil3',               type=str,       default='gau',       help='gau, LOG')

# for 9 class
#parser.add_argument('--savedir',            type=str,       default='vsgat/resnet18_09_cbs_ls')
#parser.add_argument('--num_classes',        type=int,       default=9,           help='11')
#parser.add_argument('--modelpath',           type=str,       default='checkpoint/base/ResNet18_cbs_ls_0_012345678.pkl')

# for 11 class
parser.add_argument('--savedir',            type=str,       default='vsgat/resnet18_11_cbs_ts_test')
parser.add_argument('--num_classes',        type=int,       default=11,           help='11')
parser.add_argument('--modelpath',          type=str,       default='checkpoint/incremental/inc_ResNet18_cbs_ts_0_012345678910.pkl')

# vanilla
# parser.add_argument('--savedir',            type=str,       default='vsgat/resnet18_vanilla')
args = parser.parse_args(args=[])

# declare fearure extraction model
vanilla_model = False
if vanilla_model:
    feature_network = ResNet18_vanilla()
else:
    feature_network = ResNet18(args)

# Set data parallel based on GPU
num_gpu = torch.cuda.device_count()
if num_gpu > 0:
    device_ids = np.arange(num_gpu).tolist()
    feature_network = nn.DataParallel(feature_network, device_ids=device_ids)

# remove the last linear layer for feature extraction
if not vanilla_model:
    if args.use_cbs: feature_network.module.get_new_kernels(0)
    feature_network.load_state_dict(torch.load(args.modelpath))
    if args.use_cbs: feature_network = nn.Sequential(*list(feature_network.module.children())[:-2])
    else: feature_network = nn.Sequential(*list(feature_network.module.children())[:-1])

# Use Cuda
feature_network = feature_network.cuda()

        
for index, _xml_dir in  enumerate(xml_dir_list):
    img_name = os.path.basename(xml_dir_list[index][:-4])
    _img_dir = os.path.dirname(os.path.dirname(xml_dir_list[index])) + '/left_frames/' + img_name + '.png'
    save_data_path = os.path.join(os.path.dirname(os.path.dirname(xml_dir_list[index])),args.savedir)

    if not os.path.exists(save_data_path):
        os.makedirs(save_data_path)
    
    _xml = ET.parse(_xml_dir).getroot()
    
    det_classes = []
    act_classes = []
    det_boxes_all = []
    c_flag = False
    
    for obj in _xml.iter('objects'):
        # object name and interaction type
        name = obj.find('name').text.strip()
        interact = obj.find('interaction').text.strip()
        det_classes.append(INSTRUMENT_CLASSES.index(str(name)))
        act_classes.append(ACTION_CLASSES.index(str(interact)))
        
        # bounding box
        bndbox = []
        bbox = obj.find('bndbox') 
        for i, pt in enumerate(['xmin', 'ymin', 'xmax', 'ymax']):         
            bndbox.append(int(bbox.find(pt).text))
        det_boxes_all.append(np.array(bndbox))
        
    if c_flag: continue
        
    tissue_num = len(np.where(np.array(det_classes)==0)[0])
    node_num = len(det_classes)
    if tissue_num > 0: edges = np.cumsum(node_num - np.arange(tissue_num) -1)[-1]
    else: edges = 0

    # parse the original data to get node labels
    edge_labels = np.zeros((edges, len(ACTION_CLASSES)))
    edge_index = 0
    for tissue in range (tissue_num):
        for obj_index in range(tissue+1, node_num):
            edge_labels[edge_index, act_classes[tissue_num+edge_index]] = 1 
            edge_index += 1

    # node features
    node_features = np.zeros((node_num, 512))
    _img = Image.open(_img_dir).convert('RGB')
    _img = np.array(_img)
    for idx, bndbox in enumerate(det_boxes_all):
        roi = np.array(bndbox).astype(int)
        roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :]
        roi_image = transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR))
        roi_image = torch.autograd.Variable(roi_image.unsqueeze(0)).cuda()
        feature = feature_network(roi_image)
        feature = feature.view(feature.size(0), -1)
        node_features[idx] = feature.data.cpu().numpy()

    # spatial_features
    spatial_features = np.array(calculate_spatial_feats(np.array(det_boxes_all), [1024, 1280]))

    # save to file
    hdf5_file = h5py.File(os.path.join(save_data_path, '{}_features.hdf5'.format(img_name)),'w')
    hdf5_file.create_dataset('img_name', data=img_name)
    hdf5_file.create_dataset('node_num', data=node_num)
    hdf5_file.create_dataset('classes', data=det_classes)
    hdf5_file.create_dataset('boxes', data=det_boxes_all)
    hdf5_file.create_dataset('edge_labels', data=edge_labels)
    hdf5_file.create_dataset('node_features', data=node_features)
    hdf5_file.create_dataset('spatial_features', data=spatial_features)
    hdf5_file.close()
    print('edges', edge_labels.shape, 'node_feat', node_features.shape, 'spatial_feat', spatial_features.shape)

## SGT TORS Dataset

In [None]:
import os
import sys
import cv2
import h5py
import argparse

import torch
import torchvision.models

import numpy as np
from PIL import Image
from glob import glob

if sys.version_info[0] == 2: import xml.etree.cElementTree as ET
else: import xml.etree.ElementTree as ET

# input data and IO folder location
mlist = [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22]
dir_root_gt = '../datasets/SGH_dataset_2020/'

xml_dir_list = []
for i in mlist:
    xml_dir_temp = dir_root_gt + str(i) + '/xml/'
    seq_list_each = glob(xml_dir_temp + '/*.xml')
    xml_dir_list = xml_dir_list + seq_list_each
    
# global variables
INSTRUMENT_CLASSES = ('tissue', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                      'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                      'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery')

ACTION_CLASSES = (  'Idle', 'Grasping', 'Retraction', 'Tissue_Manipulation', 
                    'Tool_Manipulation', 'Cutting', 'Cauterization',
                    'Suction', 'Looping', 'Suturing', 'Clipping', 'Staple', 'Ultrasound_Sensing')

transform = torchvision.transforms.Compose([
                    torchvision.transforms.ToTensor(),
                    ])

# arguments
parser = argparse.ArgumentParser(description='feature extractor')
parser.add_argument('--use_cbs',            type=bool,      default=True,        help='use CBS')
parser.add_argument('--std',                type=float,     default=1.0,         help='')
parser.add_argument('--std_factor',         type=float,     default=0.9,         help='')
parser.add_argument('--cbs_epoch',          type=int,       default=5,           help='')
parser.add_argument('--kernel_size',        type=int,       default=3,           help='')
parser.add_argument('--fil1',               type=str,       default='LOG',       help='gau, LOG')
parser.add_argument('--fil2',               type=str,       default='gau',       help='gau, LOG')
parser.add_argument('--fil3',               type=str,       default='gau',       help='gau, LOG')

# for 9 class
#parser.add_argument('--savedir',            type=str,       default='vsgat/resnet18_09_cbs_ls')
#parser.add_argument('--num_classes',        type=int,       default=9,           help='11')
#parser.add_argument('--modelpath',           type=str,       default='checkpoint/base/ResNet18_cbs_ls_0_012345678.pkl')

# for 11 class
parser.add_argument('--savedir',            type=str,       default='vsgat/resnet18_11_cbs_ts')
parser.add_argument('--num_classes',        type=int,       default=11,           help='11')
parser.add_argument('--modelpath',          type=str,       default='checkpoint/incremental/inc_ResNet18_cbs_ts_0_012345678910.pkl')

# vanilla
#parser.add_argument('--savedir',            type=str,       default='vsgat/resnet18_vanilla')
args = parser.parse_args(args=[])
    
# Feature extraction network
vanilla_model = False
if vanilla_model:
    feature_network = ResNet18_vanilla()
else:
    feature_network = ResNet18(args)

# Use dataparallel for GPU
num_gpu = torch.cuda.device_count()
if num_gpu > 0:
    device_ids = np.arange(num_gpu).tolist()
    feature_network = nn.DataParallel(feature_network, device_ids=device_ids)
    
# Remove last layer from feature extraction network
if not vanilla_model:
    if args.use_cbs: feature_network.module.get_new_kernels(0)
    feature_network.load_state_dict(torch.load(args.modelpath))
    if args.use_cbs: feature_network = nn.Sequential(*list(feature_network.module.children())[:-2])
    else: feature_network = nn.Sequential(*list(feature_network.module.children())[:-1])

# Use Cuda
feature_network = feature_network.cuda()


for index, _xml_dir in  enumerate(xml_dir_list):
    img_name = os.path.basename(xml_dir_list[index][:-4])
    _img_dir = os.path.dirname(os.path.dirname(xml_dir_list[index])) + '/resized_frames/' + img_name + '.png'
    save_data_path = os.path.join(os.path.dirname(os.path.dirname(xml_dir_list[index])),args.savedir)
    if not os.path.exists(save_data_path):
        os.makedirs(save_data_path)
    
    _xml = ET.parse(_xml_dir).getroot()
    
    det_classes = []
    act_classes = []
    det_boxes_all = []
    c_flag = False
    
    for obj in _xml.iter('objects'):
        # object name and interaction type
        name = obj.find('name').text.strip()
        interact = obj.find('interaction').text.strip()
        det_classes.append(INSTRUMENT_CLASSES.index(str(name)))
        act_classes.append(ACTION_CLASSES.index(str(interact)))
        
        # bounding box
        bndbox = []
        bbox = obj.find('bndbox') 
        for i, pt in enumerate(['xmin', 'ymin', 'xmax', 'ymax']):         
            bndbox.append(int(bbox.find(pt).text))
        det_boxes_all.append(np.array(bndbox))
        
    if c_flag: continue
        
    tissue_num = len(np.where(np.array(det_classes)==0)[0])
    node_num = len(det_classes)
    if tissue_num > 0: edges = np.cumsum(node_num - np.arange(tissue_num) -1)[-1]
    else: edges = 0

    # parse the original data to get node labels
    edge_labels = np.zeros((edges, len(ACTION_CLASSES)))
    edge_index = 0
    for tissue in range (tissue_num):
        for obj_index in range(tissue+1, node_num):
            edge_labels[edge_index, act_classes[tissue_num+edge_index]] = 1 
            edge_index += 1

    # roi features extraction
    # node features
    node_features = np.zeros((node_num, 512))
    _img = Image.open(_img_dir).convert('RGB')
    _img = np.array(_img)
    for idx, bndbox in enumerate(det_boxes_all):
        roi = np.array(bndbox).astype(int)
        roi_image = _img[roi[1]:roi[3] + 1, roi[0]:roi[2] + 1, :]
        roi_image = transform(cv2.resize(roi_image, (224, 224), interpolation=cv2.INTER_LINEAR))
        roi_image = torch.autograd.Variable(roi_image.unsqueeze(0)).cuda()
        feature = feature_network(roi_image)
        feature = feature.view(feature.size(0), -1)
        node_features[idx] = feature.data.cpu().numpy()

    # spatial_features
    spatial_features = np.array(calculate_spatial_feats(np.array(det_boxes_all), [1024, 1280]))

    # save to file
    hdf5_file = h5py.File(os.path.join(save_data_path, '{}_features.hdf5'.format(img_name)),'w')
    hdf5_file.create_dataset('img_name', data=img_name)
    hdf5_file.create_dataset('node_num', data=node_num)
    hdf5_file.create_dataset('classes', data=det_classes)
    hdf5_file.create_dataset('boxes', data=det_boxes_all)
    hdf5_file.create_dataset('edge_labels', data=edge_labels)
    hdf5_file.create_dataset('node_features', data=node_features)
    hdf5_file.create_dataset('spatial_features', data=spatial_features)
    hdf5_file.close()
    print('edges', edge_labels.shape, 'node_feat', node_features.shape, 'spatial_feat', spatial_features.shape)

## Word2Vec Features

In [None]:
import os
import h5py
import gensim

#Load Google's pre-trained Word2Vec model.
model = gensim.models.KeyedVectors.load_word2vec_format('../datasets/word2vec/GoogleNews-vectors-negative300.bin', binary=True)  
original_keys = list(model.vocab.keys())
upper_keys = [str.upper(x) for x in original_keys]

# class names
INSTRUMENT_CLASSES = ('kidney', 'bipolar_forceps', 'prograsp_forceps', 'large_needle_driver',
                      'monopolar_curved_scissors', 'ultrasound_probe', 'suction', 'clip_applier',
                      'stapler', 'maryland_dissector', 'spatulated_monopolar_cautery', 'tissue')

instrument_class_to_w2v = ['kidney', 'bipolar', 'grasp', 'needle', 
                           'scissors', 'ultrasound', 'suction', 'clipper', 
                           'stapler', 'dissector', 'cautery','tissue']

hico_word2vec = os.path.join('../datasets/','surgicalscene_word2vec.hdf5')
file = h5py.File(hico_word2vec, 'w')

for i, name in enumerate(INSTRUMENT_CLASSES):
    print(name, ':', instrument_class_to_w2v[i])
    if name == '': continue
    else: 
        index = upper_keys.index(str.upper(instrument_class_to_w2v[i]))
        data = data=model[original_keys[index]]
        print(data.shape)
        file.create_dataset(name, data=model[original_keys[index]])
file.close()