399 changes: 399 additions & 0 deletions LICENSE

Large diffs are not rendered by default.

400 changes: 400 additions & 0 deletions README.md

Large diffs are not rendered by default.

71 changes: 71 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse

import multiprocessing as mp

import pprint
import yaml

from app.scaffold import main as app_main
from src.utils.distributed import init_distributed

parser = argparse.ArgumentParser()
parser.add_argument(
'--fname', type=str,
help='name of config file to load',
default='configs.yaml')
parser.add_argument(
'--devices', type=str, nargs='+', default=['cuda:0'],
help='which devices to use on local machine')


def process_main(rank, fname, world_size, devices):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])

import logging
from src.utils.logging import get_logger
logger = get_logger(force=True)
if rank == 0:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.ERROR)

logger.info(f'called-params {fname}')

# Load config
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info('loaded params...')

# Log config
if rank == 0:
pprint.PrettyPrinter(indent=4).pprint(params)
dump = os.path.join(params['logging']['folder'], 'params-pretrain.yaml')
with open(dump, 'w') as f:
yaml.dump(params, f)

# Init distributed (access to comm between GPUS on same machine)
world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
logger.info(f'Running... (rank: {rank}/{world_size})')

# Launch the app with loaded config
app_main(params['app'], args=params)


if __name__ == '__main__':
args = parser.parse_args()
num_gpus = len(args.devices)
mp.set_start_method('spawn')
for rank in range(num_gpus):
mp.Process(
target=process_main,
args=(rank, args.fname, num_gpus, args.devices)
).start()
152 changes: 152 additions & 0 deletions app/main_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import os
import pprint
import yaml

import submitit

from app.scaffold import main as app_main
from src.utils.logging import get_logger

logger = get_logger(force=True)


parser = argparse.ArgumentParser()
parser.add_argument(
'--folder', type=str,
help='location to save submitit logs',
default='/fsx-jepa/massran/submitit/')
parser.add_argument(
'--exclude', type=str,
help='nodes to exclude from training',
default=None)
parser.add_argument(
'--batch-launch', action='store_true',
help='whether fname points to a file to batch-lauch several config files')
parser.add_argument(
'--fname', type=str,
help='yaml file containing config file names to launch',
default='configs.yaml')
parser.add_argument(
'--partition', type=str,
help='cluster partition to submit jobs on')
parser.add_argument(
'--time', type=int, default=4300,
help='time in minutes to run job')


class Trainer:

def __init__(self, args_pretrain, load_model=None):
self.app = args_pretrain['app']
self.args_pretrain = args_pretrain
self.load_model = load_model

def __call__(self):
app = self.app
params = self.args_pretrain
load_model = self.load_model

logger.info('loaded pretrain params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)

# Launch app with loaded config
resume_preempt = False if load_model is None else load_model
app_main(app, args=params, resume_preempt=resume_preempt)

def checkpoint(self):
fb_trainer = Trainer(self.args_pretrain, True)
return submitit.helpers.DelayedSubmission(fb_trainer,)


def launch_app_with_parsed_args(
args_for_pretrain,
submitit_folder,
partition,
timeout=4300,
nodes=1,
tasks_per_node=1,
exclude_nodes=None
):
executor = submitit.AutoExecutor(
folder=os.path.join(submitit_folder, 'job_%j'),
slurm_max_num_timeout=20)
executor.update_parameters(
slurm_partition=partition,
slurm_mem_per_gpu='55G',
timeout_min=timeout,
nodes=nodes,
tasks_per_node=tasks_per_node,
cpus_per_task=12,
gpus_per_node=tasks_per_node)

if args.exclude is not None:
executor.update_parameters(slurm_exclude=args.exclude)

jobs, trainers = [], []
with executor.batch():
for ap in args_for_pretrain:
fb_trainer = Trainer(ap)
job = executor.submit(fb_trainer,)
trainers.append(fb_trainer)
jobs.append(job)

for job in jobs:
print(job.job_id)


def launch():

# ---------------------------------------------------------------------- #
# 1. Put config file names in a list
# ---------------------------------------------------------------------- #
config_fnames = [args.fname]

# -- If batch-launch is True, then the args.fname yaml file is not a
# -- config, but actually specifies a list of other config files
# -- to run in a slurm job array
if args.batch_launch:
with open(args.fname, 'r') as y_file:
config_fnames = yaml.load(y_file, Loader=yaml.FullLoader)
# ---------------------------------------------------------------------- #

# ---------------------------------------------------------------------- #
# 2. Parse each yaml config file as a dict and place in list
# ---------------------------------------------------------------------- #
nodes, tasks_per_node = None, None
configs = []
for f in config_fnames:
with open(f, 'r') as y_file:
_params = yaml.load(y_file, Loader=yaml.FullLoader)
nodes = int(_params.get('nodes'))
tasks_per_node = int(_params.get('tasks_per_node'))
configs += [_params]
logger.info(f'Loaded {len(configs)} config files')
logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}')
# ---------------------------------------------------------------------- #

# ---------------------------------------------------------------------- #
# 3. Launch evals with parsed config files
# ---------------------------------------------------------------------- #
launch_app_with_parsed_args(
args_for_pretrain=configs,
submitit_folder=args.folder,
partition=args.partition,
timeout=args.time,
nodes=nodes,
tasks_per_node=tasks_per_node,
exclude_nodes=args.exclude)
# ---------------------------------------------------------------------- #


if __name__ == '__main__':
args = parser.parse_args()
launch()
21 changes: 21 additions & 0 deletions app/scaffold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import importlib
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def main(app, args, resume_preempt=False):

logger.info(f'Running pre-training of app: {app}')
return importlib.import_module(f'app.{app}.train').main(
args=args,
resume_preempt=resume_preempt)
586 changes: 586 additions & 0 deletions app/vjepa/train.py

Large diffs are not rendered by default.

153 changes: 153 additions & 0 deletions app/vjepa/transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import torch
import torchvision.transforms as transforms

import src.datasets.utils.video.transforms as video_transforms
from src.datasets.utils.video.randerase import RandomErasing


