In [1]:
import tensorflow as tf

In [2]:
def make_dataset(episodes, config):
  example = episodes[next(iter(episodes.keys()))]
  types = {k: v.dtype for k, v in example.items()}
  shapes = {k: (None,) + v.shape[1:] for k, v in example.items()}
  generator = lambda: tools.sample_episodes(
      episodes, config.batch_length, config.oversample_ends)
  dataset = tf.data.Dataset.from_generator(generator, types, shapes)
  dataset = dataset.batch(config.batch_size, drop_remainder=True)
  dataset = dataset.prefetch(10)
  return dataset

In [3]:
class VideoFolder(tf.data.Dataset):

  # tf.data.Dataset.from_generator(generator, types, shapes)

  def __init__(self, args, root, json_file_input, json_file_labels, clip_size,
                 nclips, step_size, is_val, num_tasks=174, transform_pre=None, transform_post=None,
                 augmentation_mappings_json=None, augmentation_types_todo=None,
                 is_test=False, robot_demo_transform=None):
    self.num_tasks = num_tasks
    self.is_val = is_val

    # Gets data from json files?
    self.dataset_object = WebmDataset(args, json_file_input, json_file_labels,
                                      root, num_tasks=self.num_tasks, is_test=is_test, is_val=is_val)
        
    self.json_data = self.dataset_object.json_data # json data from the webm dataset
    self.classes = self.dataset_object.classes
    self.classes_dict = self.dataset_object.classes_dict
    self.root = root
    self.transform_pre = transform_pre
    self.transform_post = transform_post
    self.im_size = args.im_size
    self.batch_size = args.batch_size

    #self.augmentor = Augmentor(augmentation_mappings_json, augmentation_types_todo)

    self.traj_length = clip_size
    self.nclips = nclips
    self.step_size = step_size
    self.similarity = args.similarity
    self.add_demos = args.add_demos 
    if self.add_demos:
            self.robot_demo_transform = robot_demo_transform
            self.demo_batch_val = args.demo_batch_val
        
    # add keys to list called classes if they are not ints
    classes = []
    for key in self.classes_dict.keys():
        if not isinstance(key, int):
            classes.append(key)

    self.classes = classes
    num_occur = defaultdict(int)
    # make a dict with key = class, value = num_occurances
    for c in self.classes:
        for video in self.json_data:
            if video.label == c:
                num_occur[c] += 1

    # dump the occrance dict to a file
    if not self.is_val:
        with open(args.log_dir + '/human_data_tasks.txt', 'w') as f:
            json.dump(num_occur, f, indent=2)
    else:
        with open(args.log_dir + '/val_human_data_tasks.txt', 'w') as f:
            json.dump(num_occur, f, indent=2)
                
    # Every sample in batch: anchor (randomly selected class A), positive (randomly selected class A), 
    # and negative (randomly selected class not A)
    # Make dictionary for similarity triplets
    
    self.json_dict = defaultdict(list)
    for data in self.json_data:
        self.json_dict[data.label].append(data)

    # Make separate robot dictionary:
    self.robot_json_dict = defaultdict(list)
    self.total_robot = [] # all robot demos
    
    for data in self.json_data:
        if data.id == 300000: # robot video
            self.robot_json_dict[data.label].append(data)
            self.total_robot.append(data)
            
    print("Number of human videos: ", len(self.json_data), len(self.classes), "Total:", self.__len__())
        
    # Tasks used
    self.tasks = args.human_tasks
    if self.add_demos:
        self.robot_tasks = args.robot_tasks
    assert(sum(num_occur.values()) == len(self.json_data))        

  def process_video(self, item):
    # Open video file
    try: 
        reader = av.open(item.path)
    except:
        print("Issue with opening the video, path:", item.path)
        assert(False)

    try:
        imgs = []
        imgs = [f.to_rgb().to_ndarray() for f in reader.decode(video=0)]
    except (RuntimeError, ZeroDivisionError) as exception:
        print('{}: WEBM reader cannot open {}. Empty '
                  'list returned.'.format(type(exception).__name__, item.path))
    
    orig_imgs = np.array(imgs).copy() 
        
    target_idx = self.classes_dict[item.label] 
    # not sure what this does
    if not self.num_tasks == 174:
        target_idx = self.tasks.index(target_idx)
            
    # If robot demonstration
    # get trajectory length clips from video
    if self.add_demos and item.id == 300000: 
            imgs = self.robot_demo_transform(imgs)
            frame = random.randint(0, max(len(imgs) - self.traj_length, 0))
            length = min(self.traj_length, len(imgs))
            imgs = imgs[frame: length + frame]
            imgs_copy = tf.stack(imgs)
            imgs_copy = imgs_copy.permute(1, 0, 2, 3)
            return imgs_copy
        
    imgs = self.transform_pre(imgs)
    imgs, label = self.augmentor(imgs, item.label)
    imgs = self.transform_post(imgs)
        
    num_frames = len(imgs)        
    if self.nclips > -1:
        num_frames_necessary = self.traj_length * self.nclips * self.step_size
    else:
        num_frames_necessary = num_frames
    offset = 0
    if num_frames_necessary < num_frames:
        # If there are more frames, then sample starting offset.
        diff = (num_frames - num_frames_necessary)
        # temporal augmentation
        offset = np.random.randint(0, diff)

    imgs = imgs[offset: num_frames_necessary + offset: self.step_size]
    if len(imgs) < (self.traj_length * self.nclips):
        imgs.extend([imgs[-1]] *
                        ((self.traj_length * self.nclips) - len(imgs)))

    # format data to torch
    data = tf.stack(imgs)
    data = data.permute(1, 0, 2, 3)
    return data
    
            
    def __getitem__(self, index):
        """
        [!] FPS jittering doesn't work with AV dataloader as of now
        """
            
        if self.similarity:
            # Need triplet for each sample
            if self.add_demos and np.random.uniform(0.0, 1.0) < self.demo_batch_val:
                item = random.choice(self.total_robot)
            else:
                item = random.choice(self.json_data) 
            
            # Get random anchor
            # If adding demos, get 1/2 robot anchors for a more balanced batch
            if self.add_demos and (self.classes_dict[item.label] in self.robot_tasks) and (np.random.uniform(0.0, 1.0) < self.demo_batch_val): 
                anchor = random.choice(self.robot_json_dict[item.label])
            else:
                anchor = random.choice(self.json_dict[item.label])
            
            # Get negative 
            neg = random.choice(self.json_data)
            if self.add_demos and np.random.uniform(0.0, 1.0) < self.demo_batch_val: 
                neg = random.choice(self.total_robot)
            while neg.label == item.label:
                neg = random.choice(self.json_data)
                
            pos_data = self.process_video(item)  
            anchor_data  = self.process_video(anchor)
            neg_data = self.process_video(neg)

            # return teo clips per task
            return (pos_data, anchor_data, neg_data)
            
    def __len__(self):
        self.total_files = len(self.json_data)
        if self.similarity and not self.is_val and self.num_tasks <= 12:
            self.total_files = self.batch_size * 200 
        return self.total_file

