| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,71 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import argparse | ||
|
|
||
| import multiprocessing as mp | ||
|
|
||
| import pprint | ||
| import yaml | ||
|
|
||
| from app.scaffold import main as app_main | ||
| from src.utils.distributed import init_distributed | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| '--fname', type=str, | ||
| help='name of config file to load', | ||
| default='configs.yaml') | ||
| parser.add_argument( | ||
| '--devices', type=str, nargs='+', default=['cuda:0'], | ||
| help='which devices to use on local machine') | ||
|
|
||
|
|
||
| def process_main(rank, fname, world_size, devices): | ||
| import os | ||
| os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) | ||
|
|
||
| import logging | ||
| from src.utils.logging import get_logger | ||
| logger = get_logger(force=True) | ||
| if rank == 0: | ||
| logger.setLevel(logging.INFO) | ||
| else: | ||
| logger.setLevel(logging.ERROR) | ||
|
|
||
| logger.info(f'called-params {fname}') | ||
|
|
||
| # Load config | ||
| params = None | ||
| with open(fname, 'r') as y_file: | ||
| params = yaml.load(y_file, Loader=yaml.FullLoader) | ||
| logger.info('loaded params...') | ||
|
|
||
| # Log config | ||
| if rank == 0: | ||
| pprint.PrettyPrinter(indent=4).pprint(params) | ||
| dump = os.path.join(params['logging']['folder'], 'params-pretrain.yaml') | ||
| with open(dump, 'w') as f: | ||
| yaml.dump(params, f) | ||
|
|
||
| # Init distributed (access to comm between GPUS on same machine) | ||
| world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) | ||
| logger.info(f'Running... (rank: {rank}/{world_size})') | ||
|
|
||
| # Launch the app with loaded config | ||
| app_main(params['app'], args=params) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| args = parser.parse_args() | ||
| num_gpus = len(args.devices) | ||
| mp.set_start_method('spawn') | ||
| for rank in range(num_gpus): | ||
| mp.Process( | ||
| target=process_main, | ||
| args=(rank, args.fname, num_gpus, args.devices) | ||
| ).start() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import argparse | ||
| import os | ||
| import pprint | ||
| import yaml | ||
|
|
||
| import submitit | ||
|
|
||
| from app.scaffold import main as app_main | ||
| from src.utils.logging import get_logger | ||
|
|
||
| logger = get_logger(force=True) | ||
|
|
||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| '--folder', type=str, | ||
| help='location to save submitit logs', | ||
| default='/fsx-jepa/massran/submitit/') | ||
| parser.add_argument( | ||
| '--exclude', type=str, | ||
| help='nodes to exclude from training', | ||
| default=None) | ||
| parser.add_argument( | ||
| '--batch-launch', action='store_true', | ||
| help='whether fname points to a file to batch-lauch several config files') | ||
| parser.add_argument( | ||
| '--fname', type=str, | ||
| help='yaml file containing config file names to launch', | ||
| default='configs.yaml') | ||
| parser.add_argument( | ||
| '--partition', type=str, | ||
| help='cluster partition to submit jobs on') | ||
| parser.add_argument( | ||
| '--time', type=int, default=4300, | ||
| help='time in minutes to run job') | ||
|
|
||
|
|
||
| class Trainer: | ||
|
|
||
| def __init__(self, args_pretrain, load_model=None): | ||
| self.app = args_pretrain['app'] | ||
| self.args_pretrain = args_pretrain | ||
| self.load_model = load_model | ||
|
|
||
| def __call__(self): | ||
| app = self.app | ||
| params = self.args_pretrain | ||
| load_model = self.load_model | ||
|
|
||
| logger.info('loaded pretrain params...') | ||
| pp = pprint.PrettyPrinter(indent=4) | ||
| pp.pprint(params) | ||
|
|
||
| # Launch app with loaded config | ||
| resume_preempt = False if load_model is None else load_model | ||
| app_main(app, args=params, resume_preempt=resume_preempt) | ||
|
|
||
| def checkpoint(self): | ||
| fb_trainer = Trainer(self.args_pretrain, True) | ||
| return submitit.helpers.DelayedSubmission(fb_trainer,) | ||
|
|
||
|
|
||
| def launch_app_with_parsed_args( | ||
| args_for_pretrain, | ||
| submitit_folder, | ||
| partition, | ||
| timeout=4300, | ||
| nodes=1, | ||
| tasks_per_node=1, | ||
| exclude_nodes=None | ||
| ): | ||
| executor = submitit.AutoExecutor( | ||
| folder=os.path.join(submitit_folder, 'job_%j'), | ||
| slurm_max_num_timeout=20) | ||
| executor.update_parameters( | ||
| slurm_partition=partition, | ||
| slurm_mem_per_gpu='55G', | ||
| timeout_min=timeout, | ||
| nodes=nodes, | ||
| tasks_per_node=tasks_per_node, | ||
| cpus_per_task=12, | ||
| gpus_per_node=tasks_per_node) | ||
|
|
||
| if args.exclude is not None: | ||
| executor.update_parameters(slurm_exclude=args.exclude) | ||
|
|
||
| jobs, trainers = [], [] | ||
| with executor.batch(): | ||
| for ap in args_for_pretrain: | ||
| fb_trainer = Trainer(ap) | ||
| job = executor.submit(fb_trainer,) | ||
| trainers.append(fb_trainer) | ||
| jobs.append(job) | ||
|
|
||
| for job in jobs: | ||
| print(job.job_id) | ||
|
|
||
|
|
||
| def launch(): | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # 1. Put config file names in a list | ||
| # ---------------------------------------------------------------------- # | ||
| config_fnames = [args.fname] | ||
|
|
||
| # -- If batch-launch is True, then the args.fname yaml file is not a | ||
| # -- config, but actually specifies a list of other config files | ||
| # -- to run in a slurm job array | ||
| if args.batch_launch: | ||
| with open(args.fname, 'r') as y_file: | ||
| config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) | ||
| # ---------------------------------------------------------------------- # | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # 2. Parse each yaml config file as a dict and place in list | ||
| # ---------------------------------------------------------------------- # | ||
| nodes, tasks_per_node = None, None | ||
| configs = [] | ||
| for f in config_fnames: | ||
| with open(f, 'r') as y_file: | ||
| _params = yaml.load(y_file, Loader=yaml.FullLoader) | ||
| nodes = int(_params.get('nodes')) | ||
| tasks_per_node = int(_params.get('tasks_per_node')) | ||
| configs += [_params] | ||
| logger.info(f'Loaded {len(configs)} config files') | ||
| logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') | ||
| # ---------------------------------------------------------------------- # | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # 3. Launch evals with parsed config files | ||
| # ---------------------------------------------------------------------- # | ||
| launch_app_with_parsed_args( | ||
| args_for_pretrain=configs, | ||
| submitit_folder=args.folder, | ||
| partition=args.partition, | ||
| timeout=args.time, | ||
| nodes=nodes, | ||
| tasks_per_node=tasks_per_node, | ||
| exclude_nodes=args.exclude) | ||
| # ---------------------------------------------------------------------- # | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| args = parser.parse_args() | ||
| launch() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import importlib | ||
| import logging | ||
| import sys | ||
|
|
||
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | ||
| logger = logging.getLogger() | ||
|
|
||
|
|
||
| def main(app, args, resume_preempt=False): | ||
|
|
||
| logger.info(f'Running pre-training of app: {app}') | ||
| return importlib.import_module(f'app.{app}.train').main( | ||
| args=args, | ||
| resume_preempt=resume_preempt) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import torch | ||
| import torchvision.transforms as transforms | ||
|
|
||
| import src.datasets.utils.video.transforms as video_transforms | ||
| from src.datasets.utils.video.randerase import RandomErasing | ||
|
|
||
|
|
||
| def make_transforms( | ||
| random_horizontal_flip=True, | ||
| random_resize_aspect_ratio=(3/4, 4/3), | ||
| random_resize_scale=(0.3, 1.0), | ||
| reprob=0.0, | ||
| auto_augment=False, | ||
| motion_shift=False, | ||
| crop_size=224, | ||
| normalize=((0.485, 0.456, 0.406), | ||
| (0.229, 0.224, 0.225)) | ||
| ): | ||
|
|
||
| _frames_augmentation = VideoTransform( | ||
| random_horizontal_flip=random_horizontal_flip, | ||
| random_resize_aspect_ratio=random_resize_aspect_ratio, | ||
| random_resize_scale=random_resize_scale, | ||
| reprob=reprob, | ||
| auto_augment=auto_augment, | ||
| motion_shift=motion_shift, | ||
| crop_size=crop_size, | ||
| normalize=normalize, | ||
| ) | ||
| return _frames_augmentation | ||
|
|
||
|
|
||
| class VideoTransform(object): | ||
|
|
||
| def __init__( | ||
| self, | ||
| random_horizontal_flip=True, | ||
| random_resize_aspect_ratio=(3/4, 4/3), | ||
| random_resize_scale=(0.3, 1.0), | ||
| reprob=0.0, | ||
| auto_augment=False, | ||
| motion_shift=False, | ||
| crop_size=224, | ||
| normalize=((0.485, 0.456, 0.406), | ||
| (0.229, 0.224, 0.225)) | ||
| ): | ||
|
|
||
| self.random_horizontal_flip = random_horizontal_flip | ||
| self.random_resize_aspect_ratio = random_resize_aspect_ratio | ||
| self.random_resize_scale = random_resize_scale | ||
| self.auto_augment = auto_augment | ||
| self.motion_shift = motion_shift | ||
| self.crop_size = crop_size | ||
| self.mean = torch.tensor(normalize[0], dtype=torch.float32) | ||
| self.std = torch.tensor(normalize[1], dtype=torch.float32) | ||
| if not self.auto_augment: | ||
| # Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255. | ||
| self.mean *= 255. | ||
| self.std *= 255. | ||
|
|
||
| self.autoaug_transform = video_transforms.create_random_augment( | ||
| input_size=(crop_size, crop_size), | ||
| auto_augment='rand-m7-n4-mstd0.5-inc1', | ||
| interpolation='bicubic', | ||
| ) | ||
|
|
||
| self.spatial_transform = video_transforms.random_resized_crop_with_shift \ | ||
| if motion_shift else video_transforms.random_resized_crop | ||
|
|
||
| self.reprob = reprob | ||
| self.erase_transform = RandomErasing( | ||
| reprob, | ||
| mode='pixel', | ||
| max_count=1, | ||
| num_splits=1, | ||
| device='cpu', | ||
| ) | ||
|
|
||
| def __call__(self, buffer): | ||
|
|
||
| if self.auto_augment: | ||
| buffer = [transforms.ToPILImage()(frame) for frame in buffer] | ||
| buffer = self.autoaug_transform(buffer) | ||
| buffer = [transforms.ToTensor()(img) for img in buffer] | ||
| buffer = torch.stack(buffer) # T C H W | ||
| buffer = buffer.permute(0, 2, 3, 1) # T H W C | ||
| else: | ||
| buffer = torch.tensor(buffer, dtype=torch.float32) | ||
|
|
||
| buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W | ||
|
|
||
| buffer = self.spatial_transform( | ||
| images=buffer, | ||
| target_height=self.crop_size, | ||
| target_width=self.crop_size, | ||
| scale=self.random_resize_scale, | ||
| ratio=self.random_resize_aspect_ratio, | ||
| ) | ||
| if self.random_horizontal_flip: | ||
| buffer, _ = video_transforms.horizontal_flip(0.5, buffer) | ||
|
|
||
| buffer = _tensor_normalize_inplace(buffer, self.mean, self.std) | ||
| if self.reprob > 0: | ||
| buffer = buffer.permute(1, 0, 2, 3) | ||
| buffer = self.erase_transform(buffer) | ||
| buffer = buffer.permute(1, 0, 2, 3) | ||
|
|
||
| return buffer | ||
|
|
||
|
|
||
| def tensor_normalize(tensor, mean, std): | ||
| """ | ||
| Normalize a given tensor by subtracting the mean and dividing the std. | ||
| Args: | ||
| tensor (tensor): tensor to normalize. | ||
| mean (tensor or list): mean value to subtract. | ||
| std (tensor or list): std to divide. | ||
| """ | ||
| if tensor.dtype == torch.uint8: | ||
| tensor = tensor.float() | ||
| tensor = tensor / 255.0 | ||
| if type(mean) == list: | ||
| mean = torch.tensor(mean) | ||
| if type(std) == list: | ||
| std = torch.tensor(std) | ||
| tensor = tensor - mean | ||
| tensor = tensor / std | ||
| return tensor | ||
|
|
||
|
|
||
| def _tensor_normalize_inplace(tensor, mean, std): | ||
| """ | ||
| Normalize a given tensor by subtracting the mean and dividing the std. | ||
| Args: | ||
| tensor (tensor): tensor to normalize (with dimensions C, T, H, W). | ||
| mean (tensor): mean value to subtract (in 0 to 255 floats). | ||
| std (tensor): std to divide (in 0 to 255 floats). | ||
| """ | ||
| if tensor.dtype == torch.uint8: | ||
| tensor = tensor.float() | ||
|
|
||
| C, T, H, W = tensor.shape | ||
| tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension | ||
| tensor.sub_(mean).div_(std) | ||
| tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front | ||
| return tensor |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,210 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import logging | ||
| import sys | ||
| import warnings | ||
| import yaml | ||
|
|
||
|
|
||
| import torch | ||
|
|
||
| import src.models.vision_transformer as video_vit | ||
| import src.models.predictor as vit_pred | ||
| from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper | ||
| from src.utils.schedulers import ( | ||
| WarmupCosineSchedule, | ||
| CosineWDSchedule) | ||
| from src.utils.tensors import trunc_normal_ | ||
|
|
||
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | ||
| logger = logging.getLogger() | ||
|
|
||
|
|
||
| def load_checkpoint( | ||
| r_path, | ||
| encoder, | ||
| predictor, | ||
| target_encoder, | ||
| opt, | ||
| scaler, | ||
| ): | ||
| try: | ||
| checkpoint = torch.load(r_path, map_location=torch.device('cpu')) | ||
| except Exception as e: | ||
| logger.info(f'Encountered exception when loading checkpoint {e}') | ||
|
|
||
| epoch = 0 | ||
| try: | ||
| epoch = checkpoint['epoch'] | ||
|
|
||
| # -- loading encoder | ||
| pretrained_dict = checkpoint['encoder'] | ||
| msg = encoder.load_state_dict(pretrained_dict) | ||
| logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}') | ||
|
|
||
| # -- loading predictor | ||
| pretrained_dict = checkpoint['predictor'] | ||
| msg = predictor.load_state_dict(pretrained_dict) | ||
| logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}') | ||
|
|
||
| # -- loading target_encoder | ||
| if target_encoder is not None: | ||
| print(list(checkpoint.keys())) | ||
| pretrained_dict = checkpoint['target_encoder'] | ||
| msg = target_encoder.load_state_dict(pretrained_dict) | ||
| logger.info( | ||
| f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}' | ||
| ) | ||
|
|
||
| # -- loading optimizer | ||
| opt.load_state_dict(checkpoint['opt']) | ||
| if scaler is not None: | ||
| scaler.load_state_dict(checkpoint['scaler']) | ||
| logger.info(f'loaded optimizers from epoch {epoch}') | ||
| logger.info(f'read-path: {r_path}') | ||
| del checkpoint | ||
|
|
||
| except Exception as e: | ||
| logger.info(f'Encountered exception when loading checkpoint {e}') | ||
| epoch = 0 | ||
|
|
||
| return ( | ||
| encoder, | ||
| predictor, | ||
| target_encoder, | ||
| opt, | ||
| scaler, | ||
| epoch, | ||
| ) | ||
|
|
||
|
|
||
| def init_video_model( | ||
| device, | ||
| patch_size=16, | ||
| num_frames=16, | ||
| tubelet_size=2, | ||
| model_name='vit_base', | ||
| crop_size=224, | ||
| pred_depth=6, | ||
| pred_embed_dim=384, | ||
| uniform_power=False, | ||
| use_mask_tokens=False, | ||
| num_mask_tokens=2, | ||
| zero_init_mask_tokens=True, | ||
| use_sdpa=False, | ||
| ): | ||
| encoder = video_vit.__dict__[model_name]( | ||
| img_size=crop_size, | ||
| patch_size=patch_size, | ||
| num_frames=num_frames, | ||
| tubelet_size=tubelet_size, | ||
| uniform_power=uniform_power, | ||
| use_sdpa=use_sdpa, | ||
| ) | ||
| encoder = MultiMaskWrapper(encoder) | ||
| predictor = vit_pred.__dict__['vit_predictor']( | ||
| img_size=crop_size, | ||
| use_mask_tokens=use_mask_tokens, | ||
| patch_size=patch_size, | ||
| num_frames=num_frames, | ||
| tubelet_size=tubelet_size, | ||
| embed_dim=encoder.backbone.embed_dim, | ||
| predictor_embed_dim=pred_embed_dim, | ||
| depth=pred_depth, | ||
| num_heads=encoder.backbone.num_heads, | ||
| uniform_power=uniform_power, | ||
| num_mask_tokens=num_mask_tokens, | ||
| zero_init_mask_tokens=zero_init_mask_tokens, | ||
| use_sdpa=use_sdpa, | ||
| ) | ||
| predictor = PredictorMultiMaskWrapper(predictor) | ||
|
|
||
| def init_weights(m): | ||
| if isinstance(m, torch.nn.Linear): | ||
| trunc_normal_(m.weight, std=0.02) | ||
| if m.bias is not None: | ||
| torch.nn.init.constant_(m.bias, 0) | ||
| elif isinstance(m, torch.nn.LayerNorm): | ||
| torch.nn.init.constant_(m.bias, 0) | ||
| torch.nn.init.constant_(m.weight, 1.0) | ||
|
|
||
| for m in encoder.modules(): | ||
| init_weights(m) | ||
|
|
||
| for m in predictor.modules(): | ||
| init_weights(m) | ||
|
|
||
| encoder.to(device) | ||
| predictor.to(device) | ||
| logger.info(encoder) | ||
| logger.info(predictor) | ||
|
|
||
| def count_parameters(model): | ||
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | ||
|
|
||
| logger.info(f'Encoder number of parameters: {count_parameters(encoder)}') | ||
| logger.info(f'Predictor number of parameters: {count_parameters(predictor)}') | ||
|
|
||
| return encoder, predictor | ||
|
|
||
|
|
||
| def init_opt( | ||
| encoder, | ||
| predictor, | ||
| iterations_per_epoch, | ||
| start_lr, | ||
| ref_lr, | ||
| warmup, | ||
| num_epochs, | ||
| wd=1e-6, | ||
| final_wd=1e-6, | ||
| final_lr=0.0, | ||
| mixed_precision=False, | ||
| ipe_scale=1.25, | ||
| betas=(0.9, 0.999), | ||
| eps=1e-8, | ||
| zero_init_bias_wd=True, | ||
| ): | ||
| param_groups = [ | ||
| { | ||
| 'params': (p for n, p in encoder.named_parameters() | ||
| if ('bias' not in n) and (len(p.shape) != 1)) | ||
| }, { | ||
| 'params': (p for n, p in predictor.named_parameters() | ||
| if ('bias' not in n) and (len(p.shape) != 1)) | ||
| }, { | ||
| 'params': (p for n, p in encoder.named_parameters() | ||
| if ('bias' in n) or (len(p.shape) == 1)), | ||
| 'WD_exclude': zero_init_bias_wd, | ||
| 'weight_decay': 0, | ||
| }, { | ||
| 'params': (p for n, p in predictor.named_parameters() | ||
| if ('bias' in n) or (len(p.shape) == 1)), | ||
| 'WD_exclude': zero_init_bias_wd, | ||
| 'weight_decay': 0, | ||
| }, | ||
| ] | ||
|
|
||
| logger.info('Using AdamW') | ||
| optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps) | ||
| scheduler = WarmupCosineSchedule( | ||
| optimizer, | ||
| warmup_steps=int(warmup * iterations_per_epoch), | ||
| start_lr=start_lr, | ||
| ref_lr=ref_lr, | ||
| final_lr=final_lr, | ||
| T_max=int(ipe_scale * num_epochs * iterations_per_epoch), | ||
| ) | ||
| wd_scheduler = CosineWDSchedule( | ||
| optimizer, | ||
| ref_wd=wd, | ||
| final_wd=final_wd, | ||
| T_max=int(ipe_scale * num_epochs * iterations_per_epoch), | ||
| ) | ||
| scaler = torch.cuda.amp.GradScaler() if mixed_precision else None | ||
| return optimizer, scaler, scheduler, wd_scheduler |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: in1k-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: imagenet_full_size/061417/ | ||
| num_classes: 1000 | ||
| resolution: 384 | ||
| dataset_name: ImageNet | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: inat-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: iNaturalist-2021/110421/ | ||
| num_classes: 10000 | ||
| resolution: 384 | ||
| dataset_name: iNat21 | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: k400-16x8x3 | ||
| eval_name: video_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv | ||
| dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv | ||
| dataset_type: VideoDataset | ||
| num_classes: 400 | ||
| frames_per_clip: 16 | ||
| num_segments: 8 | ||
| num_views_per_segment: 3 | ||
| frame_step: 4 | ||
| optimization: | ||
| attend_across_segments: true | ||
| num_epochs: 20 | ||
| resolution: 384 | ||
| batch_size: 4 | ||
| weight_decay: 0.01 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| use_sdpa: true | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: places-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: places205/121517/pytorch/ | ||
| num_classes: 205 | ||
| resolution: 384 | ||
| dataset_name: Places205 | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: ssv2-16x2x3 | ||
| eval_name: video_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| dataset_train: /your_path_to_ssv2_train_csv_file_index.csv | ||
| dataset_val: /your_path_to_ssv2_val_csv_file_index.csv | ||
| dataset_type: VideoDataset | ||
| num_classes: 174 | ||
| frames_per_clip: 16 | ||
| num_segments: 2 | ||
| num_views_per_segment: 3 | ||
| frame_step: 4 | ||
| optimization: | ||
| attend_across_segments: true | ||
| num_epochs: 20 | ||
| resolution: 384 | ||
| batch_size: 4 | ||
| weight_decay: 0.01 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| use_sdpa: true | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: in1k-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: imagenet_full_size/061417/ | ||
| num_classes: 1000 | ||
| resolution: 224 | ||
| dataset_name: ImageNet | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: inat-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: iNaturalist-2021/110421/ | ||
| num_classes: 10000 | ||
| resolution: 224 | ||
| dataset_name: iNat21 | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: k400-16x8x3 | ||
| eval_name: video_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv | ||
| dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv | ||
| dataset_type: VideoDataset | ||
| num_classes: 400 | ||
| frames_per_clip: 16 | ||
| num_segments: 8 | ||
| num_views_per_segment: 3 | ||
| frame_step: 4 | ||
| optimization: | ||
| attend_across_segments: true | ||
| num_epochs: 20 | ||
| resolution: 224 | ||
| batch_size: 4 | ||
| weight_decay: 0.01 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| use_sdpa: true | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: places-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: places205/121517/pytorch/ | ||
| num_classes: 205 | ||
| resolution: 224 | ||
| dataset_name: Places205 | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: ssv2-16x2x3 | ||
| eval_name: video_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| dataset_train: /your_path_to_ssv2_train_csv_file_index.csv | ||
| dataset_val: /your_path_to_ssv2_val_csv_file_index.csv | ||
| dataset_type: VideoDataset | ||
| num_classes: 174 | ||
| frames_per_clip: 16 | ||
| num_segments: 2 | ||
| num_views_per_segment: 3 | ||
| frame_step: 4 | ||
| optimization: | ||
| attend_across_segments: true | ||
| num_epochs: 20 | ||
| resolution: 224 | ||
| batch_size: 4 | ||
| weight_decay: 0.01 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_huge | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| use_sdpa: true | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: in1k-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: imagenet_full_size/061417/ | ||
| num_classes: 1000 | ||
| resolution: 224 | ||
| dataset_name: ImageNet | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_large | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: inat-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: iNaturalist-2021/110421/ | ||
| num_classes: 10000 | ||
| resolution: 224 | ||
| dataset_name: iNat21 | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_large | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: k400-16x8x3 | ||
| eval_name: video_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv | ||
| dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv | ||
| dataset_type: VideoDataset | ||
| num_classes: 400 | ||
| frames_per_clip: 16 | ||
| num_segments: 8 | ||
| num_views_per_segment: 3 | ||
| frame_step: 4 | ||
| optimization: | ||
| attend_across_segments: true | ||
| num_epochs: 20 | ||
| resolution: 224 | ||
| batch_size: 4 | ||
| weight_decay: 0.01 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_large | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| use_sdpa: true | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: places-16f | ||
| eval_name: image_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/ | ||
| image_folder: places205/121517/pytorch/ | ||
| num_classes: 205 | ||
| resolution: 224 | ||
| dataset_name: Places205 | ||
| optimization: | ||
| num_epochs: 20 | ||
| batch_size: 16 | ||
| weight_decay: 0.001 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_large | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_sdpa: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| nodes: 8 | ||
| tasks_per_node: 8 | ||
| tag: ssv2-16x2x3 | ||
| eval_name: video_classification_frozen | ||
| resume_checkpoint: false | ||
| data: | ||
| dataset_train: /your_path_to_ssv2_train_csv_file_index.csv | ||
| dataset_val: /your_path_to_ssv2_val_csv_file_index.csv | ||
| dataset_type: VideoDataset | ||
| num_classes: 174 | ||
| frames_per_clip: 16 | ||
| num_segments: 2 | ||
| num_views_per_segment: 3 | ||
| frame_step: 4 | ||
| optimization: | ||
| attend_across_segments: true | ||
| num_epochs: 20 | ||
| resolution: 224 | ||
| batch_size: 4 | ||
| weight_decay: 0.01 | ||
| lr: 0.001 | ||
| start_lr: 0.001 | ||
| final_lr: 0.0 | ||
| warmup: 0. | ||
| use_bfloat16: true | ||
| pretrain: | ||
| model_name: vit_large | ||
| checkpoint_key: target_encoder | ||
| clip_duration: null | ||
| frames_per_clip: 16 | ||
| tubelet_size: 2 | ||
| uniform_power: true | ||
| use_silu: false | ||
| tight_silu: false | ||
| use_sdpa: true | ||
| patch_size: 16 | ||
| folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/ | ||
| checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder | ||
| write_tag: jepa |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| app: vjepa | ||
| nodes: 16 | ||
| tasks_per_node: 8 | ||
| data: | ||
| dataset_type: VideoDataset | ||
| datasets: | ||
| - /your_path_to_kinetics710_csv_file_index.csv | ||
| - /your_path_to_ssv2_csv_file_index.csv | ||
| - /your_path_to_howto100m_csv_file_index.csv | ||
| decode_one_clip: true | ||
| batch_size: 24 | ||
| num_clips: 1 | ||
| num_frames: 16 | ||
| tubelet_size: 2 | ||
| sampling_rate: 4 | ||
| crop_size: 224 | ||
| patch_size: 16 | ||
| pin_mem: true | ||
| num_workers: 12 | ||
| filter_short_videos: false | ||
| clip_duration: null | ||
| data_aug: | ||
| auto_augment: false | ||
| motion_shift: false | ||
| random_resize_aspect_ratio: | ||
| - 0.75 | ||
| - 1.35 | ||
| random_resize_scale: | ||
| - 0.3 | ||
| - 1.0 | ||
| reprob: 0.0 | ||
| logging: | ||
| folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ | ||
| write_tag: jepa | ||
| loss: | ||
| loss_exp: 1.0 | ||
| reg_coeff: 0.0 | ||
| mask: | ||
| - aspect_ratio: | ||
| - 0.75 | ||
| - 1.5 | ||
| num_blocks: 8 | ||
| spatial_scale: | ||
| - 0.15 | ||
| - 0.15 | ||
| temporal_scale: | ||
| - 1.0 | ||
| - 1.0 | ||
| max_temporal_keep: 1.0 | ||
| max_keep: null | ||
| - aspect_ratio: | ||
| - 0.75 | ||
| - 1.5 | ||
| num_blocks: 2 | ||
| spatial_scale: | ||
| - 0.7 | ||
| - 0.7 | ||
| temporal_scale: | ||
| - 1.0 | ||
| - 1.0 | ||
| max_temporal_keep: 1.0 | ||
| max_keep: null | ||
| meta: | ||
| load_checkpoint: false | ||
| read_checkpoint: null | ||
| seed: 234 | ||
| eval_freq: 100 | ||
| use_sdpa: true | ||
| dtype: bfloat16 | ||
| model: | ||
| model_name: vit_huge | ||
| pred_depth: 12 | ||
| pred_embed_dim: 384 | ||
| uniform_power: true | ||
| use_mask_tokens: true | ||
| zero_init_mask_tokens: true | ||
| optimization: | ||
| ipe: 300 | ||
| ipe_scale: 1.25 | ||
| clip_grad: 10.0 | ||
| weight_decay: 0.04 | ||
| final_weight_decay: 0.4 | ||
| epochs: 300 | ||
| warmup: 40 | ||
| start_lr: 0.0002 | ||
| lr: 0.000625 | ||
| final_lr: 1.0e-06 | ||
| ema: | ||
| - 0.998 | ||
| - 1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| app: vjepa | ||
| nodes: 30 | ||
| tasks_per_node: 8 | ||
| data: | ||
| dataset_type: VideoDataset | ||
| datasets: | ||
| - /your_path_to_kinetics710_csv_file_index.csv | ||
| - /your_path_to_ssv2_csv_file_index.csv | ||
| - /your_path_to_howto100m_csv_file_index.csv | ||
| decode_one_clip: true | ||
| batch_size: 10 | ||
| num_clips: 1 | ||
| num_frames: 16 | ||
| tubelet_size: 2 | ||
| sampling_rate: 4 | ||
| crop_size: 384 | ||
| patch_size: 16 | ||
| pin_mem: true | ||
| num_workers: 12 | ||
| filter_short_videos: false | ||
| clip_duration: null | ||
| data_aug: | ||
| auto_augment: false | ||
| motion_shift: false | ||
| random_resize_aspect_ratio: | ||
| - 0.75 | ||
| - 1.35 | ||
| random_resize_scale: | ||
| - 0.3 | ||
| - 1.0 | ||
| reprob: 0.0 | ||
| logging: | ||
| folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ | ||
| write_tag: jepa | ||
| loss: | ||
| loss_exp: 1.0 | ||
| reg_coeff: 0.0 | ||
| mask: | ||
| - aspect_ratio: | ||
| - 0.75 | ||
| - 1.5 | ||
| num_blocks: 8 | ||
| spatial_scale: | ||
| - 0.15 | ||
| - 0.15 | ||
| temporal_scale: | ||
| - 1.0 | ||
| - 1.0 | ||
| max_temporal_keep: 1.0 | ||
| max_keep: null | ||
| - aspect_ratio: | ||
| - 0.75 | ||
| - 1.5 | ||
| num_blocks: 2 | ||
| spatial_scale: | ||
| - 0.7 | ||
| - 0.7 | ||
| temporal_scale: | ||
| - 1.0 | ||
| - 1.0 | ||
| max_temporal_keep: 1.0 | ||
| max_keep: null | ||
| meta: | ||
| load_checkpoint: false | ||
| read_checkpoint: null | ||
| seed: 234 | ||
| eval_freq: 100 | ||
| use_sdpa: true | ||
| dtype: bfloat16 | ||
| model: | ||
| model_name: vit_huge | ||
| pred_depth: 12 | ||
| pred_embed_dim: 384 | ||
| uniform_power: true | ||
| use_mask_tokens: true | ||
| zero_init_mask_tokens: true | ||
| optimization: | ||
| ipe: 300 | ||
| ipe_scale: 1.25 | ||
| clip_grad: 10.0 | ||
| weight_decay: 0.04 | ||
| final_weight_decay: 0.4 | ||
| epochs: 300 | ||
| warmup: 40 | ||
| start_lr: 0.0002 | ||
| lr: 0.000625 | ||
| final_lr: 1.0e-06 | ||
| ema: | ||
| - 0.998 | ||
| - 1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| app: vjepa | ||
| nodes: 16 | ||
| tasks_per_node: 8 | ||
| data: | ||
| dataset_type: VideoDataset | ||
| datasets: | ||
| - /your_path_to_kinetics710_csv_file_index.csv | ||
| - /your_path_to_ssv2_csv_file_index.csv | ||
| - /your_path_to_howto100m_csv_file_index.csv | ||
| decode_one_clip: true | ||
| batch_size: 24 | ||
| num_clips: 1 | ||
| num_frames: 16 | ||
| tubelet_size: 2 | ||
| sampling_rate: 4 | ||
| crop_size: 224 | ||
| patch_size: 16 | ||
| pin_mem: true | ||
| num_workers: 12 | ||
| filter_short_videos: false | ||
| clip_duration: null | ||
| data_aug: | ||
| auto_augment: false | ||
| motion_shift: false | ||
| random_resize_aspect_ratio: | ||
| - 0.75 | ||
| - 1.35 | ||
| random_resize_scale: | ||
| - 0.3 | ||
| - 1.0 | ||
| reprob: 0.0 | ||
| logging: | ||
| folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/ | ||
| write_tag: jepa | ||
| loss: | ||
| loss_exp: 1.0 | ||
| reg_coeff: 0.0 | ||
| mask: | ||
| - aspect_ratio: | ||
| - 0.75 | ||
| - 1.5 | ||
| num_blocks: 8 | ||
| spatial_scale: | ||
| - 0.15 | ||
| - 0.15 | ||
| temporal_scale: | ||
| - 1.0 | ||
| - 1.0 | ||
| max_temporal_keep: 1.0 | ||
| max_keep: null | ||
| - aspect_ratio: | ||
| - 0.75 | ||
| - 1.5 | ||
| num_blocks: 2 | ||
| spatial_scale: | ||
| - 0.7 | ||
| - 0.7 | ||
| temporal_scale: | ||
| - 1.0 | ||
| - 1.0 | ||
| max_temporal_keep: 1.0 | ||
| max_keep: null | ||
| meta: | ||
| load_checkpoint: false | ||
| read_checkpoint: null | ||
| seed: 234 | ||
| eval_freq: 100 | ||
| use_sdpa: true | ||
| dtype: bfloat16 | ||
| model: | ||
| model_name: vit_large | ||
| pred_depth: 12 | ||
| pred_embed_dim: 384 | ||
| uniform_power: true | ||
| use_mask_tokens: true | ||
| zero_init_mask_tokens: true | ||
| optimization: | ||
| ipe: 300 | ||
| ipe_scale: 1.25 | ||
| clip_grad: 10.0 | ||
| weight_decay: 0.04 | ||
| final_weight_decay: 0.4 | ||
| epochs: 300 | ||
| warmup: 40 | ||
| start_lr: 0.0002 | ||
| lr: 0.000625 | ||
| final_lr: 1.0e-06 | ||
| ema: | ||
| - 0.998 | ||
| - 1.0 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import argparse | ||
|
|
||
| import multiprocessing as mp | ||
|
|
||
| import pprint | ||
| import yaml | ||
|
|
||
| from src.utils.distributed import init_distributed | ||
|
|
||
| from evals.scaffold import main as eval_main | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| '--fname', type=str, | ||
| help='name of config file to load', | ||
| default='configs.yaml') | ||
| parser.add_argument( | ||
| '--devices', type=str, nargs='+', default=['cuda:0'], | ||
| help='which devices to use on local machine') | ||
|
|
||
|
|
||
| def process_main(rank, fname, world_size, devices): | ||
| import os | ||
| os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1]) | ||
|
|
||
| import logging | ||
| logging.basicConfig() | ||
| logger = logging.getLogger() | ||
| if rank == 0: | ||
| logger.setLevel(logging.INFO) | ||
| else: | ||
| logger.setLevel(logging.ERROR) | ||
|
|
||
| logger.info(f'called-params {fname}') | ||
|
|
||
| # Load config | ||
| params = None | ||
| with open(fname, 'r') as y_file: | ||
| params = yaml.load(y_file, Loader=yaml.FullLoader) | ||
| logger.info('loaded params...') | ||
| pp = pprint.PrettyPrinter(indent=4) | ||
| pp.pprint(params) | ||
|
|
||
| # Init distributed (access to comm between GPUS on same machine) | ||
| world_size, rank = init_distributed(rank_and_world_size=(rank, world_size)) | ||
| logger.info(f'Running... (rank: {rank}/{world_size})') | ||
|
|
||
| # Launch the eval with loaded config | ||
| eval_main(params['eval_name'], args_eval=params) | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| args = parser.parse_args() | ||
| num_gpus = len(args.devices) | ||
| mp.set_start_method('spawn') | ||
| for rank in range(num_gpus): | ||
| mp.Process( | ||
| target=process_main, | ||
| args=(rank, args.fname, num_gpus, args.devices) | ||
| ).start() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,162 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import argparse | ||
| import logging | ||
| import os | ||
| import pprint | ||
| import sys | ||
| import time | ||
| import yaml | ||
|
|
||
| import submitit | ||
|
|
||
| from evals.scaffold import main as eval_main | ||
|
|
||
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | ||
| logger = logging.getLogger() | ||
|
|
||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
| '--folder', type=str, | ||
| help='location to save submitit logs', | ||
| default='/fsx-jepa/massran/submitit/') | ||
| parser.add_argument( | ||
| '--exclude', type=str, | ||
| help='nodes to exclude from training', | ||
| default=None) | ||
| parser.add_argument( | ||
| '--batch-launch', action='store_true', | ||
| help='whether fname points to a file to batch-lauch several config files') | ||
| parser.add_argument( | ||
| '--fname', type=str, | ||
| help='yaml file containing config file names to launch', | ||
| default='configs.yaml') | ||
| parser.add_argument( | ||
| '--partition', type=str, | ||
| help='cluster partition to submit jobs on') | ||
| parser.add_argument( | ||
| '--time', type=int, default=4300, | ||
| help='time in minutes to run job') | ||
|
|
||
|
|
||
| class Trainer: | ||
|
|
||
| def __init__(self, args_eval=None, resume_preempt=None): | ||
| self.eval_name = args_eval['eval_name'] | ||
| self.args_eval = args_eval | ||
| self.resume_preempt = resume_preempt | ||
|
|
||
| def __call__(self): | ||
| eval_name = self.eval_name | ||
| args_eval = self.args_eval | ||
| resume_preempt = self.resume_preempt | ||
|
|
||
| logger.info('loaded eval params...') | ||
| pp = pprint.PrettyPrinter(indent=4) | ||
| pp.pprint(args_eval) | ||
|
|
||
| eval_main( | ||
| eval_name, | ||
| args_eval=args_eval, | ||
| resume_preempt=resume_preempt) | ||
|
|
||
| def checkpoint(self): | ||
| fb_trainer = Trainer(self.args_eval, True) | ||
| return submitit.helpers.DelayedSubmission(fb_trainer,) | ||
|
|
||
|
|
||
| def launch_evals_with_parsed_args( | ||
| args_for_evals, | ||
| submitit_folder, | ||
| partition='learnlab,learnfair', | ||
| timeout=4300, | ||
| nodes=1, | ||
| tasks_per_node=1, | ||
| delay_seconds=10, | ||
| exclude_nodes=None | ||
| ): | ||
| if not isinstance(args_for_evals, list): | ||
| logger.info(f'Passed in eval-args of type {type(args_for_evals)}') | ||
| args_for_evals = [args_for_evals] | ||
|
|
||
| time.sleep(delay_seconds) | ||
| logger.info('Launching evaluations in separate jobs...') | ||
| executor = submitit.AutoExecutor( | ||
| folder=os.path.join(submitit_folder, 'job_%j'), | ||
| slurm_max_num_timeout=20) | ||
| executor.update_parameters( | ||
| slurm_partition=partition, | ||
| slurm_mem_per_gpu='55G', | ||
| timeout_min=timeout, | ||
| nodes=nodes, | ||
| tasks_per_node=tasks_per_node, | ||
| cpus_per_task=12, | ||
| gpus_per_node=tasks_per_node) | ||
|
|
||
| if exclude_nodes is not None: | ||
| executor.update_parameters(slurm_exclude=exclude_nodes) | ||
|
|
||
| jobs, trainers = [], [] | ||
| with executor.batch(): | ||
| for ae in args_for_evals: | ||
| fb_trainer = Trainer(ae) | ||
| job = executor.submit(fb_trainer,) | ||
| trainers.append(fb_trainer) | ||
| jobs.append(job) | ||
|
|
||
| for job in jobs: | ||
| logger.info(f'Launched eval job with id {job.job_id}') | ||
|
|
||
|
|
||
| def launch_evals(): | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # 1. Put config file names in a list | ||
| # ---------------------------------------------------------------------- # | ||
| config_fnames = [args.fname] | ||
|
|
||
| # -- If batch-launch is True, then the args.fname yaml file is not a | ||
| # -- config, but actually specifies a list of other config files | ||
| # -- to run in a slurm job array | ||
| if args.batch_launch: | ||
| with open(args.fname, 'r') as y_file: | ||
| config_fnames = yaml.load(y_file, Loader=yaml.FullLoader) | ||
| # ---------------------------------------------------------------------- # | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # 2. Parse each yaml config file as a dict and place in list | ||
| # ---------------------------------------------------------------------- # | ||
| nodes, tasks_per_node = None, None | ||
| configs = [] | ||
| for f in config_fnames: | ||
| with open(f, 'r') as y_file: | ||
| _params = yaml.load(y_file, Loader=yaml.FullLoader) | ||
| nodes = int(_params.get('nodes')) | ||
| tasks_per_node = int(_params.get('tasks_per_node')) | ||
| configs += [_params] | ||
| logger.info(f'Loaded {len(configs)} config files') | ||
| logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}') | ||
| # ---------------------------------------------------------------------- # | ||
|
|
||
| # ---------------------------------------------------------------------- # | ||
| # 3. Launch evals with parsed config files | ||
| # ---------------------------------------------------------------------- # | ||
| launch_evals_with_parsed_args( | ||
| args_for_evals=configs, | ||
| submitit_folder=args.folder, | ||
| partition=args.partition, | ||
| timeout=args.time, | ||
| nodes=nodes, | ||
| tasks_per_node=tasks_per_node, | ||
| exclude_nodes=args.exclude) | ||
| # ---------------------------------------------------------------------- # | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| args = parser.parse_args() | ||
| launch_evals() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import importlib | ||
| import logging | ||
| import sys | ||
|
|
||
| logging.basicConfig(stream=sys.stdout, level=logging.INFO) | ||
| logger = logging.getLogger() | ||
|
|
||
|
|
||
| def main( | ||
| eval_name, | ||
| args_eval, | ||
| resume_preempt=False | ||
| ): | ||
| logger.info(f'Running evaluation: {eval_name}') | ||
| return importlib.import_module(f'evals.{eval_name}.eval').main( | ||
| args_eval=args_eval, | ||
| resume_preempt=resume_preempt) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,343 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import numpy as np | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
| import torchvision.transforms as transforms | ||
|
|
||
| import src.datasets.utils.video.transforms as video_transforms | ||
| import src.datasets.utils.video.volume_transforms as volume_transforms | ||
|
|
||
| from src.datasets.utils.video.randerase import RandomErasing | ||
|
|
||
| from src.models.utils.pos_embs import get_1d_sincos_pos_embed | ||
| from src.masks.utils import apply_masks | ||
|
|
||
|
|
||
| class FrameAggregation(nn.Module): | ||
| """ | ||
| Process each frame independently and concatenate all tokens | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model, | ||
| max_frames=10000, | ||
| use_pos_embed=False, | ||
| attend_across_segments=False | ||
| ): | ||
| super().__init__() | ||
| self.model = model | ||
| self.embed_dim = embed_dim = model.embed_dim | ||
| self.num_heads = model.num_heads | ||
| self.attend_across_segments = attend_across_segments | ||
| # 1D-temporal pos-embedding | ||
| self.pos_embed = None | ||
| if use_pos_embed: | ||
| self.pos_embed = nn.Parameter( | ||
| torch.zeros(1, max_frames, embed_dim), | ||
| requires_grad=False) | ||
| sincos = get_1d_sincos_pos_embed(embed_dim, max_frames) | ||
| self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) | ||
|
|
||
| def forward(self, x, clip_indices=None): | ||
|
|
||
| # TODO: impement attend_across_segments=False | ||
| # num_clips = len(x) | ||
| num_views_per_clip = len(x[0]) | ||
|
|
||
| # Concatenate views along batch dimension | ||
| x = [torch.cat(xi, dim=0) for xi in x] | ||
| # Concatenate clips along temporal dimension | ||
| x = torch.cat(x, dim=2) | ||
| B, C, T, H, W = x.size() | ||
|
|
||
| # Put each frame along the batch dimension | ||
| x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W) | ||
|
|
||
| outputs = self.model(x) | ||
| _, N, D = outputs.size() | ||
| outputs = outputs.reshape(B, T, N, D).flatten(1, 2) | ||
|
|
||
| # Separate views into list | ||
| B = B // num_views_per_clip | ||
| all_outputs = [] | ||
| for i in range(num_views_per_clip): | ||
| o = outputs[i*B:(i+1)*B] | ||
| # Compute positional embedding | ||
| if (self.pos_embed is not None) and (clip_indices is not None): | ||
| pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] | ||
| pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) | ||
| pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension | ||
| pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] | ||
| pos_embed = pos_embed.flatten(1, 2) | ||
| o += pos_embed | ||
| all_outputs += [o] | ||
|
|
||
| return all_outputs | ||
|
|
||
|
|
||
| class ClipAggregation(nn.Module): | ||
| """ | ||
| Process each clip indepdnently and concatenate all tokens | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| model, | ||
| tubelet_size=2, | ||
| max_frames=10000, | ||
| use_pos_embed=False, | ||
| attend_across_segments=False | ||
| ): | ||
| super().__init__() | ||
| self.model = model | ||
| self.tubelet_size = tubelet_size | ||
| self.embed_dim = embed_dim = model.embed_dim | ||
| self.num_heads = model.num_heads | ||
| self.attend_across_segments = attend_across_segments | ||
| # 1D-temporal pos-embedding | ||
| self.pos_embed = None | ||
| if use_pos_embed: | ||
| max_T = max_frames // tubelet_size | ||
| self.pos_embed = nn.Parameter( | ||
| torch.zeros(1, max_T, embed_dim), | ||
| requires_grad=False) | ||
| sincos = get_1d_sincos_pos_embed(embed_dim, max_T) | ||
| self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0)) | ||
|
|
||
| def forward(self, x, clip_indices=None): | ||
|
|
||
| num_clips = len(x) | ||
| num_views_per_clip = len(x[0]) | ||
| B, C, T, H, W = x[0][0].size() | ||
|
|
||
| # Concatenate all spatial and temporal views along batch dimension | ||
| x = [torch.cat(xi, dim=0) for xi in x] | ||
| x = torch.cat(x, dim=0) | ||
| outputs = self.model(x) | ||
| _, N, D = outputs.size() | ||
|
|
||
| T = T // self.tubelet_size # Num temporal tokens | ||
| N = N // T # Num spatial tokens | ||
|
|
||
| # Unroll outputs into a 2D array [spatial_views x temporal_views] | ||
| eff_B = B * num_views_per_clip | ||
| all_outputs = [[] for _ in range(num_views_per_clip)] | ||
| for i in range(num_clips): | ||
| o = outputs[i*eff_B:(i+1)*eff_B] | ||
| for j in range(num_views_per_clip): | ||
| all_outputs[j].append(o[j*B:(j+1)*B]) | ||
|
|
||
| if not self.attend_across_segments: | ||
| return all_outputs | ||
|
|
||
| for i, outputs in enumerate(all_outputs): | ||
|
|
||
| # Concatenate along temporal dimension | ||
| outputs = [o.reshape(B, T, N, D) for o in outputs] | ||
| outputs = torch.cat(outputs, dim=1).flatten(1, 2) | ||
|
|
||
| # Compute positional embedding | ||
| if (self.pos_embed is not None) and (clip_indices is not None): | ||
| clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices] | ||
| pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D] | ||
| pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D])) | ||
| pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension | ||
| pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D] | ||
| pos_embed = pos_embed.flatten(1, 2) | ||
| outputs += pos_embed | ||
|
|
||
| all_outputs[i] = outputs | ||
|
|
||
| return all_outputs | ||
|
|
||
|
|
||
| def make_transforms( | ||
| training=True, | ||
| random_horizontal_flip=True, | ||
| random_resize_aspect_ratio=(3/4, 4/3), | ||
| random_resize_scale=(0.3, 1.0), | ||
| reprob=0.0, | ||
| auto_augment=False, | ||
| motion_shift=False, | ||
| crop_size=224, | ||
| num_views_per_clip=1, | ||
| normalize=((0.485, 0.456, 0.406), | ||
| (0.229, 0.224, 0.225)) | ||
| ): | ||
|
|
||
| if not training and num_views_per_clip > 1: | ||
| print('Making EvalVideoTransform, multi-view') | ||
| _frames_augmentation = EvalVideoTransform( | ||
| num_views_per_clip=num_views_per_clip, | ||
| short_side_size=crop_size, | ||
| normalize=normalize, | ||
| ) | ||
|
|
||
| else: | ||
| _frames_augmentation = VideoTransform( | ||
| training=training, | ||
| random_horizontal_flip=random_horizontal_flip, | ||
| random_resize_aspect_ratio=random_resize_aspect_ratio, | ||
| random_resize_scale=random_resize_scale, | ||
| reprob=reprob, | ||
| auto_augment=auto_augment, | ||
| motion_shift=motion_shift, | ||
| crop_size=crop_size, | ||
| normalize=normalize, | ||
| ) | ||
| return _frames_augmentation | ||
|
|
||
|
|
||
| class VideoTransform(object): | ||
|
|
||
| def __init__( | ||
| self, | ||
| training=True, | ||
| random_horizontal_flip=True, | ||
| random_resize_aspect_ratio=(3/4, 4/3), | ||
| random_resize_scale=(0.3, 1.0), | ||
| reprob=0.0, | ||
| auto_augment=False, | ||
| motion_shift=False, | ||
| crop_size=224, | ||
| normalize=((0.485, 0.456, 0.406), | ||
| (0.229, 0.224, 0.225)) | ||
| ): | ||
|
|
||
| self.training = training | ||
|
|
||
| short_side_size = int(crop_size * 256 / 224) | ||
| self.eval_transform = video_transforms.Compose([ | ||
| video_transforms.Resize(short_side_size, interpolation='bilinear'), | ||
| video_transforms.CenterCrop(size=(crop_size, crop_size)), | ||
| volume_transforms.ClipToTensor(), | ||
| video_transforms.Normalize(mean=normalize[0], std=normalize[1]) | ||
| ]) | ||
|
|
||
| self.random_horizontal_flip = random_horizontal_flip | ||
| self.random_resize_aspect_ratio = random_resize_aspect_ratio | ||
| self.random_resize_scale = random_resize_scale | ||
| self.auto_augment = auto_augment | ||
| self.motion_shift = motion_shift | ||
| self.crop_size = crop_size | ||
| self.normalize = torch.tensor(normalize) | ||
|
|
||
| self.autoaug_transform = video_transforms.create_random_augment( | ||
| input_size=(crop_size, crop_size), | ||
| auto_augment='rand-m7-n4-mstd0.5-inc1', | ||
| interpolation='bicubic', | ||
| ) | ||
|
|
||
| self.spatial_transform = video_transforms.random_resized_crop_with_shift \ | ||
| if motion_shift else video_transforms.random_resized_crop | ||
|
|
||
| self.reprob = reprob | ||
| self.erase_transform = RandomErasing( | ||
| reprob, | ||
| mode='pixel', | ||
| max_count=1, | ||
| num_splits=1, | ||
| device='cpu', | ||
| ) | ||
|
|
||
| def __call__(self, buffer): | ||
|
|
||
| if not self.training: | ||
| return [self.eval_transform(buffer)] | ||
|
|
||
| buffer = [transforms.ToPILImage()(frame) for frame in buffer] | ||
|
|
||
| if self.auto_augment: | ||
| buffer = self.autoaug_transform(buffer) | ||
|
|
||
| buffer = [transforms.ToTensor()(img) for img in buffer] | ||
| buffer = torch.stack(buffer) # T C H W | ||
| buffer = buffer.permute(0, 2, 3, 1) # T H W C | ||
|
|
||
| buffer = tensor_normalize(buffer, self.normalize[0], self.normalize[1]) | ||
| buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W | ||
|
|
||
| buffer = self.spatial_transform( | ||
| images=buffer, | ||
| target_height=self.crop_size, | ||
| target_width=self.crop_size, | ||
| scale=self.random_resize_scale, | ||
| ratio=self.random_resize_aspect_ratio, | ||
| ) | ||
| if self.random_horizontal_flip: | ||
| buffer, _ = video_transforms.horizontal_flip(0.5, buffer) | ||
|
|
||
| if self.reprob > 0: | ||
| buffer = buffer.permute(1, 0, 2, 3) | ||
| buffer = self.erase_transform(buffer) | ||
| buffer = buffer.permute(1, 0, 2, 3) | ||
|
|
||
| return [buffer] | ||
|
|
||
|
|
||
| class EvalVideoTransform(object): | ||
|
|
||
| def __init__( | ||
| self, | ||
| num_views_per_clip=1, | ||
| short_side_size=224, | ||
| normalize=((0.485, 0.456, 0.406), | ||
| (0.229, 0.224, 0.225)) | ||
| ): | ||
| self.views_per_clip = num_views_per_clip | ||
| self.short_side_size = short_side_size | ||
| self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear') | ||
| self.to_tensor = video_transforms.Compose([ | ||
| volume_transforms.ClipToTensor(), | ||
| video_transforms.Normalize(mean=normalize[0], std=normalize[1]) | ||
| ]) | ||
|
|
||
| def __call__(self, buffer): | ||
|
|
||
| # Sample several spatial views of each clip | ||
| buffer = np.array(self.spatial_resize(buffer)) | ||
| T, H, W, C = buffer.shape | ||
|
|
||
| num_views = self.views_per_clip | ||
| side_len = self.short_side_size | ||
| spatial_step = (max(H, W) - side_len) // (num_views - 1) | ||
|
|
||
| all_views = [] | ||
| for i in range(num_views): | ||
| start = i*spatial_step | ||
| if H > W: | ||
| view = buffer[:, start:start+side_len, :, :] | ||
| else: | ||
| view = buffer[:, :, start:start+side_len, :] | ||
| view = self.to_tensor(view) | ||
| all_views.append(view) | ||
|
|
||
| return all_views | ||
|
|
||
|
|
||
| def tensor_normalize(tensor, mean, std): | ||
| """ | ||
| Normalize a given tensor by subtracting the mean and dividing the std. | ||
| Args: | ||
| tensor (tensor): tensor to normalize. | ||
| mean (tensor or list): mean value to subtract. | ||
| std (tensor or list): std to divide. | ||
| """ | ||
| if tensor.dtype == torch.uint8: | ||
| tensor = tensor.float() | ||
| tensor = tensor / 255.0 | ||
| if type(mean) == list: | ||
| mean = torch.tensor(mean) | ||
| if type(std) == list: | ||
| std = torch.tensor(std) | ||
| tensor = tensor - mean | ||
| tensor = tensor / std | ||
| return tensor |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| torch>=2 | ||
| torchvision | ||
| pyyaml | ||
| numpy | ||
| opencv-python | ||
| submitit | ||
| braceexpand | ||
| webdataset | ||
| timm | ||
| decord | ||
| pandas | ||
| einops | ||
| beartype |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,25 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import os | ||
| from setuptools import setup | ||
|
|
||
| VERSION = "0.0.1" | ||
|
|
||
| def get_requirements(): | ||
| with open("./requirements.txt") as reqsf: | ||
| reqs = reqsf.readlines() | ||
| return reqs | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| setup( | ||
| name="jepa", | ||
| version=VERSION, | ||
| description="JEPA research code.", | ||
| python_requires=">=3.9", | ||
| install_requires=get_requirements(), | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,90 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| from logging import getLogger | ||
|
|
||
|
|
||
| _GLOBAL_SEED = 0 | ||
| logger = getLogger() | ||
|
|
||
|
|
||
| def init_data( | ||
| batch_size, | ||
| transform=None, | ||
| shared_transform=None, | ||
| data='ImageNet', | ||
| collator=None, | ||
| pin_mem=True, | ||
| num_workers=8, | ||
| world_size=1, | ||
| rank=0, | ||
| root_path=None, | ||
| image_folder=None, | ||
| training=True, | ||
| copy_data=False, | ||
| drop_last=True, | ||
| tokenize_txt=True, | ||
| subset_file=None, | ||
| clip_len=8, | ||
| frame_sample_rate=2, | ||
| duration=None, | ||
| num_clips=1, | ||
| random_clip_sampling=True, | ||
| allow_clip_overlap=False, | ||
| filter_short_videos=False, | ||
| filter_long_videos=int(1e9), | ||
| decode_one_clip=True, | ||
| datasets_weights=None, | ||
| persistent_workers=False, | ||
| repeat_wds=False, | ||
| ipe=300, | ||
| log_dir=None, | ||
| ): | ||
|
|
||
| if (data.lower() == 'imagenet') \ | ||
| or (data.lower() == 'inat21') \ | ||
| or (data.lower() == 'places205'): | ||
| from src.datasets.image_dataset import make_imagedataset | ||
| dataset, data_loader, dist_sampler = make_imagedataset( | ||
| transform=transform, | ||
| batch_size=batch_size, | ||
| collator=collator, | ||
| pin_mem=pin_mem, | ||
| training=training, | ||
| num_workers=num_workers, | ||
| world_size=world_size, | ||
| rank=rank, | ||
| root_path=root_path, | ||
| image_folder=image_folder, | ||
| persistent_workers=persistent_workers, | ||
| copy_data=copy_data, | ||
| drop_last=drop_last, | ||
| subset_file=subset_file) | ||
|
|
||
| elif data.lower() == 'videodataset': | ||
| from src.datasets.video_dataset import make_videodataset | ||
| dataset, data_loader, dist_sampler = make_videodataset( | ||
| data_paths=root_path, | ||
| batch_size=batch_size, | ||
| frames_per_clip=clip_len, | ||
| frame_step=frame_sample_rate, | ||
| duration=duration, | ||
| num_clips=num_clips, | ||
| random_clip_sampling=random_clip_sampling, | ||
| allow_clip_overlap=allow_clip_overlap, | ||
| filter_short_videos=filter_short_videos, | ||
| filter_long_videos=filter_long_videos, | ||
| shared_transform=shared_transform, | ||
| transform=transform, | ||
| datasets_weights=datasets_weights, | ||
| collator=collator, | ||
| num_workers=num_workers, | ||
| world_size=world_size, | ||
| rank=rank, | ||
| log_dir=log_dir) | ||
|
|
||
| return (data_loader, dist_sampler) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,79 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import os | ||
|
|
||
| from logging import getLogger | ||
|
|
||
| import torch | ||
| import torchvision | ||
|
|
||
| _GLOBAL_SEED = 0 | ||
| logger = getLogger() | ||
|
|
||
|
|
||
| class ImageFolder(torchvision.datasets.ImageFolder): | ||
|
|
||
| def __init__( | ||
| self, | ||
| root, | ||
| image_folder='imagenet_full_size/061417/', | ||
| transform=None, | ||
| train=True, | ||
| ): | ||
| """ | ||
| ImageFolder | ||
| :param root: root network directory for ImageFolder data | ||
| :param image_folder: path to images inside root network directory | ||
| :param train: whether to load train data (or validation) | ||
| """ | ||
|
|
||
| suffix = 'train/' if train else 'val/' | ||
| data_path = os.path.join(root, image_folder, suffix) | ||
| logger.info(f'data-path {data_path}') | ||
| super(ImageFolder, self).__init__(root=data_path, transform=transform) | ||
| logger.info('Initialized ImageFolder') | ||
|
|
||
|
|
||
| def make_imagedataset( | ||
| transform, | ||
| batch_size, | ||
| collator=None, | ||
| pin_mem=True, | ||
| num_workers=8, | ||
| world_size=1, | ||
| rank=0, | ||
| root_path=None, | ||
| image_folder=None, | ||
| training=True, | ||
| copy_data=False, | ||
| drop_last=True, | ||
| persistent_workers=False, | ||
| subset_file=None | ||
| ): | ||
| dataset = ImageFolder( | ||
| root=root_path, | ||
| image_folder=image_folder, | ||
| transform=transform, | ||
| train=training) | ||
| logger.info('ImageFolder dataset created') | ||
| dist_sampler = torch.utils.data.distributed.DistributedSampler( | ||
| dataset=dataset, | ||
| num_replicas=world_size, | ||
| rank=rank) | ||
| data_loader = torch.utils.data.DataLoader( | ||
| dataset, | ||
| collate_fn=collator, | ||
| sampler=dist_sampler, | ||
| batch_size=batch_size, | ||
| drop_last=drop_last, | ||
| pin_memory=pin_mem, | ||
| num_workers=num_workers, | ||
| persistent_workers=persistent_workers) | ||
| logger.info('ImageFolder unsupervised data loader created') | ||
|
|
||
| return dataset, data_loader, dist_sampler |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,96 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import numbers | ||
| import cv2 | ||
| import numpy as np | ||
| import PIL | ||
| import torch | ||
|
|
||
|
|
||
| def _is_tensor_clip(clip): | ||
| return torch.is_tensor(clip) and clip.ndimension() == 4 | ||
|
|
||
|
|
||
| def crop_clip(clip, min_h, min_w, h, w): | ||
| if isinstance(clip[0], np.ndarray): | ||
| cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip] | ||
|
|
||
| elif isinstance(clip[0], PIL.Image.Image): | ||
| cropped = [ | ||
| img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip | ||
| ] | ||
| else: | ||
| raise TypeError('Expected numpy.ndarray or PIL.Image' + | ||
| 'but got list of {0}'.format(type(clip[0]))) | ||
| return cropped | ||
|
|
||
|
|
||
| def resize_clip(clip, size, interpolation='bilinear'): | ||
| if isinstance(clip[0], np.ndarray): | ||
| if isinstance(size, numbers.Number): | ||
| im_h, im_w, im_c = clip[0].shape | ||
| # Min spatial dim already matches minimal size | ||
| if (im_w <= im_h and im_w == size) or (im_h <= im_w | ||
| and im_h == size): | ||
| return clip | ||
| new_h, new_w = get_resize_sizes(im_h, im_w, size) | ||
| size = (new_w, new_h) | ||
| else: | ||
| size = size[0], size[1] | ||
| if interpolation == 'bilinear': | ||
| np_inter = cv2.INTER_LINEAR | ||
| else: | ||
| np_inter = cv2.INTER_NEAREST | ||
| scaled = [ | ||
| cv2.resize(img, size, interpolation=np_inter) for img in clip | ||
| ] | ||
| elif isinstance(clip[0], PIL.Image.Image): | ||
| if isinstance(size, numbers.Number): | ||
| im_w, im_h = clip[0].size | ||
| # Min spatial dim already matches minimal size | ||
| if (im_w <= im_h and im_w == size) or (im_h <= im_w | ||
| and im_h == size): | ||
| return clip | ||
| new_h, new_w = get_resize_sizes(im_h, im_w, size) | ||
| size = (new_w, new_h) | ||
| else: | ||
| size = size[1], size[0] | ||
| if interpolation == 'bilinear': | ||
| pil_inter = PIL.Image.BILINEAR | ||
| else: | ||
| pil_inter = PIL.Image.NEAREST | ||
| scaled = [img.resize(size, pil_inter) for img in clip] | ||
| else: | ||
| raise TypeError('Expected numpy.ndarray or PIL.Image' + | ||
| 'but got list of {0}'.format(type(clip[0]))) | ||
| return scaled | ||
|
|
||
|
|
||
| def get_resize_sizes(im_h, im_w, size): | ||
| if im_w < im_h: | ||
| ow = size | ||
| oh = int(size * im_h / im_w) | ||
| else: | ||
| oh = size | ||
| ow = int(size * im_w / im_h) | ||
| return oh, ow | ||
|
|
||
|
|
||
| def normalize(clip, mean, std, inplace=False): | ||
| if not _is_tensor_clip(clip): | ||
| raise TypeError('tensor is not a torch clip.') | ||
|
|
||
| if not inplace: | ||
| clip = clip.clone() | ||
|
|
||
| dtype = clip.dtype | ||
| mean = torch.as_tensor(mean, dtype=dtype, device=clip.device) | ||
| std = torch.as_tensor(std, dtype=dtype, device=clip.device) | ||
| clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) | ||
|
|
||
| return clip |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,180 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| """ | ||
| This implementation is based on | ||
| https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py | ||
| pulished under an Apache License 2.0. | ||
| """ | ||
| import math | ||
| import random | ||
| import torch | ||
|
|
||
|
|
||
| def _get_pixels( | ||
| per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda" | ||
| ): | ||
| # NOTE I've seen CUDA illegal memory access errors being caused by the normal_() | ||
| # paths, flip the order so normal is run on CPU if this becomes a problem | ||
| # Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508 | ||
| if per_pixel: | ||
| return torch.empty(patch_size, dtype=dtype, device=device).normal_() | ||
| elif rand_color: | ||
| return torch.empty( | ||
| (patch_size[0], 1, 1), dtype=dtype, device=device | ||
| ).normal_() | ||
| else: | ||
| return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device) | ||
|
|
||
|
|
||
| class RandomErasing: | ||
| """Randomly selects a rectangle region in an image and erases its pixels. | ||
| 'Random Erasing Data Augmentation' by Zhong et al. | ||
| See https://arxiv.org/pdf/1708.04896.pdf | ||
| This variant of RandomErasing is intended to be applied to either a batch | ||
| or single image tensor after it has been normalized by dataset mean and std. | ||
| Args: | ||
| probability: Probability that the Random Erasing operation will be performed. | ||
| min_area: Minimum percentage of erased area wrt input image area. | ||
| max_area: Maximum percentage of erased area wrt input image area. | ||
| min_aspect: Minimum aspect ratio of erased area. | ||
| mode: pixel color mode, one of 'const', 'rand', or 'pixel' | ||
| 'const' - erase block is constant color of 0 for all channels | ||
| 'rand' - erase block is same per-channel random (normal) color | ||
| 'pixel' - erase block is per-pixel random (normal) color | ||
| max_count: maximum number of erasing blocks per image, area per box is scaled by count. | ||
| per-image count is randomly chosen between 1 and this value. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| probability=0.5, | ||
| min_area=0.02, | ||
| max_area=1 / 3, | ||
| min_aspect=0.3, | ||
| max_aspect=None, | ||
| mode="const", | ||
| min_count=1, | ||
| max_count=None, | ||
| num_splits=0, | ||
| device="cuda", | ||
| cube=True, | ||
| ): | ||
| self.probability = probability | ||
| self.min_area = min_area | ||
| self.max_area = max_area | ||
| max_aspect = max_aspect or 1 / min_aspect | ||
| self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) | ||
| self.min_count = min_count | ||
| self.max_count = max_count or min_count | ||
| self.num_splits = num_splits | ||
| mode = mode.lower() | ||
| self.rand_color = False | ||
| self.per_pixel = False | ||
| self.cube = cube | ||
| if mode == "rand": | ||
| self.rand_color = True # per block random normal | ||
| elif mode == "pixel": | ||
| self.per_pixel = True # per pixel random normal | ||
| else: | ||
| assert not mode or mode == "const" | ||
| self.device = device | ||
|
|
||
| def _erase(self, img, chan, img_h, img_w, dtype): | ||
| if random.random() > self.probability: | ||
| return | ||
| area = img_h * img_w | ||
| count = ( | ||
| self.min_count | ||
| if self.min_count == self.max_count | ||
| else random.randint(self.min_count, self.max_count) | ||
| ) | ||
| for _ in range(count): | ||
| for _ in range(10): | ||
| target_area = ( | ||
| random.uniform(self.min_area, self.max_area) * area / count | ||
| ) | ||
| aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) | ||
| h = int(round(math.sqrt(target_area * aspect_ratio))) | ||
| w = int(round(math.sqrt(target_area / aspect_ratio))) | ||
| if w < img_w and h < img_h: | ||
| top = random.randint(0, img_h - h) | ||
| left = random.randint(0, img_w - w) | ||
| img[:, top:top + h, left:left + w] = _get_pixels( | ||
| self.per_pixel, | ||
| self.rand_color, | ||
| (chan, h, w), | ||
| dtype=dtype, | ||
| device=self.device, | ||
| ) | ||
| break | ||
|
|
||
| def _erase_cube( | ||
| self, | ||
| img, | ||
| batch_start, | ||
| batch_size, | ||
| chan, | ||
| img_h, | ||
| img_w, | ||
| dtype, | ||
| ): | ||
| if random.random() > self.probability: | ||
| return | ||
| area = img_h * img_w | ||
| count = ( | ||
| self.min_count | ||
| if self.min_count == self.max_count | ||
| else random.randint(self.min_count, self.max_count) | ||
| ) | ||
| for _ in range(count): | ||
| for _ in range(100): | ||
| target_area = ( | ||
| random.uniform(self.min_area, self.max_area) * area / count | ||
| ) | ||
| aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) | ||
| h = int(round(math.sqrt(target_area * aspect_ratio))) | ||
| w = int(round(math.sqrt(target_area / aspect_ratio))) | ||
| if w < img_w and h < img_h: | ||
| top = random.randint(0, img_h - h) | ||
| left = random.randint(0, img_w - w) | ||
| for i in range(batch_start, batch_size): | ||
| img_instance = img[i] | ||
| img_instance[ | ||
| :, top:top + h, left:left + w | ||
| ] = _get_pixels( | ||
| self.per_pixel, | ||
| self.rand_color, | ||
| (chan, h, w), | ||
| dtype=dtype, | ||
| device=self.device, | ||
| ) | ||
| break | ||
|
|
||
| def __call__(self, input): | ||
| if len(input.size()) == 3: | ||
| self._erase(input, *input.size(), input.dtype) | ||
| else: | ||
| batch_size, chan, img_h, img_w = input.size() | ||
| # skip first slice of batch if num_splits is set (for clean portion of samples) | ||
| batch_start = ( | ||
| batch_size // self.num_splits if self.num_splits > 1 else 0 | ||
| ) | ||
| if self.cube: | ||
| self._erase_cube( | ||
| input, | ||
| batch_start, | ||
| batch_size, | ||
| chan, | ||
| img_h, | ||
| img_w, | ||
| input.dtype, | ||
| ) | ||
| else: | ||
| for i in range(batch_start, batch_size): | ||
| self._erase(input[i], chan, img_h, img_w, input.dtype) | ||
| return input |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| import numpy as np | ||
| from PIL import Image | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def convert_img(img): | ||
| """Converts (H, W, C) numpy.ndarray to (C, W, H) format""" | ||
| if len(img.shape) == 3: | ||
| img = img.transpose(2, 0, 1) | ||
| if len(img.shape) == 2: | ||
| img = np.expand_dims(img, 0) | ||
| return img | ||
|
|
||
|
|
||
| class ClipToTensor(object): | ||
| """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] | ||
| to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] | ||
| """ | ||
|
|
||
| def __init__(self, channel_nb=3, div_255=True, numpy=False): | ||
| self.channel_nb = channel_nb | ||
| self.div_255 = div_255 | ||
| self.numpy = numpy | ||
|
|
||
| def __call__(self, clip): | ||
| """ | ||
| Args: clip (list of numpy.ndarray): clip (list of images) | ||
| to be converted to tensor. | ||
| """ | ||
| # Retrieve shape | ||
| if isinstance(clip[0], np.ndarray): | ||
| h, w, ch = clip[0].shape | ||
| assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) | ||
| elif isinstance(clip[0], Image.Image): | ||
| w, h = clip[0].size | ||
| else: | ||
| raise TypeError( | ||
| "Expected numpy.ndarray or PIL.Image\ | ||
| but got list of {0}".format( | ||
| type(clip[0]) | ||
| ) | ||
| ) | ||
|
|
||
| np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) | ||
|
|
||
| # Convert | ||
| for img_idx, img in enumerate(clip): | ||
| if isinstance(img, np.ndarray): | ||
| pass | ||
| elif isinstance(img, Image.Image): | ||
| img = np.array(img, copy=False) | ||
| else: | ||
| raise TypeError( | ||
| "Expected numpy.ndarray or PIL.Image\ | ||
| but got list of {0}".format( | ||
| type(clip[0]) | ||
| ) | ||
| ) | ||
| img = convert_img(img) | ||
| np_clip[:, img_idx, :, :] = img | ||
| if self.numpy: | ||
| if self.div_255: | ||
| np_clip = np_clip / 255.0 | ||
| return np_clip | ||
|
|
||
| else: | ||
| tensor_clip = torch.from_numpy(np_clip) | ||
|
|
||
| if not isinstance(tensor_clip, torch.FloatTensor): | ||
| tensor_clip = tensor_clip.float() | ||
| if self.div_255: | ||
| tensor_clip = torch.div(tensor_clip, 255) | ||
| return tensor_clip | ||
|
|
||
|
|
||
| # Note this norms data to -1/1 | ||
| class ClipToTensor_K(object): | ||
| """Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255] | ||
| to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0] | ||
| """ | ||
|
|
||
| def __init__(self, channel_nb=3, div_255=True, numpy=False): | ||
| self.channel_nb = channel_nb | ||
| self.div_255 = div_255 | ||
| self.numpy = numpy | ||
|
|
||
| def __call__(self, clip): | ||
| """ | ||
| Args: clip (list of numpy.ndarray): clip (list of images) | ||
| to be converted to tensor. | ||
| """ | ||
| # Retrieve shape | ||
| if isinstance(clip[0], np.ndarray): | ||
| h, w, ch = clip[0].shape | ||
| assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch) | ||
| elif isinstance(clip[0], Image.Image): | ||
| w, h = clip[0].size | ||
| else: | ||
| raise TypeError( | ||
| "Expected numpy.ndarray or PIL.Image\ | ||
| but got list of {0}".format( | ||
| type(clip[0]) | ||
| ) | ||
| ) | ||
|
|
||
| np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)]) | ||
|
|
||
| # Convert | ||
| for img_idx, img in enumerate(clip): | ||
| if isinstance(img, np.ndarray): | ||
| pass | ||
| elif isinstance(img, Image.Image): | ||
| img = np.array(img, copy=False) | ||
| else: | ||
| raise TypeError( | ||
| "Expected numpy.ndarray or PIL.Image\ | ||
| but got list of {0}".format( | ||
| type(clip[0]) | ||
| ) | ||
| ) | ||
| img = convert_img(img) | ||
| np_clip[:, img_idx, :, :] = img | ||
| if self.numpy: | ||
| if self.div_255: | ||
| np_clip = (np_clip - 127.5) / 127.5 | ||
| return np_clip | ||
|
|
||
| else: | ||
| tensor_clip = torch.from_numpy(np_clip) | ||
|
|
||
| if not isinstance(tensor_clip, torch.FloatTensor): | ||
| tensor_clip = tensor_clip.float() | ||
| if self.div_255: | ||
| tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5) | ||
| return tensor_clip | ||
|
|
||
|
|
||
| class ToTensor(object): | ||
| """Converts numpy array to tensor""" | ||
|
|
||
| def __call__(self, array): | ||
| tensor = torch.from_numpy(array) | ||
| return tensor |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,97 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| # | ||
|
|
||
| from typing import Iterator, Optional | ||
| from operator import itemgetter | ||
| import numpy as np | ||
|
|
||
| import torch | ||
| from torch.utils.data import ( | ||
| Dataset, | ||
| Sampler, | ||
| DistributedSampler, | ||
| WeightedRandomSampler | ||
| ) | ||
|
|
||
|
|
||
| class DatasetFromSampler(Dataset): | ||
|
|
||
| def __init__(self, sampler: Sampler): | ||
| self.sampler = sampler | ||
| self.sampler_list = None | ||
|
|
||
| def __getitem__(self, index: int): | ||
| if self.sampler_list is None: | ||
| self.sampler_list = list(self.sampler) | ||
| return self.sampler_list[index] | ||
|
|
||
| def __len__(self) -> int: | ||
| return len(self.sampler) | ||
|
|
||
|
|
||
| class DistributedSamplerWrapper(DistributedSampler): | ||
| """ Convert any Pytorch Sampler to a DistributedSampler """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| sampler, | ||
| num_replicas: Optional[int] = None, | ||
| rank: Optional[int] = None, | ||
| shuffle: bool = True, | ||
| ): | ||
| super(DistributedSamplerWrapper, self).__init__( | ||
| DatasetFromSampler(sampler), | ||
| num_replicas=num_replicas, | ||
| rank=rank, | ||
| shuffle=shuffle, | ||
| ) | ||
| self.sampler = sampler | ||
|
|
||
| def __iter__(self) -> Iterator[int]: | ||
| self.dataset = DatasetFromSampler(self.sampler) | ||
| indexes_of_indexes = super().__iter__() | ||
| subsampler_indexes = self.dataset | ||
| return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) | ||
|
|
||
|
|
||
| class CustomWeightedRandomSampler(WeightedRandomSampler): | ||
| """ Generalized WeightedRandomSampler to allow for more than 2^24 samples """ | ||
|
|
||
| def __init__(self, *args, **kwargs): | ||
| super().__init__(*args, **kwargs) | ||
|
|
||
| def __iter__(self): | ||
| rand_tensor = np.random.choice( | ||
| range(0, len(self.weights)), | ||
| size=self.num_samples, | ||
| p=self.weights.numpy() / torch.sum(self.weights).numpy(), | ||
| replace=self.replacement | ||
| ) | ||
| rand_tensor = torch.from_numpy(rand_tensor) | ||
| return iter(rand_tensor.tolist()) | ||
|
|
||
|
|
||
| class DistributedWeightedSampler(DistributedSamplerWrapper): | ||
|
|
||
| def __init__( | ||
| self, | ||
| weights, | ||
| num_replicas: Optional[int] = None, | ||
| rank: Optional[int] = None, | ||
| shuffle: bool = True, | ||
| ): | ||
| weighted_sampler = CustomWeightedRandomSampler( | ||
| weights=weights, | ||
| num_samples=len(weights), | ||
| replacement=False) | ||
|
|
||
| super(DistributedWeightedSampler, self).__init__( | ||
| sampler=weighted_sampler, | ||
| num_replicas=num_replicas, | ||
| rank=rank, | ||
| shuffle=shuffle, | ||
| ) |