def make_transforms(
random_horizontal_flip=True,
random_resize_aspect_ratio=(3/4, 4/3),
random_resize_scale=(0.3, 1.0),
reprob=0.0,
auto_augment=False,
motion_shift=False,
crop_size=224,
normalize=((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
):

_frames_augmentation = VideoTransform(
random_horizontal_flip=random_horizontal_flip,
random_resize_aspect_ratio=random_resize_aspect_ratio,
random_resize_scale=random_resize_scale,
reprob=reprob,
auto_augment=auto_augment,
motion_shift=motion_shift,
crop_size=crop_size,
normalize=normalize,
)
return _frames_augmentation


class VideoTransform(object):

def __init__(
self,
random_horizontal_flip=True,
random_resize_aspect_ratio=(3/4, 4/3),
random_resize_scale=(0.3, 1.0),
reprob=0.0,
auto_augment=False,
motion_shift=False,
crop_size=224,
normalize=((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
):

self.random_horizontal_flip = random_horizontal_flip
self.random_resize_aspect_ratio = random_resize_aspect_ratio
self.random_resize_scale = random_resize_scale
self.auto_augment = auto_augment
self.motion_shift = motion_shift
self.crop_size = crop_size
self.mean = torch.tensor(normalize[0], dtype=torch.float32)
self.std = torch.tensor(normalize[1], dtype=torch.float32)
if not self.auto_augment:
# Without auto-augment, PIL and tensor conversions simply scale uint8 space by 255.
self.mean *= 255.
self.std *= 255.

self.autoaug_transform = video_transforms.create_random_augment(
input_size=(crop_size, crop_size),
auto_augment='rand-m7-n4-mstd0.5-inc1',
interpolation='bicubic',
)

self.spatial_transform = video_transforms.random_resized_crop_with_shift \
if motion_shift else video_transforms.random_resized_crop

self.reprob = reprob
self.erase_transform = RandomErasing(
reprob,
mode='pixel',
max_count=1,
num_splits=1,
device='cpu',
)

def __call__(self, buffer):

if self.auto_augment:
buffer = [transforms.ToPILImage()(frame) for frame in buffer]
buffer = self.autoaug_transform(buffer)
buffer = [transforms.ToTensor()(img) for img in buffer]
buffer = torch.stack(buffer) # T C H W
buffer = buffer.permute(0, 2, 3, 1) # T H W C
else:
buffer = torch.tensor(buffer, dtype=torch.float32)

buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W

buffer = self.spatial_transform(
images=buffer,
target_height=self.crop_size,
target_width=self.crop_size,
scale=self.random_resize_scale,
ratio=self.random_resize_aspect_ratio,
)
if self.random_horizontal_flip:
buffer, _ = video_transforms.horizontal_flip(0.5, buffer)

buffer = _tensor_normalize_inplace(buffer, self.mean, self.std)
if self.reprob > 0:
buffer = buffer.permute(1, 0, 2, 3)
buffer = self.erase_transform(buffer)
buffer = buffer.permute(1, 0, 2, 3)

return buffer


def tensor_normalize(tensor, mean, std):
"""
Normalize a given tensor by subtracting the mean and dividing the std.
Args:
tensor (tensor): tensor to normalize.
mean (tensor or list): mean value to subtract.
std (tensor or list): std to divide.
"""
if tensor.dtype == torch.uint8:
tensor = tensor.float()
tensor = tensor / 255.0
if type(mean) == list:
mean = torch.tensor(mean)
if type(std) == list:
std = torch.tensor(std)
tensor = tensor - mean
tensor = tensor / std
return tensor


def _tensor_normalize_inplace(tensor, mean, std):
"""
Normalize a given tensor by subtracting the mean and dividing the std.
Args:
tensor (tensor): tensor to normalize (with dimensions C, T, H, W).
mean (tensor): mean value to subtract (in 0 to 255 floats).
std (tensor): std to divide (in 0 to 255 floats).
"""
if tensor.dtype == torch.uint8:
tensor = tensor.float()

C, T, H, W = tensor.shape
tensor = tensor.view(C, -1).permute(1, 0) # Make C the last dimension
tensor.sub_(mean).div_(std)
tensor = tensor.permute(1, 0).view(C, T, H, W) # Put C back in front
return tensor
210 changes: 210 additions & 0 deletions app/vjepa/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import logging
import sys
import warnings
import yaml


import torch

import src.models.vision_transformer as video_vit
import src.models.predictor as vit_pred
from src.models.utils.multimask import MultiMaskWrapper, PredictorMultiMaskWrapper
from src.utils.schedulers import (
WarmupCosineSchedule,
CosineWDSchedule)
from src.utils.tensors import trunc_normal_

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def load_checkpoint(
r_path,
encoder,
predictor,
target_encoder,
opt,
scaler,
):
try:
checkpoint = torch.load(r_path, map_location=torch.device('cpu'))
except Exception as e:
logger.info(f'Encountered exception when loading checkpoint {e}')

epoch = 0
try:
epoch = checkpoint['epoch']

# -- loading encoder
pretrained_dict = checkpoint['encoder']
msg = encoder.load_state_dict(pretrained_dict)
logger.info(f'loaded pretrained encoder from epoch {epoch} with msg: {msg}')

# -- loading predictor
pretrained_dict = checkpoint['predictor']
msg = predictor.load_state_dict(pretrained_dict)
logger.info(f'loaded pretrained predictor from epoch {epoch} with msg: {msg}')

# -- loading target_encoder
if target_encoder is not None:
print(list(checkpoint.keys()))
pretrained_dict = checkpoint['target_encoder']
msg = target_encoder.load_state_dict(pretrained_dict)
logger.info(
f'loaded pretrained target encoder from epoch {epoch} with msg: {msg}'
)

# -- loading optimizer
opt.load_state_dict(checkpoint['opt'])
if scaler is not None:
scaler.load_state_dict(checkpoint['scaler'])
logger.info(f'loaded optimizers from epoch {epoch}')
logger.info(f'read-path: {r_path}')
del checkpoint

except Exception as e:
logger.info(f'Encountered exception when loading checkpoint {e}')
epoch = 0

return (
encoder,
predictor,
target_encoder,
opt,
scaler,
epoch,
)


def init_video_model(
device,
patch_size=16,
num_frames=16,
tubelet_size=2,
model_name='vit_base',
crop_size=224,
pred_depth=6,
pred_embed_dim=384,
uniform_power=False,
use_mask_tokens=False,
num_mask_tokens=2,
zero_init_mask_tokens=True,
use_sdpa=False,
):
encoder = video_vit.__dict__[model_name](
img_size=crop_size,
patch_size=patch_size,
num_frames=num_frames,
tubelet_size=tubelet_size,
uniform_power=uniform_power,
use_sdpa=use_sdpa,
)
encoder = MultiMaskWrapper(encoder)
predictor = vit_pred.__dict__['vit_predictor'](
img_size=crop_size,
use_mask_tokens=use_mask_tokens,
patch_size=patch_size,
num_frames=num_frames,
tubelet_size=tubelet_size,
embed_dim=encoder.backbone.embed_dim,
predictor_embed_dim=pred_embed_dim,
depth=pred_depth,
num_heads=encoder.backbone.num_heads,
uniform_power=uniform_power,
num_mask_tokens=num_mask_tokens,
zero_init_mask_tokens=zero_init_mask_tokens,
use_sdpa=use_sdpa,
)
predictor = PredictorMultiMaskWrapper(predictor)

def init_weights(m):
if isinstance(m, torch.nn.Linear):
trunc_normal_(m.weight, std=0.02)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, torch.nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)

for m in encoder.modules():
init_weights(m)

for m in predictor.modules():
init_weights(m)

encoder.to(device)
predictor.to(device)
logger.info(encoder)
logger.info(predictor)

def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)

logger.info(f'Encoder number of parameters: {count_parameters(encoder)}')
logger.info(f'Predictor number of parameters: {count_parameters(predictor)}')

return encoder, predictor


def init_opt(
encoder,
predictor,
iterations_per_epoch,
start_lr,
ref_lr,
warmup,
num_epochs,
wd=1e-6,
final_wd=1e-6,
final_lr=0.0,
mixed_precision=False,
ipe_scale=1.25,
betas=(0.9, 0.999),
eps=1e-8,
zero_init_bias_wd=True,
):
param_groups = [
{
'params': (p for n, p in encoder.named_parameters()
if ('bias' not in n) and (len(p.shape) != 1))
}, {
'params': (p for n, p in predictor.named_parameters()
if ('bias' not in n) and (len(p.shape) != 1))
}, {
'params': (p for n, p in encoder.named_parameters()
if ('bias' in n) or (len(p.shape) == 1)),
'WD_exclude': zero_init_bias_wd,
'weight_decay': 0,
}, {
'params': (p for n, p in predictor.named_parameters()
if ('bias' in n) or (len(p.shape) == 1)),
'WD_exclude': zero_init_bias_wd,
'weight_decay': 0,
},
]

logger.info('Using AdamW')
optimizer = torch.optim.AdamW(param_groups, betas=betas, eps=eps)
scheduler = WarmupCosineSchedule(
optimizer,
warmup_steps=int(warmup * iterations_per_epoch),
start_lr=start_lr,
ref_lr=ref_lr,
final_lr=final_lr,
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
)
wd_scheduler = CosineWDSchedule(
optimizer,
ref_wd=wd,
final_wd=final_wd,
T_max=int(ipe_scale * num_epochs * iterations_per_epoch),
)
scaler = torch.cuda.amp.GradScaler() if mixed_precision else None
return optimizer, scaler, scheduler, wd_scheduler
34 changes: 34 additions & 0 deletions configs/evals/vith16_384_in1k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: in1k-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: imagenet_full_size/061417/
num_classes: 1000
resolution: 384
dataset_name: ImageNet
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vith16_384_inat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: inat-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: iNaturalist-2021/110421/
num_classes: 10000
resolution: 384
dataset_name: iNat21
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
39 changes: 39 additions & 0 deletions configs/evals/vith16_384_k400_16x8x3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
nodes: 8
tasks_per_node: 8
tag: k400-16x8x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv
dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv
dataset_type: VideoDataset
num_classes: 400
frames_per_clip: 16
num_segments: 8
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 384
batch_size: 4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vith16_384_places.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: places-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: places205/121517/pytorch/
num_classes: 205
resolution: 384
dataset_name: Places205
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
39 changes: 39 additions & 0 deletions configs/evals/vith16_384_ssv2_16x2x3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
nodes: 8
tasks_per_node: 8
tag: ssv2-16x2x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: /your_path_to_ssv2_train_csv_file_index.csv
dataset_val: /your_path_to_ssv2_val_csv_file_index.csv
dataset_type: VideoDataset
num_classes: 174
frames_per_clip: 16
num_segments: 2
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 384
batch_size: 4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vith16_in1k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: in1k-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: imagenet_full_size/061417/
num_classes: 1000
resolution: 224
dataset_name: ImageNet
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vith16_inat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: inat-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: iNaturalist-2021/110421/
num_classes: 10000
resolution: 224
dataset_name: iNat21
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
39 changes: 39 additions & 0 deletions configs/evals/vith16_k400_16x8x3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
nodes: 8
tasks_per_node: 8
tag: k400-16x8x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv
dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv
dataset_type: VideoDataset
num_classes: 400
frames_per_clip: 16
num_segments: 8
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 224
batch_size: 4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vith16_places.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: places-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: places205/121517/pytorch/
num_classes: 205
resolution: 224
dataset_name: Places205
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
39 changes: 39 additions & 0 deletions configs/evals/vith16_ssv2_16x2x3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
nodes: 8
tasks_per_node: 8
tag: ssv2-16x2x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: /your_path_to_ssv2_train_csv_file_index.csv
dataset_val: /your_path_to_ssv2_val_csv_file_index.csv
dataset_type: VideoDataset
num_classes: 174
frames_per_clip: 16
num_segments: 2
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 224
batch_size: 4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_huge
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vitl16_in1k.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: in1k-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: imagenet_full_size/061417/
num_classes: 1000
resolution: 224
dataset_name: ImageNet
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_large
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vitl16_inat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: inat-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: iNaturalist-2021/110421/
num_classes: 10000
resolution: 224
dataset_name: iNat21
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_large
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
39 changes: 39 additions & 0 deletions configs/evals/vitl16_k400_16x8x3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
nodes: 8
tasks_per_node: 8
tag: k400-16x8x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: /your_path_to_kinetics400_train_csv_file_index.csv
dataset_val: /your_path_to_kinetics400_val_csv_file_index.csv
dataset_type: VideoDataset
num_classes: 400
frames_per_clip: 16
num_segments: 8
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 224
batch_size: 4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_large
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
34 changes: 34 additions & 0 deletions configs/evals/vitl16_places.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
nodes: 8
tasks_per_node: 8
tag: places-16f
eval_name: image_classification_frozen
resume_checkpoint: false
data:
root_path: /your_absolute_file_path_to_directory_where_image_datasets_are_stored/
image_folder: places205/121517/pytorch/
num_classes: 205
resolution: 224
dataset_name: Places205
optimization:
num_epochs: 20
batch_size: 16
weight_decay: 0.001
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_large
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_sdpa: true
use_silu: false
tight_silu: false
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
39 changes: 39 additions & 0 deletions configs/evals/vitl16_ssv2_16x2x3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
nodes: 8
tasks_per_node: 8
tag: ssv2-16x2x3
eval_name: video_classification_frozen
resume_checkpoint: false
data:
dataset_train: /your_path_to_ssv2_train_csv_file_index.csv
dataset_val: /your_path_to_ssv2_val_csv_file_index.csv
dataset_type: VideoDataset
num_classes: 174
frames_per_clip: 16
num_segments: 2
num_views_per_segment: 3
frame_step: 4
optimization:
attend_across_segments: true
num_epochs: 20
resolution: 224
batch_size: 4
weight_decay: 0.01
lr: 0.001
start_lr: 0.001
final_lr: 0.0
warmup: 0.
use_bfloat16: true
pretrain:
model_name: vit_large
checkpoint_key: target_encoder
clip_duration: null
frames_per_clip: 16
tubelet_size: 2
uniform_power: true
use_silu: false
tight_silu: false
use_sdpa: true
patch_size: 16
folder: /your_absolute_file_path_to_directory_where_pretrained_models_are_contained/
checkpoint: jepa-latest.pth.tar # name of pretrained model file inside folder
write_tag: jepa
90 changes: 90 additions & 0 deletions configs/pretrain/vith16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
app: vjepa
nodes: 16
tasks_per_node: 8
data:
dataset_type: VideoDataset
datasets:
- /your_path_to_kinetics710_csv_file_index.csv
- /your_path_to_ssv2_csv_file_index.csv
- /your_path_to_howto100m_csv_file_index.csv
decode_one_clip: true
batch_size: 24
num_clips: 1
num_frames: 16
tubelet_size: 2
sampling_rate: 4
crop_size: 224
patch_size: 16
pin_mem: true
num_workers: 12
filter_short_videos: false
clip_duration: null
data_aug:
auto_augment: false
motion_shift: false
random_resize_aspect_ratio:
- 0.75
- 1.35
random_resize_scale:
- 0.3
- 1.0
reprob: 0.0
logging:
folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/
write_tag: jepa
loss:
loss_exp: 1.0
reg_coeff: 0.0
mask:
- aspect_ratio:
- 0.75
- 1.5
num_blocks: 8
spatial_scale:
- 0.15
- 0.15
temporal_scale:
- 1.0
- 1.0
max_temporal_keep: 1.0
max_keep: null
- aspect_ratio:
- 0.75
- 1.5
num_blocks: 2
spatial_scale:
- 0.7
- 0.7
temporal_scale:
- 1.0
- 1.0
max_temporal_keep: 1.0
max_keep: null
meta:
load_checkpoint: false
read_checkpoint: null
seed: 234
eval_freq: 100
use_sdpa: true
dtype: bfloat16
model:
model_name: vit_huge
pred_depth: 12
pred_embed_dim: 384
uniform_power: true
use_mask_tokens: true
zero_init_mask_tokens: true
optimization:
ipe: 300
ipe_scale: 1.25
clip_grad: 10.0
weight_decay: 0.04
final_weight_decay: 0.4
epochs: 300
warmup: 40
start_lr: 0.0002
lr: 0.000625
final_lr: 1.0e-06
ema:
- 0.998
- 1.0
90 changes: 90 additions & 0 deletions configs/pretrain/vith16_384.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
app: vjepa
nodes: 30
tasks_per_node: 8
data:
dataset_type: VideoDataset
datasets:
- /your_path_to_kinetics710_csv_file_index.csv
- /your_path_to_ssv2_csv_file_index.csv
- /your_path_to_howto100m_csv_file_index.csv
decode_one_clip: true
batch_size: 10
num_clips: 1
num_frames: 16
tubelet_size: 2
sampling_rate: 4
crop_size: 384
patch_size: 16
pin_mem: true
num_workers: 12
filter_short_videos: false
clip_duration: null
data_aug:
auto_augment: false
motion_shift: false
random_resize_aspect_ratio:
- 0.75
- 1.35
random_resize_scale:
- 0.3
- 1.0
reprob: 0.0
logging:
folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/
write_tag: jepa
loss:
loss_exp: 1.0
reg_coeff: 0.0
mask:
- aspect_ratio:
- 0.75
- 1.5
num_blocks: 8
spatial_scale:
- 0.15
- 0.15
temporal_scale:
- 1.0
- 1.0
max_temporal_keep: 1.0
max_keep: null
- aspect_ratio:
- 0.75
- 1.5
num_blocks: 2
spatial_scale:
- 0.7
- 0.7
temporal_scale:
- 1.0
- 1.0
max_temporal_keep: 1.0
max_keep: null
meta:
load_checkpoint: false
read_checkpoint: null
seed: 234
eval_freq: 100
use_sdpa: true
dtype: bfloat16
model:
model_name: vit_huge
pred_depth: 12
pred_embed_dim: 384
uniform_power: true
use_mask_tokens: true
zero_init_mask_tokens: true
optimization:
ipe: 300
ipe_scale: 1.25
clip_grad: 10.0
weight_decay: 0.04
final_weight_decay: 0.4
epochs: 300
warmup: 40
start_lr: 0.0002
lr: 0.000625
final_lr: 1.0e-06
ema:
- 0.998
- 1.0
90 changes: 90 additions & 0 deletions configs/pretrain/vitl16.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
app: vjepa
nodes: 16
tasks_per_node: 8
data:
dataset_type: VideoDataset
datasets:
- /your_path_to_kinetics710_csv_file_index.csv
- /your_path_to_ssv2_csv_file_index.csv
- /your_path_to_howto100m_csv_file_index.csv
decode_one_clip: true
batch_size: 24
num_clips: 1
num_frames: 16
tubelet_size: 2
sampling_rate: 4
crop_size: 224
patch_size: 16
pin_mem: true
num_workers: 12
filter_short_videos: false
clip_duration: null
data_aug:
auto_augment: false
motion_shift: false
random_resize_aspect_ratio:
- 0.75
- 1.35
random_resize_scale:
- 0.3
- 1.0
reprob: 0.0
logging:
folder: /your_absolute_file_path_for_saving_logs_and_checkpoints/
write_tag: jepa
loss:
loss_exp: 1.0
reg_coeff: 0.0
mask:
- aspect_ratio:
- 0.75
- 1.5
num_blocks: 8
spatial_scale:
- 0.15
- 0.15
temporal_scale:
- 1.0
- 1.0
max_temporal_keep: 1.0
max_keep: null
- aspect_ratio:
- 0.75
- 1.5
num_blocks: 2
spatial_scale:
- 0.7
- 0.7
temporal_scale:
- 1.0
- 1.0
max_temporal_keep: 1.0
max_keep: null
meta:
load_checkpoint: false
read_checkpoint: null
seed: 234
eval_freq: 100
use_sdpa: true
dtype: bfloat16
model:
model_name: vit_large
pred_depth: 12
pred_embed_dim: 384
uniform_power: true
use_mask_tokens: true
zero_init_mask_tokens: true
optimization:
ipe: 300
ipe_scale: 1.25
clip_grad: 10.0
weight_decay: 0.04
final_weight_decay: 0.4
epochs: 300
warmup: 40
start_lr: 0.0002
lr: 0.000625
final_lr: 1.0e-06
ema:
- 0.998
- 1.0
503 changes: 503 additions & 0 deletions evals/image_classification_frozen/eval.py

Large diffs are not rendered by default.

67 changes: 67 additions & 0 deletions evals/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse

import multiprocessing as mp

import pprint
import yaml

from src.utils.distributed import init_distributed

from evals.scaffold import main as eval_main

parser = argparse.ArgumentParser()
parser.add_argument(
'--fname', type=str,
help='name of config file to load',
default='configs.yaml')
parser.add_argument(
'--devices', type=str, nargs='+', default=['cuda:0'],
help='which devices to use on local machine')


def process_main(rank, fname, world_size, devices):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])

