Skip to content

Commit

Permalink
Adapt Scribble-OSVOS to DAVIS 2019 interactive challenge
Browse files Browse the repository at this point in the history
  • Loading branch information
kmaninis committed Apr 24, 2019
1 parent 1fccc4e commit 0126434
Show file tree
Hide file tree
Showing 8 changed files with 128 additions and 59,879 deletions.
23 changes: 14 additions & 9 deletions analyze_report.py
@@ -1,27 +1,32 @@
""" Analyse Global Summary
"""
import os
import json

import matplotlib.pyplot as plt
from mypath import Path

METRIC_TXT = {'J': 'J',
'F': 'F',
'J_AND_F': 'J&F'}


def main():
with open('results/summary.json', 'r') as fp:
summary = json.load(fp)
metric = 'J_AND_F'

with open(os.path.join(Path.save_root_dir(), 'summary.json'), 'r') as f:
summary = json.load(f)

print('AUC: \t{:.3f}'.format(summary['auc']))
th = summary['jaccard_at_threshold']['threshold']
jac = summary['jaccard_at_threshold']['jaccard']
th = summary['metric_at_threshold']['threshold']
jac = summary['metric_at_threshold'][metric]
print('J@{}: \t{:.3f}'.format(th, jac))

time = summary['curve']['time']
jaccard = summary['curve']['jaccard']
jaccard = summary['curve'][metric]

plt.plot(time, jaccard)
plt.ylim([0, 1])
plt.xlim([0, max(time)])
plt.xlabel('Accumulated Time (s)')
plt.ylabel(r'Jaccard ($\mathcal{J}$)')
plt.ylabel(r'$\mathcal{' + METRIC_TXT[metric] + '}$')
plt.axvline(th, c='k')
plt.show()

Expand Down
18 changes: 1 addition & 17 deletions dataloaders/custom_transforms.py
Expand Up @@ -127,7 +127,7 @@ def __init__(self, scribbles,
nocare_area=None,
bresenham=True,
use_previous_mask=False,
previous_mask_path='/media/eec/external/Databases/Segmentation/DAVIS-2017/Results/OSVOS-scribble-180-1'):
previous_mask_path=None):

self.scribbles = scribbles
self.dilation = dilation
Expand Down Expand Up @@ -203,22 +203,6 @@ def __call__(self, sample):

sample['scribble_gt'] = scr_gt
sample['scribble_void_pixels'] = scr_nocare
# from matplotlib import pyplot as plt
# f, ax_arr = plt.subplots(2, 2)
# ax_arr[0, 0].cla()
# ax_arr[0, 1].cla()
# ax_arr[1, 0].cla()
# ax_arr[1, 1].cla()
# ax_arr[0, 0].set_title("Input Image")
# ax_arr[0, 1].set_title("Ground Truth")
# ax_arr[1, 0].set_title("Nocare")
# ax_arr[1, 1].set_title("Negative")
# ax_arr[0, 0].imshow((sample['image'][:, :, ::-1] - sample['image'].min()) / (sample['image'].max() - sample['image'].min()))
# ax_arr[0, 1].imshow(scr_gt)
# ax_arr[1, 0].imshow(scr_nocare)
# ax_arr[1, 1].imshow(scr_gt_neg)
# plt.show()
# plt.close('all')

return sample

Expand Down
125 changes: 77 additions & 48 deletions demo_interactive.py
@@ -1,56 +1,85 @@
import os

import torch, cv2
import os
import timeit

from davisinteractive.session import DavisInteractiveSession
from davisinteractive import utils as interactive_utils
from davisinteractive.dataset import Davis

from osvos_scribble import OsvosScribble
from osvos_scribble import OSVOSScribble
from mypath import Path

# General parameters
gpu_id = 0

# Interactive parameters
max_nb_interactions = 5
max_time = None # Maximum time for each interaction
subset = 'val'
host = 'localhost' # 'localhost' for subsets train and val.

# OSVOS parameters
time_budget_per_object = 60
parent_model = 'osvos_parent.pth'
prev_mask = True # Use previous mask as no-care area when fine-tuning

save_model_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'models')
report_save_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'results')
save_result_dir = report_save_dir # 'None' to not save the results

model = OsvosScribble(parent_model, save_model_dir, gpu_id, time_budget_per_object, save_result_dir=save_result_dir)

