Skip to content

Commit

Permalink
merge ava updates
Browse files Browse the repository at this point in the history
  • Loading branch information
Gunnar Sigurdsson committed May 7, 2019
1 parent ded24bd commit 5211f56
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 25 deletions.
47 changes: 28 additions & 19 deletions datasets/ava_mp4.py
Expand Up @@ -131,25 +131,34 @@ def parse_ava_csv(filename, cls2int):
return labels

@classmethod
def get(cls, args):
def get(cls, args, splits=('train', 'val', 'val_video')):
train_file = args.train_file
val_file = args.val_file
train_dataset = cls(
args, args.data, 'train', train_file, args.cache,
transform=transforms.Compose([
videotransforms.ScaledCenterCrop(args.input_size),
]),
input_size=args.input_size)
val_dataset = cls(
args, args.valdata, 'val', val_file, args.cache,
transform=transforms.Compose([
videotransforms.ScaledCenterCrop(args.input_size),
]),
input_size=args.input_size)
valvideo_dataset = cls(
args, args.valdata, 'val_video', val_file, args.cache,
transform=transforms.Compose([
videotransforms.ScaledCenterCrop(args.input_size),
]),
input_size=args.input_size)
if 'train' in splits:
train_dataset = cls(
args, args.data, 'train', train_file, args.cache,
transform=transforms.Compose([
videotransforms.ScaledCenterCrop(args.input_size),
]),
input_size=args.input_size)
else:
train_dataset = None
if 'val' in splits:
val_dataset = cls(
args, args.valdata, 'val', val_file, args.cache,
transform=transforms.Compose([
videotransforms.ScaledCenterCrop(args.input_size),
]),
input_size=args.input_size)
else:
val_dataset = None
if 'val_video' in splits:
valvideo_dataset = cls(
args, args.valdata, 'val_video', val_file, args.cache,
transform=transforms.Compose([
videotransforms.ScaledCenterCrop(args.input_size),
]),
input_size=args.input_size)
else:
valvideo_dataset = None
return train_dataset, val_dataset, valvideo_dataset
14 changes: 8 additions & 6 deletions models/wrappers/maskrcnn_wrapper.py
Expand Up @@ -61,12 +61,14 @@ def __init__(self, basenet, args):
self.freeze_head = args.freeze_head

# for visualizing bounding boxes
this_dir = os.path.dirname(__file__)
lib_path = os.path.join(this_dir, '../../external/Detectron.pytorch/lib')
sys.path.insert(0, lib_path)
import utils.vis as vis_utils
self.vis_utils = vis_utils
sys.path.pop(0)
# this_dir = os.path.dirname(__file__)
# lib_path = os.path.join(this_dir, '../../external/Detectron.pytorch/lib')
# sys.path.insert(0, lib_path)
# import utils.vis as vis_utils
# self.vis_utils = vis_utils
# sys.path.pop(0)

# for full i3d model
#for i, end_point in enumerate(self.basenet.VALID_ENDPOINTS):
# if end_point == 'Mixed_4f': # first half should include Mixed_4f
# self.first_layers = self.basenet.VALID_ENDPOINTS[:i+1]
Expand Down

0 comments on commit 5211f56

Please sign in to comment.