import logging
logging.basicConfig()
logger = logging.getLogger()
if rank == 0:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.ERROR)

logger.info(f'called-params {fname}')

# Load config
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info('loaded params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)

# Init distributed (access to comm between GPUS on same machine)
world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
logger.info(f'Running... (rank: {rank}/{world_size})')

# Launch the eval with loaded config
eval_main(params['eval_name'], args_eval=params)


if __name__ == '__main__':
args = parser.parse_args()
num_gpus = len(args.devices)
mp.set_start_method('spawn')
for rank in range(num_gpus):
mp.Process(
target=process_main,
args=(rank, args.fname, num_gpus, args.devices)
).start()
162 changes: 162 additions & 0 deletions evals/main_distributed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import argparse
import logging
import os
import pprint
import sys
import time
import yaml

import submitit

from evals.scaffold import main as eval_main

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()

parser = argparse.ArgumentParser()
parser.add_argument(
'--folder', type=str,
help='location to save submitit logs',
default='/fsx-jepa/massran/submitit/')
parser.add_argument(
'--exclude', type=str,
help='nodes to exclude from training',
default=None)
parser.add_argument(
'--batch-launch', action='store_true',
help='whether fname points to a file to batch-lauch several config files')
parser.add_argument(
'--fname', type=str,
help='yaml file containing config file names to launch',
default='configs.yaml')
parser.add_argument(
'--partition', type=str,
help='cluster partition to submit jobs on')
parser.add_argument(
'--time', type=int, default=4300,
help='time in minutes to run job')


class Trainer:

def __init__(self, args_eval=None, resume_preempt=None):
self.eval_name = args_eval['eval_name']
self.args_eval = args_eval
self.resume_preempt = resume_preempt

def __call__(self):
eval_name = self.eval_name
args_eval = self.args_eval
resume_preempt = self.resume_preempt

logger.info('loaded eval params...')
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(args_eval)

eval_main(
eval_name,
args_eval=args_eval,
resume_preempt=resume_preempt)

def checkpoint(self):
fb_trainer = Trainer(self.args_eval, True)
return submitit.helpers.DelayedSubmission(fb_trainer,)


def launch_evals_with_parsed_args(
args_for_evals,
submitit_folder,
partition='learnlab,learnfair',
timeout=4300,
nodes=1,
tasks_per_node=1,
delay_seconds=10,
exclude_nodes=None
):
if not isinstance(args_for_evals, list):
logger.info(f'Passed in eval-args of type {type(args_for_evals)}')
args_for_evals = [args_for_evals]

time.sleep(delay_seconds)
logger.info('Launching evaluations in separate jobs...')
executor = submitit.AutoExecutor(
folder=os.path.join(submitit_folder, 'job_%j'),
slurm_max_num_timeout=20)
executor.update_parameters(
slurm_partition=partition,
slurm_mem_per_gpu='55G',
timeout_min=timeout,
nodes=nodes,
tasks_per_node=tasks_per_node,
cpus_per_task=12,
gpus_per_node=tasks_per_node)

if exclude_nodes is not None:
executor.update_parameters(slurm_exclude=exclude_nodes)

jobs, trainers = [], []
with executor.batch():
for ae in args_for_evals:
fb_trainer = Trainer(ae)
job = executor.submit(fb_trainer,)
trainers.append(fb_trainer)
jobs.append(job)

for job in jobs:
logger.info(f'Launched eval job with id {job.job_id}')


def launch_evals():

# ---------------------------------------------------------------------- #
# 1. Put config file names in a list
# ---------------------------------------------------------------------- #
config_fnames = [args.fname]

# -- If batch-launch is True, then the args.fname yaml file is not a
# -- config, but actually specifies a list of other config files
# -- to run in a slurm job array
if args.batch_launch:
with open(args.fname, 'r') as y_file:
config_fnames = yaml.load(y_file, Loader=yaml.FullLoader)
# ---------------------------------------------------------------------- #

# ---------------------------------------------------------------------- #
# 2. Parse each yaml config file as a dict and place in list
# ---------------------------------------------------------------------- #
nodes, tasks_per_node = None, None
configs = []
for f in config_fnames:
with open(f, 'r') as y_file:
_params = yaml.load(y_file, Loader=yaml.FullLoader)
nodes = int(_params.get('nodes'))
tasks_per_node = int(_params.get('tasks_per_node'))
configs += [_params]
logger.info(f'Loaded {len(configs)} config files')
logger.info(f'Running all jobs with {nodes=} / {tasks_per_node=}')
# ---------------------------------------------------------------------- #

# ---------------------------------------------------------------------- #
# 3. Launch evals with parsed config files
# ---------------------------------------------------------------------- #
launch_evals_with_parsed_args(
args_for_evals=configs,
submitit_folder=args.folder,
partition=args.partition,
timeout=args.time,
nodes=nodes,
tasks_per_node=tasks_per_node,
exclude_nodes=args.exclude)
# ---------------------------------------------------------------------- #


if __name__ == '__main__':
args = parser.parse_args()
launch_evals()
24 changes: 24 additions & 0 deletions evals/scaffold.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import importlib
import logging
import sys

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
logger = logging.getLogger()


def main(
eval_name,
args_eval,
resume_preempt=False
):
logger.info(f'Running evaluation: {eval_name}')
return importlib.import_module(f'evals.{eval_name}.eval').main(
args_eval=args_eval,
resume_preempt=resume_preempt)
561 changes: 561 additions & 0 deletions evals/video_classification_frozen/eval.py

Large diffs are not rendered by default.

343 changes: 343 additions & 0 deletions evals/video_classification_frozen/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,343 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np

import torch
import torch.nn as nn
import torchvision.transforms as transforms

import src.datasets.utils.video.transforms as video_transforms
import src.datasets.utils.video.volume_transforms as volume_transforms

from src.datasets.utils.video.randerase import RandomErasing

from src.models.utils.pos_embs import get_1d_sincos_pos_embed
from src.masks.utils import apply_masks


class FrameAggregation(nn.Module):
"""
Process each frame independently and concatenate all tokens
"""

def __init__(
self,
model,
max_frames=10000,
use_pos_embed=False,
attend_across_segments=False
):
super().__init__()
self.model = model
self.embed_dim = embed_dim = model.embed_dim
self.num_heads = model.num_heads
self.attend_across_segments = attend_across_segments
# 1D-temporal pos-embedding
self.pos_embed = None
if use_pos_embed:
self.pos_embed = nn.Parameter(
torch.zeros(1, max_frames, embed_dim),
requires_grad=False)
sincos = get_1d_sincos_pos_embed(embed_dim, max_frames)
self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))