In [4]:
train_data = VideoFolder(root='/iris/u/asc8/workspace/humans/Humans/20bn-something-something-v2-all-videos/',
                           json_file_input='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-train.json',
                           json_file_labels='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-labels.json',
                           clip_size=72,
                           nclips=1,
                           step_size=1,
                           is_val=False,
                           get_item_id=False)

TypeError: Can't instantiate abstract class VideoFolder with abstract methods _inputs, element_spec

In [None]:
train_data = VideoFolder(root='/iris/u/asc8/workspace/humans/Humans/20bn-something-something-v2-all-videos/',
                           json_file_input='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-train.json',
                           json_file_labels='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-labels.json',

In [5]:
import json 
f = open('/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-train.json')
data = json.load(f)
data 

[{'id': '78687',
  'label': 'holding potato next to vicks vaporub bottle',
  'template': 'Holding [something] next to [something]',
  'placeholders': ['potato', 'vicks vaporub bottle']},
 {'id': '42326',
  'label': 'spreading margarine onto bread',
  'template': 'Spreading [something] onto [something]',
  'placeholders': ['margarine', 'bread']},
 {'id': '100904',
  'label': 'putting pen on a surface',
  'template': 'Putting [something] on a surface',
  'placeholders': ['pen']},
 {'id': '80715',
  'label': 'lifting up one end of bottle, then letting it drop down',
  'template': 'Lifting up one end of [something], then letting it drop down',
  'placeholders': ['bottle']},
 {'id': '34899',
  'label': 'holding bulb',
  'template': 'Holding [something]',
  'placeholders': ['bulb']},
 {'id': '184568',
  'label': 'pushing strap camera from right to left',
  'template': 'Pushing [something] from right to left',
  'placeholders': ['strap camera']},
 {'id': '112783',
  'label': 'spilling mouthwa

In [29]:
class DatasetBase(object):
    """
    To read json data and construct a list containing video sample `ids`,
    `label` and `path`
    """
    def __init__(self, args, json_path_input, json_path_labels, data_root,
                 extension=None, num_tasks=None, is_test=False, is_val=False):
        self.num_tasks = num_tasks
        self.json_path_input = json_path_input
        self.json_path_labels = json_path_labels
        self.data_root = data_root
        self.extension = extension
        self.is_test = is_test
        self.is_val = is_val
#         self.just_robot = args.just_robot
#         self.sim_dir = args.sim_dir
        
        self.num_occur = defaultdict(int)
        
        self.tasks = [] # args.human_tasks
        self.add_demos = False #args.add_demos
        if self.add_demos:
            self.robot_tasks = [] #args.robot_tasks

        # preparing data and class dictionary
        self.classes = self.read_json_labels()
        self.classes_dict = self.get_two_way_dict(self.classes)
        self.json_data = self.read_json_input()
        print("Number of human videos:", self.num_occur.values())
        
        
    def read_json_input(self):
        json_data = []
        if not self.is_test:
            if not self.just_robot: #not self.triplet or not self.add_demos: #self.is_val or
                with open(self.json_path_input, 'rb') as jsonfile:
                    json_reader = json.load(jsonfile)
                    for elem in json_reader:
                        label = self.clean_template(elem['template'])
                        if label not in self.classes_dict.keys(): # or label == 'Pushing something so that it slightly moves':
                            continue
                        if label not in self.classes:
                            raise ValueError("Label mismatch! Please correct")
                        
                        label_num = self.classes_dict[label]
                        item = ListData(elem['id'],
                                        label,
                                        os.path.join(self.data_root,
                                                     elem['id'] + self.extension)
                                        )
                        json_data.append(item)
                        self.num_occur[label] += 1
            
            if self.add_demos: 
                # Add robot demonstrations or extra robot class to json_data, just use id 300000
                robot_tasks = self.robot_tasks
                root_in_dir = self.sim_dir 
                for label_num in robot_tasks: 
                    # add task demos for task label_num
                    in_dirs = [f'{root_in_dir}/env1/task{label_num}_webm', f'{root_in_dir}/env1_rearranged/task{label_num}_webm']
                        
                    for in_dir in in_dirs:
                        label = self.classes_dict[label_num]

                        num_demos = self.add_demos
                        self.num_occur[label] += num_demos
                        if not self.is_val: 
                            for j in range(num_demos):
                                item = ListData(300000,
                                            label,
                                            os.path.join(in_dir, str(j) + self.extension)
                                            )
                                json_data.append(item)
                        else:
                            for j in range(num_demos, int(1.4*num_demos)):
                                item = ListData(300000,
                                            label,
                                            os.path.join(in_dir, str(j) + self.extension)
                                            )
                                json_data.append(item)
                        

        else:
            with open(self.json_path_input, 'rb') as jsonfile:
                json_reader = json.load(jsonfile)
                for elem in json_reader:
                    # add a dummy label for all test samples
                    item = ListData(elem['id'],
                                    "Holding something",
                                    os.path.join(self.data_root,
                                                 elem['id'] + self.extension)
                                    )
                    json_data.append(item)
        return json_data

    def read_json_labels(self):
        classes = []
        with open(self.json_path_labels, 'rb') as jsonfile:
            json_reader = json.load(jsonfile)
            for elem in json_reader:
                classes.append(elem)
        return sorted(classes)

    def get_two_way_dict(self, classes):
        classes_dict = {} 
        tasks = self.tasks
        for i, item in enumerate(classes):
            if i not in tasks:
                continue
            classes_dict[item] = i
            classes_dict[i] = item
        print("Length of keys", len(classes_dict.keys()), classes_dict.keys())
        return classes_dict

    def clean_template(self, template):
        """ Replaces instances of `[something]` --> `something`"""
        template = template.replace("[", "")
        template = template.replace("]", "")
        return template


class WebmDataset(DatasetBase):
    def __init__(self, args, json_path_input, json_path_labels, data_root, num_tasks, 
                 is_test=False, is_val=False):
        EXTENSION = ".webm"
        super().__init__(args, json_path_input, json_path_labels, data_root,
                         EXTENSION, num_tasks, is_test, is_val)


In [None]:
self.dataset_object = WebmDataset(args, json_file_input, json_file_labels,
                                      root, num_tasks=self.num_tasks, is_test=is_test, is_val=is_val)

In [None]:
    args.im_size_x = int(args.im_size * 1.5)
    args.json_data_train = args.root + "something-something-v2-train.json"
    args.json_data_val = args.root + "something-something-v2-validation.json"
    args.json_data_test = args.root + "something-something-v2-test.json"
    args.json_file_labels = args.root + "something-something-v2-labels.json"

In [None]:
train_data = VideoFolder(root='/iris/u/asc8/workspace/humans/Humans/20bn-something-something-v2-all-videos/',
                           json_file_input='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-train.json',
                           json_file_labels='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-labels.json')


In [52]:
from collections import defaultdict, namedtuple
import os

ListData = namedtuple('ListData', ['id', 'label', 'path'])

def read_json_input(json_path_input, classes_dict, is_test=False, just_robot=False, 
                    data_root = "/iris/u/nivsiyer/temp", extension=".csv", sim_dir="/iris/u/nivsiyer/sim", add_demos=False, robot_tasks=[]):
        json_data = []
        num_occur = defaultdict(int)
        if not is_test:
            if not just_robot: #not self.triplet or not self.add_demos: #self.is_val or
                with open(json_path_input, 'rb') as jsonfile:
                    json_reader = json.load(jsonfile)
                    for elem in json_reader:
                        label = clean_template(elem['template'])
                        if label not in classes_dict.keys(): # or label == 'Pushing something so that it slightly moves':
                            continue
#                         if label not in classes:
#                             raise ValueError("Label mismatch! Please correct")
                        
                        label_num = classes_dict[label]
                        item = ListData(elem['id'],
                                        label,
                                        os.path.join(data_root,
                                                     elem['id'] + extension)
                                        )
                        json_data.append(item)
                        num_occur[label] += 1
            
            if add_demos: 
                # Add robot demonstrations or extra robot class to json_data, just use id 300000
                robot_tasks = robot_tasks
                root_in_dir = sim_dir 
                for label_num in robot_tasks: 
                    # add task demos for task label_num
                    in_dirs = [f'{root_in_dir}/env1/task{label_num}_webm', f'{root_in_dir}/env1_rearranged/task{label_num}_webm']
                        
                    for in_dir in in_dirs:
                        label = self.classes_dict[label_num]

                        num_demos = add_demos
                        num_occur[label] += num_demos
                        if not self.is_val: 
                            for j in range(num_demos):
                                item = ListData(300000,
                                            label,
                                            os.path.join(in_dir, str(j) + extension)
                                            )
                                json_data.append(item)
                        else:
                            for j in range(num_demos, int(1.4*num_demos)):
                                item = ListData(300000,
                                            label,
                                            os.path.join(in_dir, str(j) + extension)
                                            )
                                json_data.append(item)
                        

        else:
            with open(json_path_input, 'rb') as jsonfile:
                json_reader = json.load(jsonfile)
                for elem in json_reader:
                    # add a dummy label for all test samples
                    item = ListData(elem['id'],
                                    "Holding something",
                                    os.path.join(data_root,
                                                 elem['id'] + extension)
                                    )
                    json_data.append(item)
        return json_data
    
def read_json_labels(json_path_labels):
        classes = []
        with open(json_path_labels, 'rb') as jsonfile:
            json_reader = json.load(jsonfile)
            for elem in json_reader:
                classes.append(elem)
        return sorted(classes)

def get_two_way_dict(classes, tasks=[5, 41, 93]):
        classes_dict = {} 
        for i, item in enumerate(classes):
            if i not in tasks:
                continue
            classes_dict[item] = i
            classes_dict[i] = item
        print("Length of keys", len(classes_dict.keys()), classes_dict.keys())
        return classes_dict

def clean_template(template):
        """ Replaces instances of `[something]` --> `something`"""
        template = template.replace("[", "")
        template = template.replace("]", "")
        return template
    
class VideoFolder():

  # tf.data.Dataset.from_generator(generator, types, shapes)
  def __init__(self, root, json_file_input, json_file_labels, clip_size=None, args=None,
                 nclips=None, step_size=None, is_val=None, num_tasks=174, transform_pre=None, transform_post=None,
                 augmentation_mappings_json=None, augmentation_types_todo=None,
                 is_test=False, robot_demo_transform=None):
    self.num_tasks = num_tasks
    self.is_val = is_val
    self.classes = read_json_labels(json_file_labels)
    self.classes_dict = get_two_way_dict(self.classes)
    self.json_data = read_json_input(json_file_input, self.classes_dict)

  def process_video(self, path):
    # Open video file
    try: 
        reader = av.open(path)
    except:
        print("Issue with opening the video, path:", item.path)
        assert(False)

    try:
        imgs = []
        imgs = [f.to_rgb().to_ndarray() for f in reader.decode(video=0)]
    except (RuntimeError, ZeroDivisionError) as exception:
        print('{}: WEBM reader cannot open {}. Empty '
                  'list returned.'.format(type(exception).__name__, item.path))
    
    orig_imgs = np.array(imgs).copy() 
        
    target_idx = self.classes_dict[item.label] 
    # not sure what this does
    if not self.num_tasks == 174:
        target_idx = self.tasks.index(target_idx)
            
    # If robot demonstration
    # get trajectory length clips from video
    if self.add_demos and item.id == 300000: 
            imgs = self.robot_demo_transform(imgs)
            frame = random.randint(0, max(len(imgs) - self.traj_length, 0))
            length = min(self.traj_length, len(imgs))
            imgs = imgs[frame: length + frame]
            imgs_copy = tf.stack(imgs)
            imgs_copy = imgs_copy.permute(1, 0, 2, 3)
            return imgs_copy
        
#     imgs = self.transform_pre(imgs)
#     imgs, label = self.augmentor(imgs, item.label)
#     imgs = self.transform_post(imgs)
        
    num_frames = len(imgs)        
    if self.nclips > -1:
        num_frames_necessary = self.traj_length * self.nclips * self.step_size
    else:
        num_frames_necessary = num_frames
    offset = 0
    if num_frames_necessary < num_frames:
        # If there are more frames, then sample starting offset.
        diff = (num_frames - num_frames_necessary)
        # temporal augmentation
        offset = np.random.randint(0, diff)

    imgs = imgs[offset: num_frames_necessary + offset: self.step_size]
    if len(imgs) < (self.traj_length * self.nclips):
        imgs.extend([imgs[-1]] *
                        ((self.traj_length * self.nclips) - len(imgs)))

    data = tf.stack(imgs)
    data = data.permute(1, 0, 2, 3)
    return data
  
    
    
train_data = VideoFolder(root='/iris/u/asc8/workspace/humans/Humans/20bn-something-something-v2-all-videos/',
                           json_file_input='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-train.json',
                           json_file_labels='/iris/u/surajn/workspace/language_offline_rl/sthsth/something-something-v2-labels.json')
        

#     self.root = root
#     self.transform_pre = transform_pre
#     self.transform_post = transform_post
#     self.im_size = args.im_size
#     self.batch_size = args.batch_size

#     #self.augmentor = Augmentor(augmentation_mappings_json, augmentation_types_todo)

#     self.traj_length = clip_size
#     self.nclips = nclips
#     self.step_size = step_size
#     self.similarity = args.similarity
#     self.add_demos = args.add_demos 
#     if self.add_demos:
#             self.robot_demo_transform = robot_demo_transform
#             self.demo_batch_val = args.demo_batch_val
        
#     # add keys to list called classes if they are not ints
#     classes = []
#     for key in self.classes_dict.keys():
#         if not isinstance(key, int):
#             classes.append(key)

#     self.classes = classes
#     num_occur = defaultdict(int)
#     # make a dict with key = class, value = num_occurances
#     for c in self.classes:
#         for video in self.json_data:
#             if video.label == c:
#                 num_occur[c] += 1

#     # dump the occrance dict to a file
#     if not self.is_val:
#         with open(args.log_dir + '/human_data_tasks.txt', 'w') as f:
#             json.dump(num_occur, f, indent=2)
#     else:
#         with open(args.log_dir + '/val_human_data_tasks.txt', 'w') as f:
#             json.dump(num_occur, f, indent=2)
                
#     # Every sample in batch: anchor (randomly selected class A), positive (randomly selected class A), 
#     # and negative (randomly selected class not A)
#     # Make dictionary for similarity triplets
    
#     self.json_dict = defaultdict(list)
#     for data in self.json_data:
#         self.json_dict[data.label].append(data)

#     # Make separate robot dictionary:
#     self.robot_json_dict = defaultdict(list)
#     self.total_robot = [] # all robot demos
    
#     for data in self.json_data:
#         if data.id == 300000: # robot video
#             self.robot_json_dict[data.label].append(data)
#             self.total_robot.append(data)
            
#     print("Number of human videos: ", len(self.json_data), len(self.classes), "Total:", self.__len__())
        
#     # Tasks used
#     self.tasks = args.human_tasks
#     if self.add_demos:
#         self.robot_tasks = args.robot_tasks
#     assert(sum(num_occur.values()) == len(self.json_data))        

#   def process_video(self, item):
#     # Open video file
#     try: 
#         reader = av.open(item.path)
#     except:
#         print("Issue with opening the video, path:", item.path)
#         assert(False)

#     try:
#         imgs = []
#         imgs = [f.to_rgb().to_ndarray() for f in reader.decode(video=0)]
#     except (RuntimeError, ZeroDivisionError) as exception:
#         print('{}: WEBM reader cannot open {}. Empty '
#                   'list returned.'.format(type(exception).__name__, item.path))
    
#     orig_imgs = np.array(imgs).copy() 
        
#     target_idx = self.classes_dict[item.label] 
#     # not sure what this does
#     if not self.num_tasks == 174:
#         target_idx = self.tasks.index(target_idx)
            
#     # If robot demonstration
#     # get trajectory length clips from video
#     if self.add_demos and item.id == 300000: 
#             imgs = self.robot_demo_transform(imgs)
#             frame = random.randint(0, max(len(imgs) - self.traj_length, 0))
#             length = min(self.traj_length, len(imgs))
#             imgs = imgs[frame: length + frame]
#             imgs_copy = tf.stack(imgs)
#             imgs_copy = imgs_copy.permute(1, 0, 2, 3)
#             return imgs_copy
        
#     imgs = self.transform_pre(imgs)
#     imgs, label = self.augmentor(imgs, item.label)
#     imgs = self.transform_post(imgs)
        
#     num_frames = len(imgs)        
#     if self.nclips > -1:
#         num_frames_necessary = self.traj_length * self.nclips * self.step_size
#     else:
#         num_frames_necessary = num_frames
#     offset = 0
#     if num_frames_necessary < num_frames:
#         # If there are more frames, then sample starting offset.
#         diff = (num_frames - num_frames_necessary)
#         # temporal augmentation
#         offset = np.random.randint(0, diff)

#     imgs = imgs[offset: num_frames_necessary + offset: self.step_size]
#     if len(imgs) < (self.traj_length * self.nclips):
#         imgs.extend([imgs[-1]] *
#                         ((self.traj_length * self.nclips) - len(imgs)))

#     # format data to torch
#     data = tf.stack(imgs)
#     data = data.permute(1, 0, 2, 3)
#     return data
    
            
# #     def __getitem__(self, index):
# #         """
# #         [!] FPS jittering doesn't work with AV dataloader as of now
# #         """
            
# #         if self.similarity:
# #             # Need triplet for each sample
# #             if self.add_demos and np.random.uniform(0.0, 1.0) < self.demo_batch_val:
# #                 item = random.choice(self.total_robot)
# #             else:
#                 item = random.choice(self.json_data) 
            
#             # Get random anchor
#             # If adding demos, get 1/2 robot anchors for a more balanced batch
#             if self.add_demos and (self.classes_dict[item.label] in self.robot_tasks) and (np.random.uniform(0.0, 1.0) < self.demo_batch_val): 
#                 anchor = random.choice(self.robot_json_dict[item.label])
#             else:
#                 anchor = random.choice(self.json_dict[item.label])
            
#             # Get negative 
#             neg = random.choice(self.json_data)
#             if self.add_demos and np.random.uniform(0.0, 1.0) < self.demo_batch_val: 
#                 neg = random.choice(self.total_robot)
#             while neg.label == item.label:
#                 neg = random.choice(self.json_data)
                
#             pos_data = self.process_video(item)  
#             anchor_data  = self.process_video(anchor)
#             neg_data = self.process_video(neg)

#             # return teo clips per task
#             return (pos_data, anchor_data, neg_data)
            
#     def __len__(self):
#         self.total_files = len(self.json_data)
#         if self.similarity and not self.is_val and self.num_tasks <= 12:
#             self.total_files = self.batch_size * 200 
#         return self.total_file

Length of keys 6 dict_keys(['Closing something', 5, 'Moving something away from the camera', 41, 'Pushing something from left to right', 93])


In [47]:
train_data.json_data

[]

In [7]:
!pip install av

Collecting av
  Downloading av-9.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (28.8 MB)
[K     |████████████████████████████████| 28.8 MB 46 kB/s s eta 0:00:01
[?25hInstalling collected packages: av
Successfully installed av-9.2.0


In [8]:
import  av