Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions run_configs_2_5B.py → config/metaclip_2_5b.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# usage:
# python src/training/main.py b32_fullcc
# torchrun --nproc_per_node=8 src/training/main.py b32_fullcc
# python submitit_openclip.py b32_fullcc
# python submit.py b32_fullcc

from dataclasses import dataclass
from configs import Config
Expand All @@ -18,7 +18,7 @@ class b32_fullcc(Config):
save_frequency=1
train_data="data/metaclip_v1_2_5B/{0..200000}.tar"
workers=8
train_num_samples=400000000
train_num_samples=400000000 # assume same freq. of validation as 400M.
batch_size=512
epochs=32
model="ViT-B-32-quickgelu"
Expand Down
2 changes: 1 addition & 1 deletion run_configs_400m.py → config/metaclip_400m.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# usage:
# python src/training/main.py b32_400m
# torchrun --nproc_per_node=8 src/training/main.py b32_400m
# python submitit_openclip.py b32_400m
# python submit.py b32_400m

from dataclasses import dataclass
from configs import Config
Expand Down
1 change: 1 addition & 0 deletions run_configs_data.py → config/metaclip_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ class metaclip_2_5b:
start_shard = 0
end_shard = 200000
max_match = 170000

27 changes: 6 additions & 21 deletions configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,7 @@
from collections import OrderedDict
from dataclasses import dataclass

import sys
sys.path.append("src")

from training.params import get_default_params
from src.training.params import get_default_params


@dataclass
Expand Down Expand Up @@ -35,7 +32,7 @@ class Config:
eps = None
wd = 0.2
warmup = 2000 # 10000
use_bn_sync = False
min_ratio = 0.
skip_scheduler = False
save_frequency = 1
save_most_recent = True # False
Expand All @@ -47,9 +44,6 @@ class Config:
model = "RN50"
pretrained = ''
pretrained_image = False
lock_image = False
lock_image_unlocked_groups = 0
lock_image_freeze_bn_stats = False
grad_checkpointing = False
local_loss = False
gather_with_grad = False
Expand All @@ -58,11 +52,8 @@ class Config:
trace = False
dist_url = "env://"
dist_backend = "nccl"
report_to = ""
wandb_notes = ''
debug = False
copy_codebase = False
horovod = False
report_to = ""
ddp_static_graph = False
no_set_device_rank = False
seed = 0
Expand All @@ -78,18 +69,12 @@ def __post_init__(self):
setattr(args, name, val)


def parse_start_end(shards):
start, end = os.path.basename(shards).split("{")[1].split("}")[0].split("..")
return int(start), int(end)


def search_config(config_name):
import importlib
project_dir = os.path.dirname(__file__)
all_configs = {}
for code in os.listdir(project_dir):
if code.endswith(".py") and code.startswith("run_configs"):
module = importlib.import_module(code[:-3])
for code in os.listdir("config"):
if code.endswith(".py"):
module = importlib.import_module(f"config.{code[:-3]}")
for _config_name in dir(module):
if _config_name in ["Config"] or _config_name.startswith("__") or _config_name.startswith("run_config"):
continue
Expand Down
Loading