def forward(self, x, clip_indices=None):

# TODO: impement attend_across_segments=False
# num_clips = len(x)
num_views_per_clip = len(x[0])

# Concatenate views along batch dimension
x = [torch.cat(xi, dim=0) for xi in x]
# Concatenate clips along temporal dimension
x = torch.cat(x, dim=2)
B, C, T, H, W = x.size()

# Put each frame along the batch dimension
x = x.permute(0, 2, 1, 3, 4).reshape(B*T, C, H, W)

outputs = self.model(x)
_, N, D = outputs.size()
outputs = outputs.reshape(B, T, N, D).flatten(1, 2)

# Separate views into list
B = B // num_views_per_clip
all_outputs = []
for i in range(num_views_per_clip):
o = outputs[i*B:(i+1)*B]
# Compute positional embedding
if (self.pos_embed is not None) and (clip_indices is not None):
pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D]
pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D]))
pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension
pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D]
pos_embed = pos_embed.flatten(1, 2)
o += pos_embed
all_outputs += [o]

return all_outputs


class ClipAggregation(nn.Module):
"""
Process each clip indepdnently and concatenate all tokens
"""

def __init__(
self,
model,
tubelet_size=2,
max_frames=10000,
use_pos_embed=False,
attend_across_segments=False
):
super().__init__()
self.model = model
self.tubelet_size = tubelet_size
self.embed_dim = embed_dim = model.embed_dim
self.num_heads = model.num_heads
self.attend_across_segments = attend_across_segments
# 1D-temporal pos-embedding
self.pos_embed = None
if use_pos_embed:
max_T = max_frames // tubelet_size
self.pos_embed = nn.Parameter(
torch.zeros(1, max_T, embed_dim),
requires_grad=False)
sincos = get_1d_sincos_pos_embed(embed_dim, max_T)
self.pos_embed.copy_(torch.from_numpy(sincos).float().unsqueeze(0))