seen_seq = {}
with DavisInteractiveSession(host='localhost', davis_root=Path.db_root_dir(), subset=subset,
report_save_dir=report_save_dir, max_nb_interactions=max_nb_interactions,
max_time=max_time) as sess:
while sess.next():
# Get the current iteration scribbles
sequence, scribbles, first_scribble = sess.get_scribbles()
if first_scribble:
n_interaction = 1
n_objects = Davis.dataset[sequence]['num_objects']
first_frame = interactive_utils.scribbles.annotated_frames(scribbles)[0]
seen_seq[sequence] = 1 if sequence not in seen_seq.keys() else seen_seq[sequence]+1
else:
n_interaction += 1
pred_masks = []
print('\nRunning sequence {} in interaction {} and scribble iteration {}'
.format(sequence, n_interaction, seen_seq[sequence]))
for obj_id in range(1, n_objects+1):
model.train(first_frame, n_interaction, obj_id, scribbles, seen_seq[sequence], subset=subset, use_previous_mask=prev_mask)
pred_masks.append(model.test(sequence, n_interaction, obj_id, subset=subset, scribble_iter=seen_seq[sequence]))

final_masks = interactive_utils.mask.combine_masks(pred_masks)

# Submit your prediction
sess.submit_masks(final_masks)

def main():
# General parameters
gpu_id = 1

# Configuration used in the challenges
max_nb_interactions = 8 # Maximum number of interactions
max_time_per_interaction = 30 # Maximum time per interaction per object

# Total time available to interact with a sequence and an initial set of scribbles
max_time = max_nb_interactions * max_time_per_interaction # Maximum time per object

# Interactive parameters
subset = 'val'
host = 'localhost' # 'localhost' for subsets train and val.

# OSVOS parameters
time_budget_per_object = 20
parent_model = 'osvos_parent.pth'
prev_mask = True # Use previous mask as no-care area when fine-tuning

save_model_dir = Path.models_dir()
report_save_dir = Path.save_root_dir()
save_result_dir = report_save_dir

model = OSVOSScribble(parent_model, save_model_dir, gpu_id, time_budget_per_object,
save_result_dir=save_result_dir)

seen_seq = {}
with DavisInteractiveSession(host=host,
davis_root=Path.db_root_dir(),
subset=subset,
report_save_dir=report_save_dir,
max_nb_interactions=max_nb_interactions,
max_time=max_time) as sess:
while sess.next():
t_total = timeit.default_timer()

# Get the current iteration scribbles
sequence, scribbles, first_scribble = sess.get_scribbles()
if first_scribble:
n_interaction = 1
n_objects = Davis.dataset[sequence]['num_objects']
first_frame = interactive_utils.scribbles.annotated_frames(scribbles)[0]
seen_seq[sequence] = 1 if sequence not in seen_seq.keys() else seen_seq[sequence]+1
else:
n_interaction += 1
pred_masks = []
print('\nRunning sequence {} in interaction {} and scribble iteration {}'
.format(sequence, n_interaction, seen_seq[sequence]))
for obj_id in range(1, n_objects+1):
model.train(first_frame, n_interaction, obj_id, scribbles, seen_seq[sequence],
subset=subset,
use_previous_mask=prev_mask)
pred_masks.append(model.test(sequence, n_interaction, obj_id,
subset=subset,
scribble_iter=seen_seq[sequence]))

final_masks = interactive_utils.mask.combine_masks(pred_masks)

# Submit your prediction
sess.submit_masks(final_masks)
t_end = timeit.default_timer()
print('Total time (training and testing) for single interaction: ' + str(t_end - t_total))

# Get the DataFrame report
report = sess.get_report()

# Get the global summary
summary = sess.get_global_summary(save_file=os.path.join(report_save_dir, 'summary.json'))


if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions mypath.py
@@ -1,11 +1,11 @@
class Path(object):
@staticmethod
def db_root_dir():
return '/path/to/DAVIS2017'
return '/path/to/DAVIS'

@staticmethod
def save_root_dir():
return './models'
return './results'

@staticmethod
def models_dir():
Expand Down
30 changes: 15 additions & 15 deletions networks/vgg_osvos.py
Expand Up @@ -74,21 +74,6 @@ def forward(self, x):
return side_out

def _initialize_weights(self, pretrained):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.001)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.zero_()
m.weight.data = interp_surgery(m)

if pretrained == 1:
print("Loading weights from PyTorch VGG")
vgg_structure = [64, 64, 'M', 128, 128, 'M', 256, 256, 256,
Expand Down Expand Up @@ -123,6 +108,21 @@ def _initialize_weights(self, pretrained):
assert (layer.data.shape == c_b.shape)
layer.data = c_b
caffe_ind += 1
else:
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0, 0.001)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
elif isinstance(m, nn.ConvTranspose2d):
m.weight.data.zero_()
m.weight.data = interp_surgery(m)


def find_conv_layers(_vgg):
Expand Down
44 changes: 18 additions & 26 deletions osvos_scribble.py
@@ -1,5 +1,6 @@
import os
import timeit
import copy

import torch
import torch.optim as optim
Expand All @@ -18,7 +19,7 @@
from layers.osvos_layers import class_balanced_cross_entropy_loss