def forward(self, x, clip_indices=None):

num_clips = len(x)
num_views_per_clip = len(x[0])
B, C, T, H, W = x[0][0].size()

# Concatenate all spatial and temporal views along batch dimension
x = [torch.cat(xi, dim=0) for xi in x]
x = torch.cat(x, dim=0)
outputs = self.model(x)
_, N, D = outputs.size()

T = T // self.tubelet_size # Num temporal tokens
N = N // T # Num spatial tokens

# Unroll outputs into a 2D array [spatial_views x temporal_views]
eff_B = B * num_views_per_clip
all_outputs = [[] for _ in range(num_views_per_clip)]
for i in range(num_clips):
o = outputs[i*eff_B:(i+1)*eff_B]
for j in range(num_views_per_clip):
all_outputs[j].append(o[j*B:(j+1)*B])

if not self.attend_across_segments:
return all_outputs

for i, outputs in enumerate(all_outputs):

# Concatenate along temporal dimension
outputs = [o.reshape(B, T, N, D) for o in outputs]
outputs = torch.cat(outputs, dim=1).flatten(1, 2)

# Compute positional embedding
if (self.pos_embed is not None) and (clip_indices is not None):
clip_indices = [c[:, ::self.tubelet_size] for c in clip_indices]
pos_embed = self.pos_embed.repeat(B, 1, 1) # [B, F, D]
pos_embed = apply_masks(pos_embed, clip_indices, concat=False) # list(Tensor([B, T, D]))
pos_embed = torch.cat(pos_embed, dim=1) # concatenate along temporal dimension
pos_embed = pos_embed.unsqueeze(2).repeat(1, 1, N, 1) # [B, T*num_clips, N, D]
pos_embed = pos_embed.flatten(1, 2)
outputs += pos_embed

all_outputs[i] = outputs

return all_outputs


def make_transforms(
training=True,
random_horizontal_flip=True,
random_resize_aspect_ratio=(3/4, 4/3),
random_resize_scale=(0.3, 1.0),
reprob=0.0,
auto_augment=False,
motion_shift=False,
crop_size=224,
num_views_per_clip=1,
normalize=((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
):

if not training and num_views_per_clip > 1:
print('Making EvalVideoTransform, multi-view')
_frames_augmentation = EvalVideoTransform(
num_views_per_clip=num_views_per_clip,
short_side_size=crop_size,
normalize=normalize,
)

else:
_frames_augmentation = VideoTransform(
training=training,
random_horizontal_flip=random_horizontal_flip,
random_resize_aspect_ratio=random_resize_aspect_ratio,
random_resize_scale=random_resize_scale,
reprob=reprob,
auto_augment=auto_augment,
motion_shift=motion_shift,
crop_size=crop_size,
normalize=normalize,
)
return _frames_augmentation


class VideoTransform(object):

def __init__(
self,
training=True,
random_horizontal_flip=True,
random_resize_aspect_ratio=(3/4, 4/3),
random_resize_scale=(0.3, 1.0),
reprob=0.0,
auto_augment=False,
motion_shift=False,
crop_size=224,
normalize=((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
):

self.training = training

short_side_size = int(crop_size * 256 / 224)
self.eval_transform = video_transforms.Compose([
video_transforms.Resize(short_side_size, interpolation='bilinear'),
video_transforms.CenterCrop(size=(crop_size, crop_size)),
volume_transforms.ClipToTensor(),
video_transforms.Normalize(mean=normalize[0], std=normalize[1])
])

self.random_horizontal_flip = random_horizontal_flip
self.random_resize_aspect_ratio = random_resize_aspect_ratio
self.random_resize_scale = random_resize_scale
self.auto_augment = auto_augment
self.motion_shift = motion_shift
self.crop_size = crop_size
self.normalize = torch.tensor(normalize)

self.autoaug_transform = video_transforms.create_random_augment(
input_size=(crop_size, crop_size),
auto_augment='rand-m7-n4-mstd0.5-inc1',
interpolation='bicubic',
)

self.spatial_transform = video_transforms.random_resized_crop_with_shift \
if motion_shift else video_transforms.random_resized_crop

self.reprob = reprob
self.erase_transform = RandomErasing(
reprob,
mode='pixel',
max_count=1,
num_splits=1,
device='cpu',
)

def __call__(self, buffer):

if not self.training:
return [self.eval_transform(buffer)]

buffer = [transforms.ToPILImage()(frame) for frame in buffer]

if self.auto_augment:
buffer = self.autoaug_transform(buffer)

buffer = [transforms.ToTensor()(img) for img in buffer]
buffer = torch.stack(buffer) # T C H W
buffer = buffer.permute(0, 2, 3, 1) # T H W C

buffer = tensor_normalize(buffer, self.normalize[0], self.normalize[1])
buffer = buffer.permute(3, 0, 1, 2) # T H W C -> C T H W

buffer = self.spatial_transform(
images=buffer,
target_height=self.crop_size,
target_width=self.crop_size,
scale=self.random_resize_scale,
ratio=self.random_resize_aspect_ratio,
)
if self.random_horizontal_flip:
buffer, _ = video_transforms.horizontal_flip(0.5, buffer)

if self.reprob > 0:
buffer = buffer.permute(1, 0, 2, 3)
buffer = self.erase_transform(buffer)
buffer = buffer.permute(1, 0, 2, 3)

return [buffer]


class EvalVideoTransform(object):

def __init__(
self,
num_views_per_clip=1,
short_side_size=224,
normalize=((0.485, 0.456, 0.406),
(0.229, 0.224, 0.225))
):
self.views_per_clip = num_views_per_clip
self.short_side_size = short_side_size
self.spatial_resize = video_transforms.Resize(short_side_size, interpolation='bilinear')
self.to_tensor = video_transforms.Compose([
volume_transforms.ClipToTensor(),
video_transforms.Normalize(mean=normalize[0], std=normalize[1])
])

def __call__(self, buffer):

# Sample several spatial views of each clip
buffer = np.array(self.spatial_resize(buffer))
T, H, W, C = buffer.shape

num_views = self.views_per_clip
side_len = self.short_side_size
spatial_step = (max(H, W) - side_len) // (num_views - 1)

all_views = []
for i in range(num_views):
start = i*spatial_step
if H > W:
view = buffer[:, start:start+side_len, :, :]
else:
view = buffer[:, :, start:start+side_len, :]
view = self.to_tensor(view)
all_views.append(view)

return all_views


def tensor_normalize(tensor, mean, std):
"""
Normalize a given tensor by subtracting the mean and dividing the std.
Args:
tensor (tensor): tensor to normalize.
mean (tensor or list): mean value to subtract.
std (tensor or list): std to divide.
"""
if tensor.dtype == torch.uint8:
tensor = tensor.float()
tensor = tensor / 255.0
if type(mean) == list:
mean = torch.tensor(mean)
if type(std) == list:
std = torch.tensor(std)
tensor = tensor - mean
tensor = tensor / std
return tensor
13 changes: 13 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
torch>=2
torchvision
pyyaml
numpy
opencv-python
submitit
braceexpand
webdataset
timm
decord
pandas
einops
beartype
25 changes: 25 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os
from setuptools import setup

VERSION = "0.0.1"

def get_requirements():
with open("./requirements.txt") as reqsf:
reqs = reqsf.readlines()
return reqs


if __name__ == "__main__":
setup(
name="jepa",
version=VERSION,
description="JEPA research code.",
python_requires=">=3.9",
install_requires=get_requirements(),
)
90 changes: 90 additions & 0 deletions src/datasets/data_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from logging import getLogger


_GLOBAL_SEED = 0
logger = getLogger()


def init_data(
batch_size,
transform=None,
shared_transform=None,
data='ImageNet',
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
copy_data=False,
drop_last=True,
tokenize_txt=True,
subset_file=None,
clip_len=8,
frame_sample_rate=2,
duration=None,
num_clips=1,
random_clip_sampling=True,
allow_clip_overlap=False,
filter_short_videos=False,
filter_long_videos=int(1e9),
decode_one_clip=True,
datasets_weights=None,
persistent_workers=False,
repeat_wds=False,
ipe=300,
log_dir=None,
):

if (data.lower() == 'imagenet') \
or (data.lower() == 'inat21') \
or (data.lower() == 'places205'):
from src.datasets.image_dataset import make_imagedataset
dataset, data_loader, dist_sampler = make_imagedataset(
transform=transform,
batch_size=batch_size,
collator=collator,
pin_mem=pin_mem,
training=training,
num_workers=num_workers,
world_size=world_size,
rank=rank,
root_path=root_path,
image_folder=image_folder,
persistent_workers=persistent_workers,
copy_data=copy_data,
drop_last=drop_last,
subset_file=subset_file)

elif data.lower() == 'videodataset':
from src.datasets.video_dataset import make_videodataset
dataset, data_loader, dist_sampler = make_videodataset(
data_paths=root_path,
batch_size=batch_size,
frames_per_clip=clip_len,
frame_step=frame_sample_rate,
duration=duration,
num_clips=num_clips,
random_clip_sampling=random_clip_sampling,
allow_clip_overlap=allow_clip_overlap,
filter_short_videos=filter_short_videos,
filter_long_videos=filter_long_videos,
shared_transform=shared_transform,
transform=transform,
datasets_weights=datasets_weights,
collator=collator,
num_workers=num_workers,
world_size=world_size,
rank=rank,
log_dir=log_dir)

return (data_loader, dist_sampler)
79 changes: 79 additions & 0 deletions src/datasets/image_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import os

from logging import getLogger

import torch
import torchvision

_GLOBAL_SEED = 0
logger = getLogger()


class ImageFolder(torchvision.datasets.ImageFolder):

def __init__(
self,
root,
image_folder='imagenet_full_size/061417/',
transform=None,
train=True,
):
"""
ImageFolder
:param root: root network directory for ImageFolder data
:param image_folder: path to images inside root network directory
:param train: whether to load train data (or validation)
"""

suffix = 'train/' if train else 'val/'
data_path = os.path.join(root, image_folder, suffix)
logger.info(f'data-path {data_path}')
super(ImageFolder, self).__init__(root=data_path, transform=transform)
logger.info('Initialized ImageFolder')


def make_imagedataset(
transform,
batch_size,
collator=None,
pin_mem=True,
num_workers=8,
world_size=1,
rank=0,
root_path=None,
image_folder=None,
training=True,
copy_data=False,
drop_last=True,
persistent_workers=False,
subset_file=None
):
dataset = ImageFolder(
root=root_path,
image_folder=image_folder,
transform=transform,
train=training)
logger.info('ImageFolder dataset created')
dist_sampler = torch.utils.data.distributed.DistributedSampler(
dataset=dataset,
num_replicas=world_size,
rank=rank)
data_loader = torch.utils.data.DataLoader(
dataset,
collate_fn=collator,
sampler=dist_sampler,
batch_size=batch_size,
drop_last=drop_last,
pin_memory=pin_mem,
num_workers=num_workers,
persistent_workers=persistent_workers)
logger.info('ImageFolder unsupervised data loader created')

return dataset, data_loader, dist_sampler
96 changes: 96 additions & 0 deletions src/datasets/utils/video/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numbers
import cv2
import numpy as np
import PIL
import torch


def _is_tensor_clip(clip):
return torch.is_tensor(clip) and clip.ndimension() == 4


def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]

elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped


def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[0], size[1]
if interpolation == 'bilinear':
np_inter = cv2.INTER_LINEAR
else:
np_inter = cv2.INTER_NEAREST
scaled = [
cv2.resize(img, size, interpolation=np_inter) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.BILINEAR
else:
pil_inter = PIL.Image.NEAREST
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled


def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow


def normalize(clip, mean, std, inplace=False):
if not _is_tensor_clip(clip):
raise TypeError('tensor is not a torch clip.')

if not inplace:
clip = clip.clone()

dtype = clip.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
std = torch.as_tensor(std, dtype=dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])

return clip
518 changes: 518 additions & 0 deletions src/datasets/utils/video/randaugment.py

Large diffs are not rendered by default.

180 changes: 180 additions & 0 deletions src/datasets/utils/video/randerase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

"""
This implementation is based on
https://github.com/rwightman/pytorch-image-models/blob/master/timm/data/random_erasing.py
pulished under an Apache License 2.0.
"""
import math
import random
import torch


def _get_pixels(
per_pixel, rand_color, patch_size, dtype=torch.float32, device="cuda"
):
# NOTE I've seen CUDA illegal memory access errors being caused by the normal_()
# paths, flip the order so normal is run on CPU if this becomes a problem
# Issue has been fixed in master https://github.com/pytorch/pytorch/issues/19508
if per_pixel:
return torch.empty(patch_size, dtype=dtype, device=device).normal_()
elif rand_color:
return torch.empty(
(patch_size[0], 1, 1), dtype=dtype, device=device
).normal_()
else:
return torch.zeros((patch_size[0], 1, 1), dtype=dtype, device=device)


class RandomErasing:
"""Randomly selects a rectangle region in an image and erases its pixels.
'Random Erasing Data Augmentation' by Zhong et al.
See https://arxiv.org/pdf/1708.04896.pdf
This variant of RandomErasing is intended to be applied to either a batch
or single image tensor after it has been normalized by dataset mean and std.
Args:
probability: Probability that the Random Erasing operation will be performed.
min_area: Minimum percentage of erased area wrt input image area.
max_area: Maximum percentage of erased area wrt input image area.
min_aspect: Minimum aspect ratio of erased area.
mode: pixel color mode, one of 'const', 'rand', or 'pixel'
'const' - erase block is constant color of 0 for all channels
'rand' - erase block is same per-channel random (normal) color
'pixel' - erase block is per-pixel random (normal) color
max_count: maximum number of erasing blocks per image, area per box is scaled by count.
per-image count is randomly chosen between 1 and this value.
"""

def __init__(
self,
probability=0.5,
min_area=0.02,
max_area=1 / 3,
min_aspect=0.3,
max_aspect=None,
mode="const",
min_count=1,
max_count=None,
num_splits=0,
device="cuda",
cube=True,
):
self.probability = probability
self.min_area = min_area
self.max_area = max_area
max_aspect = max_aspect or 1 / min_aspect
self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect))
self.min_count = min_count
self.max_count = max_count or min_count
self.num_splits = num_splits
mode = mode.lower()
self.rand_color = False
self.per_pixel = False
self.cube = cube
if mode == "rand":
self.rand_color = True # per block random normal
elif mode == "pixel":
self.per_pixel = True # per pixel random normal
else:
assert not mode or mode == "const"
self.device = device

def _erase(self, img, chan, img_h, img_w, dtype):
if random.random() > self.probability:
return
area = img_h * img_w
count = (
self.min_count
if self.min_count == self.max_count
else random.randint(self.min_count, self.max_count)
)
for _ in range(count):
for _ in range(10):
target_area = (
random.uniform(self.min_area, self.max_area) * area / count
)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
img[:, top:top + h, left:left + w] = _get_pixels(
self.per_pixel,
self.rand_color,
(chan, h, w),
dtype=dtype,
device=self.device,
)
break

def _erase_cube(
self,
img,
batch_start,
batch_size,
chan,
img_h,
img_w,
dtype,
):
if random.random() > self.probability:
return
area = img_h * img_w
count = (
self.min_count
if self.min_count == self.max_count
else random.randint(self.min_count, self.max_count)
)
for _ in range(count):
for _ in range(100):
target_area = (
random.uniform(self.min_area, self.max_area) * area / count
)
aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio))
h = int(round(math.sqrt(target_area * aspect_ratio)))
w = int(round(math.sqrt(target_area / aspect_ratio)))
if w < img_w and h < img_h:
top = random.randint(0, img_h - h)
left = random.randint(0, img_w - w)
for i in range(batch_start, batch_size):
img_instance = img[i]
img_instance[
:, top:top + h, left:left + w
] = _get_pixels(
self.per_pixel,
self.rand_color,
(chan, h, w),
dtype=dtype,
device=self.device,
)
break