class OsvosScribble(object):
class OSVOSScribble(object):
def __init__(self, parent_model, save_model_dir, gpu_id, time_budget, save_result_dir=None):
self.save_model_dir = save_model_dir
self.parent_model = parent_model
Expand All @@ -32,9 +33,11 @@ def __init__(self, parent_model, save_model_dir, gpu_id, time_budget, save_resul
self.meanval = (104.00699, 116.66877, 122.67892)
self.train_batch = 4
self.test_batch = 4
self.prev_models = {}
self.parent_model_state = torch.load(os.path.join(Path.models_dir(), self.parent_model),
map_location=lambda storage, loc: storage)

def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_iter, subset, use_previous_mask=False):
print('Training Network for obj_id={}'.format(obj_id))
nAveGrad = 1
num_workers = 4
train_batch = min(n_interaction, self.train_batch)
Expand All @@ -43,20 +46,19 @@ def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_ite
scribbles_list = scribbles_data['scribbles']
seq_name = scribbles_data['sequence']

if obj_id == 1 and n_interaction == 1:
self.prev_models = {}

# Network definition
save_dir = os.path.join(self.save_model_dir, seq_name)
if not os.path.exists(save_dir):
os.makedirs(os.path.join(save_dir))
if n_interaction == 1:
print('Loading weights from: {}'.format(self.parent_model))
self.net.load_state_dict(torch.load(os.path.join(Path.models_dir(), self.parent_model),
map_location=lambda storage, loc: storage))
self.net.load_state_dict(self.parent_model_state)
self.prev_models[obj_id] = None
else:
print('Loading weights from previous network: objId-{}_interaction-{}_scribble-{}.pth'
.format(obj_id, n_interaction-1, scribble_iter))
self.net.load_state_dict(torch.load(os.path.join(save_dir, 'objId-{}_interaction-{}_scribble-{}.pth'
.format(obj_id, n_interaction-1, scribble_iter)),
map_location=lambda storage, loc: storage))
self.net.load_state_dict(self.prev_models[obj_id])

lr = 1e-8
wd = 0.0002
optimizer = optim.SGD([
Expand Down Expand Up @@ -88,8 +90,6 @@ def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_ite
loss_tr = []
aveGrad = 0

# print("Start of Online Training, sequence: " + seq_name)
# iter_start_time = timeit.default_timer()
start_time = timeit.default_timer()
# Main Training and Testing Loop
epoch = 0
Expand Down Expand Up @@ -132,22 +132,16 @@ def train(self, first_frame, n_interaction, obj_id, scribbles_data, scribble_ite
optimizer.step()
optimizer.zero_grad()
aveGrad = 0
# iter_stop_time = timeit.default_timer()
# print("Iteration timing: {} seconds".format(str(iter_stop_time - iter_start_time)))

epoch += train_batch
stop_time = timeit.default_timer()
# iter_start_time = timeit.default_timer()
if stop_time - start_time > self.time_budget:
break

# Save the model
torch.save(self.net.state_dict(), os.path.join(save_dir, 'objId-{}_interaction-{}_scribble-{}.pth'
.format(obj_id, n_interaction, scribble_iter)))
stop_time = timeit.default_timer()
print('Online training time: ' + str(stop_time - start_time))
# Save the model into dictionary
self.prev_models[obj_id] = copy.deepcopy(self.net.state_dict())

def test(self, sequence, n_interaction, obj_id, subset, scribble_iter=0):
save_dir = os.path.join(self.save_model_dir, sequence)
if self.save_res_dir:
save_dir_res = os.path.join(self.save_res_dir, 'interaction-{}'.format(n_interaction),
'scribble-{}'.format(scribble_iter),
Expand All @@ -165,9 +159,7 @@ def test(self, sequence, n_interaction, obj_id, subset, scribble_iter=0):
print('Testing Network for obj_id={}'.format(obj_id))
print('Loading weights from objId-{}_interaction-{}_scribble-{}.pth'
.format(obj_id, n_interaction, scribble_iter))
self.net.load_state_dict(torch.load(os.path.join(save_dir, 'objId-{}_interaction-{}_scribble-{}.pth'
.format(obj_id, n_interaction, scribble_iter)),
map_location=lambda storage, loc: storage))

# Main Testing Loop
masks = []
for ii, sample_batched in enumerate(testloader):
Expand All @@ -179,10 +171,10 @@ def test(self, sequence, n_interaction, obj_id, subset, scribble_iter=0):
if self.gpu_id >= 0:
inputs, gts = inputs.cuda(), gts.cuda()

outputs = self.net.forward(inputs)
outputs = self.net.forward(inputs)[-1].cpu().data.numpy()

for jj in range(int(inputs.size()[0])):
pred = np.transpose(outputs[-1].cpu().data.numpy()[jj, :, :, :], (1, 2, 0))
pred = np.transpose(outputs[jj, :, :, :], (1, 2, 0))
pred = 1 / (1 + np.exp(-pred))
pred = np.squeeze(pred)

Expand Down

0 comments on commit 0126434

Please sign in to comment.