def __call__(self, input):
if len(input.size()) == 3:
self._erase(input, *input.size(), input.dtype)
else:
batch_size, chan, img_h, img_w = input.size()
# skip first slice of batch if num_splits is set (for clean portion of samples)
batch_start = (
batch_size // self.num_splits if self.num_splits > 1 else 0
)
if self.cube:
self._erase_cube(
input,
batch_start,
batch_size,
chan,
img_h,
img_w,
input.dtype,
)
else:
for i in range(batch_start, batch_size):
self._erase(input[i], chan, img_h, img_w, input.dtype)
return input
1,184 changes: 1,184 additions & 0 deletions src/datasets/utils/video/transforms.py

Large diffs are not rendered by default.

151 changes: 151 additions & 0 deletions src/datasets/utils/video/volume_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

import numpy as np
from PIL import Image

import torch


def convert_img(img):
"""Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
if len(img.shape) == 3:
img = img.transpose(2, 0, 1)
if len(img.shape) == 2:
img = np.expand_dims(img, 0)
return img


class ClipToTensor(object):
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
"""

def __init__(self, channel_nb=3, div_255=True, numpy=False):
self.channel_nb = channel_nb
self.div_255 = div_255
self.numpy = numpy

def __call__(self, clip):
"""
Args: clip (list of numpy.ndarray): clip (list of images)
to be converted to tensor.
"""
# Retrieve shape
if isinstance(clip[0], np.ndarray):
h, w, ch = clip[0].shape
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
elif isinstance(clip[0], Image.Image):
w, h = clip[0].size
else:
raise TypeError(
"Expected numpy.ndarray or PIL.Image\
but got list of {0}".format(
type(clip[0])
)
)

np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])

# Convert
for img_idx, img in enumerate(clip):
if isinstance(img, np.ndarray):
pass
elif isinstance(img, Image.Image):
img = np.array(img, copy=False)
else:
raise TypeError(
"Expected numpy.ndarray or PIL.Image\
but got list of {0}".format(
type(clip[0])
)
)
img = convert_img(img)
np_clip[:, img_idx, :, :] = img
if self.numpy:
if self.div_255:
np_clip = np_clip / 255.0
return np_clip

else:
tensor_clip = torch.from_numpy(np_clip)

if not isinstance(tensor_clip, torch.FloatTensor):
tensor_clip = tensor_clip.float()
if self.div_255:
tensor_clip = torch.div(tensor_clip, 255)
return tensor_clip


# Note this norms data to -1/1
class ClipToTensor_K(object):
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
"""

def __init__(self, channel_nb=3, div_255=True, numpy=False):
self.channel_nb = channel_nb
self.div_255 = div_255
self.numpy = numpy

def __call__(self, clip):
"""
Args: clip (list of numpy.ndarray): clip (list of images)
to be converted to tensor.
"""
# Retrieve shape
if isinstance(clip[0], np.ndarray):
h, w, ch = clip[0].shape
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
elif isinstance(clip[0], Image.Image):
w, h = clip[0].size
else:
raise TypeError(
"Expected numpy.ndarray or PIL.Image\
but got list of {0}".format(
type(clip[0])
)
)

np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])

# Convert
for img_idx, img in enumerate(clip):
if isinstance(img, np.ndarray):
pass
elif isinstance(img, Image.Image):
img = np.array(img, copy=False)
else:
raise TypeError(
"Expected numpy.ndarray or PIL.Image\
but got list of {0}".format(
type(clip[0])
)
)
img = convert_img(img)
np_clip[:, img_idx, :, :] = img
if self.numpy:
if self.div_255:
np_clip = (np_clip - 127.5) / 127.5
return np_clip

else:
tensor_clip = torch.from_numpy(np_clip)

if not isinstance(tensor_clip, torch.FloatTensor):
tensor_clip = tensor_clip.float()
if self.div_255:
tensor_clip = torch.div(torch.sub(tensor_clip, 127.5), 127.5)
return tensor_clip


class ToTensor(object):
"""Converts numpy array to tensor"""

def __call__(self, array):
tensor = torch.from_numpy(array)
return tensor
97 changes: 97 additions & 0 deletions src/datasets/utils/weighted_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#

from typing import Iterator, Optional
from operator import itemgetter
import numpy as np

import torch
from torch.utils.data import (
Dataset,
Sampler,
DistributedSampler,
WeightedRandomSampler
)


class DatasetFromSampler(Dataset):

def __init__(self, sampler: Sampler):
self.sampler = sampler
self.sampler_list = None

def __getitem__(self, index: int):
if self.sampler_list is None:
self.sampler_list = list(self.sampler)
return self.sampler_list[index]

def __len__(self) -> int:
return len(self.sampler)


class DistributedSamplerWrapper(DistributedSampler):
""" Convert any Pytorch Sampler to a DistributedSampler """

def __init__(
self,
sampler,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
super(DistributedSamplerWrapper, self).__init__(
DatasetFromSampler(sampler),
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
self.sampler = sampler

def __iter__(self) -> Iterator[int]:
self.dataset = DatasetFromSampler(self.sampler)
indexes_of_indexes = super().__iter__()
subsampler_indexes = self.dataset
return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))


class CustomWeightedRandomSampler(WeightedRandomSampler):
""" Generalized WeightedRandomSampler to allow for more than 2^24 samples """

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def __iter__(self):
rand_tensor = np.random.choice(
range(0, len(self.weights)),
size=self.num_samples,
p=self.weights.numpy() / torch.sum(self.weights).numpy(),
replace=self.replacement
)
rand_tensor = torch.from_numpy(rand_tensor)
return iter(rand_tensor.tolist())


class DistributedWeightedSampler(DistributedSamplerWrapper):

def __init__(
self,
weights,
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
shuffle: bool = True,
):
weighted_sampler = CustomWeightedRandomSampler(
weights=weights,
num_samples=len(weights),
replacement=False)

super(DistributedWeightedSampler, self).__init__(
sampler=weighted_sampler,
num_replicas=num_replicas,
rank=rank,
shuffle=shuffle,
)
Loading