diff --git a/mode/README.md b/mode/README.md new file mode 100644 index 0000000..3922074 --- /dev/null +++ b/mode/README.md @@ -0,0 +1,208 @@ +# MoDE: CLIP Data Experts via Clustering + +This repository contains the code for the Mixture of Data Experts, described in the paper [MoDE: CLIP Data Experts via Clustering](https://arxiv.org/abs/2309.16671) that provides the first multi-modal understanding system based on independent CLIP models. The main contributions are: + - Introducing the concept of **data expert** and making the MoDE framework where several small models are separately learned but adaptively ensembled for each task. + - Studying how to build a **wider** system, rather than a deeper network. The system is scalable and capable of integrating new data experts, without compromising the extablished ability, which can thus be applied to online data and be continuously updated. + - Investigating the quality negative samples in contrastive language-image pretraining, and in particular, the false negatives in web-crawled image-caption pairs. + - Demonstrating that a set of small data experts can be comparable with a single large model. As the data experts can be trained asynchorously, MoDE significantly reduces the mximum computation requirement, shedding light on research based on limited computation resource. + +We conclude that: + - Effective pretraining should **carefully examine the data distribution**, instead of aggressively learning from the whole dataset. + - Data can be used to explain the model capability and determine the ensemble of models (deep learning is data driven). + - Our algorithm is simpler and easily scalable to comsume the data in the whole Internet + +MoDE is trained w/ face blurred images. + +```bibtex +@inproceedings{ma2024mode, + title={MoDE: CLIP Data Experts via Clustering}, + author={Ma, Jiawei and Huang, Po-Yao and Xie, Saining and Li, Shang-Wen and Zettlemoyer, Luke and Chang, Shih-Fu and Yih, Wen-Tau and Xu, Hu}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2024} +} + +@inproceedings{xu2023metaclip, + title={Demystifying CLIP Data}, + author={Xu, Hu and Xie, Saining and Tan, Xiaoqing and Huang, Po-Yao and Howes, Russell and Sharma, Vasu and Li, Shang-Wen and Ghosh, Gargi and Zettlemoyer, Luke and Feichtenhofer, Christoph}, + booktitle={The Twelfth International Conference on Learning Representations}, + year={2023} +} +``` + +## Quick Links + + - [Getting Started](#getting-started) + - [Data Preparation](#data-preparation) + - [Clustering](#clustering) + - [Training](#training) + - [Inference-Time Task Adaptation (Ensemble)](#ensemble) + - [Bugs or Questions?](#bugs-or-questions) + - [Citation](#citation) + - [Reference](#reference) + + +## Getting Started + +This code is developed with minimal changes on top of [MetaCLIP](https://github.com/facebookresearch/MetaCLIP). The following command should install requirements for MetaCLIP and `submitit=1.2.1` used by this repo: + +```bash +conda create -n python=3.10 pytorch torchvision pytorch-cuda=11.7 tqdm ftfy braceexpand webdataset regex pandas submitit=1.2.1 \ + -c pytorch-nightly \ + -c nvidia \ + -c conda-forge \ + -c anaconda +``` + +Then, please refer to the following repo to install the code for kmeans clustering +```bash +https://github.com/subhadarship/kmeans_pytorch/tree/master +``` + +Finally, please move the config-related files from this folder to the root +```bash +mv move2root/ ../ +rm -r move2root +``` + +## Data Preparation + +In this example code, we assume the dataset is called `demo` and all of the image-caption pairs are saved in a bunch of tarfiles while all tarfiles are tarfiles are organized in sharded folders +``` +'demo': + '0': + '0.tar' + '100.tar' + ... + '1': + '1.tar' + '101.tar' + ... + ... + '99': + '99.tar' + '199.tar' + ... +``` +Within each tarfile, the image-caption pairs are saved in sequence. +``` +., json, jpeg, json, jpeg ... +``` +where for each pair, the text is first stored in a `json` file and the image is then saved in `jpeg`. + +For the following steps, we have provided a detailed command example under `prep-steps` in `run_mode.sh` for explanation & usage. +The configuration and the paths for intermediate data storing are summarized in `mode/get_prep_parser.py`. When you run the code, please make sure to be in the root directory of the whole project. For the customization of your own data, you can also modify the `get_default_paths` function in the `py` file. + +## Clustering + +Data clustering is performed on the language embeddings of captions. This section mainly explains feature extraction and data clustering. +For large-scale data processing, we provide the optimized code below to separate the steps and enable multi-thread processing. + +### Step 0 Preparing Captions + +This step considers the tarfile where the image-caption pairs are stored together. +As caption extraction is CPU-only, we provide the function below to enable multi-thread caption collection (This is highly recommended for large-scale data processing). + +```bash +python mode/prep_caption.py +``` + +### Step 1 Preparing Features + +This step extracts the language embeddings of captions, and the features for captions in one tarfile will be stored in a single pth file. Following the organization of tarfiles, we also organize the features in sharded folders. + +When the captions are pre-collected (via step 0), run the command below to extract the features for captions where each thread is allocated on one GPU chip. + +```bash +torchrun --nproc_per_node=8 mode/prep_feature.py --file-mode caption +``` + +As an alternative, you can skip step 0 and directly do feature extraction from the tarfiles. + +```bash +torchrun --nproc_per_node=8 mode/prep_feature.py --file-mode tarfile +``` + +### Step 2 Two-Step Clustering + +Once the features are ready, perform two-step clustering to obtain the finegrained clusters and the coarse-grained condition. Note we only use a fraction of the whole data to do the clustering on a single GPU chip. Once finished, both the finegrained clusters, coarse-grained clusters can be provided. + +```bash +torchrun --nproc_per_node=1 mode/prep_hrchy.py +``` + +### Step 3 Cluster Assignment + +Once the cluster centers are obtained, use nearest neighborhood search to determine the cluster assignment for each pair. This process is CPU-only and the code below supports multi-thread processing. + +```bash +python mode/prep_inference.py +``` + +## Training + +Once the cluster assignment is ready, we do normal training as CLIP but just alter the data sampling. Please check the config file `run_configs_mode.py` and manually change the expert ID via `coarse_idx` to determine the data expert model to be trained. + +```bash +torchrun --nproc_per_node=8 src/training/main.py b32_mode +``` + +## Ensemble + +Given the well-trained expert models, for comprehensive evaluation, we gather the outputs from each expert model as well as the ensembled output, and summarize them as a report in original experiment log folder. + +Firstly, we evaluate each model and gather their outputs for ensembling. + +```bash +torchrun --master_port=29600 --nproc_per_node=4 mode/post_expert_eval.py b32_mode +``` + +Then, as a preparation for ensembling, we extract the language embeddings of task metadata, e.g., class names. We reuse the feature extraction file but pass different arguments. + +```bash +python mode/post_report_ensemble.py b32_mode ${DIR_CLIPEVAL} +``` + +Lastly, we use the similarity between metadata embeddings and cluster centers to determine ensembling weights for evaluation. By running the command below, all results will be summarized in a csv file. + +```bash +python mode/post_report_ensemble.py b32_mode ${DIR_CLIPEVAL} +``` + +## Bugs or questions? + +If you have any questions related to the code or the paper, feel free to email Jiawei Ma (`jiawei.m@columbia.edu`) Hu Xu (`huxu@meta.com`). + + +## Citation + +Please cite our papers (accepted by CVPR 2024 & ICLR 2024) if MoDE helps your work: + +```bibtex +@inproceedings{ma2024mode, + title={MoDE: CLIP Data Experts via Clustering}, + author={Ma, Jiawei and Huang, Po-Yao and Xie, Saining and Li, Shang-Wen and Zettlemoyer, Luke and Chang, Shih-Fu and Yih, Wen-Tau and Xu, Hu}, + booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition}, + year={2024} +} +@inproceedings{xu2023metaclip, + title={Demystifying CLIP Data}, + author={Xu, Hu and Xie, Saining and Tan, Xiaoqing and Huang, Po-Yao and Howes, Russell and Sharma, Vasu and Li, Shang-Wen and Ghosh, Gargi and Zettlemoyer, Luke and Feichtenhofer, Christoph}, + booktitle={The Twelfth International Conference on Learning Representations}, + year={2023} +} +``` + +## Reference + +The code is based on [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), and only the data loading & sampling is modified. + +## TODO +- (welcome your use cases or suggestions to update this codebase regularly) + + +## License + +The MoDE is licensed under CC-BY-NC. + +## Acknowledgement +We gratefully acknowledge the [OpenCLIP](https://github.com/mlfoundations/open_clip) team for initial CLIP codebase and [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) for the careful data distribution examination. diff --git a/mode/__init__.py b/mode/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mode/get_prep_parser.py b/mode/get_prep_parser.py new file mode 100644 index 0000000..40c42fd --- /dev/null +++ b/mode/get_prep_parser.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import argparse + +PATH_TO_DEMO = '/demo' +def get_default_paths(): + demo = { + 'root':'data/demo/{0..200000}.tar', + 'caption':f'{PATH_TO_DEMO}/caption/', + 'feature':f'{PATH_TO_DEMO}/feature/', + 'assign':f'{PATH_TO_DEMO}/assign/', + 'cluster':f'{PATH_TO_DEMO}/cluster_center/', + } + return {'demo': demo} + +def get_args_parser(): + parser = argparse.ArgumentParser(description='MoDE Data Preparation', add_help=False) + parser.add_argument('--dataset', default='demo', type=str, choices=['clipeval', 'demo']) + parser.add_argument('--root', default="data/demo/{0..200000}.tar", type=str, + help='path to dataset root') + parser.add_argument('--caption-dir', default='caption/', type=str, help='caption dir, highly recommended') + parser.add_argument('--feature-dir', default='feature/', type=str, help='feature output dir') + + # Below arguments are only for pre-processing pre-train data on feature extraction + parser.add_argument('--file-mode', default='tarfile', type=str, choices=['caption', 'tarfile'], + help='processing extracted captions or tarfiles direction') + parser.add_argument('--tar-init', default=0, type=int, help='tarfile_id to start') + parser.add_argument('--tar-end', default=-1, type=int, help='tarfile_id to end') + parser.add_argument('--tar-per-gpu', default=-1, type=int, help='number of tarfiles to process per GPU') + parser.add_argument('--chunk-size', default=400, type=int, help='number of captions to be processed') + parser.add_argument('--horovod', default=False, type=bool, help='placeholder, needed to pass ddp initialization') + parser.add_argument('--dist-url', default="env://", type=str, help='placeholder, needed to pass ddp initialization') + parser.add_argument('--dist-backend', default="nccl", type=str, help='placeholder, needed to pass ddp initialization') + parser.add_argument('--no-set-device-rank', default=False, type=bool, help='placeholder, needed to pass ddp initialization') + + # Arguments on clustering and assignment + parser.add_argument('--cm', default=1024, type=int, help='number of fine-grained cluster centers') + parser.add_argument('--cn', default=4, type=int, help='number of coarse-grained cluster centers') + parser.add_argument('--cd', default='euclidean', type=str, help='cluster distance, euc or cos') + parser.add_argument('--cassign-dir', default='assign/', type=str, help='dir for cluster assignment') + parser.add_argument('--ccenter-dir', default='cluster_center/', type=str, help='dir for cluster centers') + + # Arguments on intermediate variables at inference time + parser.add_argument('--logits-dir', default='./logs/clip_eval', type=str, help='cluster center') + return parser \ No newline at end of file diff --git a/mode/move2root/configs_mode.py b/mode/move2root/configs_mode.py new file mode 100644 index 0000000..c8377be --- /dev/null +++ b/mode/move2root/configs_mode.py @@ -0,0 +1,127 @@ +import os +import inspect + +from collections import OrderedDict +from dataclasses import dataclass + +import sys +sys.path.append("src") + +from training.params import get_default_params +from mode.get_prep_parser import get_default_paths + +@dataclass +class Config: + train_data = None + val_data = None + train_num_samples = None + val_num_samples = None + dataset_type = "auto" + dataset_resampled = False + csv_separator = "\t" + csv_img_key = "filepath" + csv_caption_key = "title" + imagenet_val = "/datasets01/imagenet_full_size/061417/val" + imagenet_v2 = None + logs = "./logs/" + log_local = False + name = None + workers = 8 + batch_size = 64 + epochs = 32 + lr = None + beta1 = None + beta2 = None + eps = None + wd = 0.2 + warmup = 2000 # 10000 + use_bn_sync = False + skip_scheduler = False + save_frequency = 1 + save_most_recent = True # False + zeroshot_frequency = 1 + val_frequency = 1 + resume = None + precision = "amp" + clip_model = "CLIP" + model = "RN50" + pretrained = '' + pretrained_image = False + lock_image = False + lock_image_unlocked_groups = 0 + lock_image_freeze_bn_stats = False + grad_checkpointing = False + local_loss = False + gather_with_grad = False + force_quick_gelu = False + torchscript = False + trace = False + dist_url = "env://" + dist_backend = "nccl" + report_to = "" + wandb_notes = '' + debug = False + copy_codebase = False + horovod = False + ddp_static_graph = False + no_set_device_rank = False + seed = 0 + norm_gradient_clip = None + + fine_index = '' + hrchy_assign = '' + ooc_ratio = 0.02 # slightly better than 0.0 + dist_type = 'euclidean' + + def __post_init__(self): + args = self + args.name = self.__class__.__name__ + + for name, val in get_default_params(args.model).items(): + if getattr(args, name) is None: + setattr(args, name, val) + + if 'mode' in args.name: + assert args.coarse_idx >=0 and args.coarse_idx < args.mode_size + sub_str=f'expert_{args.coarse_idx}' + args.name = '{}_n{}m{}/{}'.format(args.name, args.mode_size, args.mode_fine, sub_str) + + + if args.train_data == '': + datakey = 'demo' + args.train_data = get_default_paths()[datakey]['root'] + else: + # args.train_data and 'root' of get_default_paths in get_prep_parser.py should be the same + # the data dir is named by the dataset + datakey = args.train_data.split('/')[-2] + paths = get_default_paths()[datakey] + args.fine_index = paths['assign'] + args.hrchy_assign = paths['cluster'] + + if args.resume is None: + # As the checkpoint for data expert initialization is trained via MetaCLIP repo, + # the same format is applied to determine the checkpoint path. + args.resume = os.path.join(args.seed_exp, 'checkpoints', f'epoch_{args.quick_init}.pt') + + args.output_dir = os.path.join(args.logs, args.name) + +def parse_start_end(shards): + start, end = os.path.basename(shards).split("{")[1].split("}")[0].split("..") + return int(start), int(end) + + +def search_config(config_name): + import importlib + project_dir = os.path.dirname(__file__) + all_configs = {} + for code in os.listdir(project_dir): + if code.endswith(".py") and code.startswith("run_configs"): + module = importlib.import_module(code[:-3]) + for _config_name in dir(module): + if _config_name in ["Config"] or _config_name.startswith("__") or _config_name.startswith("run_config"): + continue + if _config_name not in all_configs: + all_configs[_config_name] = module + print(f"launching {config_name} from {all_configs[config_name].__file__}") + config = getattr(all_configs[config_name], config_name)() + return config diff --git a/mode/move2root/run_configs_mode.py b/mode/move2root/run_configs_mode.py new file mode 100644 index 0000000..b37dee8 --- /dev/null +++ b/mode/move2root/run_configs_mode.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# usage: +# python src/training/main.py b32_mode +# torchrun --nproc_per_node=8 src/training/main.py b32_mode +# python submitit_mode.py b32_mode + +from dataclasses import dataclass +from configs_mode import Config + +# ================================ +# Configs for MoDE +# set the following checkpoint paths before running +# more explanation can be found in configs_mode.py +# ================================ + +b32_demo_full='' +b16_demo_full='' +l14_demo_full='' + +@dataclass +class b32_mode(Config): + one_iter=True + inmem=True + engine="train_one_epoch_ex" + eval_steps=5000 + save_frequency=1 + train_data="data/demo/{0..200000}.tar" + workers=8 + train_num_samples=400000000 + batch_size=512 + epochs=32 + model="ViT-B-32-quickgelu" + name="ViT-B-32" + force_quick_gelu=True + warmup=2000 + seed=0 + local_loss=True + gather_with_grad=True + nodes=8 + ngpus=8 + # Configs for MoDE + dataset_type='cluster' + mode_size=4 + coarse_idx=1 # (change from 0 to mode_size-1 index the data experts) + mode_fine=1024 + quick_init=27 + seed_exp=f'logs/{b32_demo_full}' + +@dataclass +class b16_mode(b32_mode): + model="ViT-B-16-quickgelu" + name="ViT-B-16" + grad_checkpointing=True + seed_exp=f'logs/{b16_demo_full}' + +@dataclass +class l14_mode(b32_mode): + model="ViT-L-14-quickgelu" + name="ViT-L-14" + lr=0.0004 + batch_size=256 + grad_checkpointing=True + nodes=16 + ngpus=8 + seed_exp=f'logs/{l14_demo_full}' + + +if __name__ == "__main__": + import inspect + import sys + for name, obj in inspect.getmembers(sys.modules[__name__]): + if inspect.isfunction(obj): + print(name) diff --git a/mode/move2root/run_mode.sh b/mode/move2root/run_mode.sh new file mode 100644 index 0000000..cf9e41f --- /dev/null +++ b/mode/move2root/run_mode.sh @@ -0,0 +1,44 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +# To change the data path, please check get_default_paths func in mode/get_prep_parser.py +# ============================================== +# Run the commands below to prepare data clustering +# ============================================== + +DATASET=demo +DIR_CLIPEVAL=demo_verify/feature_clipeval/ +TEND=-1 # choose a small positive number for test purpose if needed + +## Prep-Step 0 Prepare captions +# (only for pre-train data, not necessary but highly recommended) +# python mode/prep_caption.py --dataset ${DATASET} --tar-end ${TEND} + +## Prep-Step 1 Prepare features +# 1.1 for pre-train dataset (if you skiped step 0, remove --file-mode caption) +# torchrun --master_port=29600 --nproc_per_node=8 mode/prep_feature.py \ +# --dataset ${DATASET} --tar-end ${TEND} --file-mode caption +# 1.2 for downstream dataset +# torchrun --master_port=29500 --nproc_per_node=1 mode/prep_feature.py \ +# --dataset clipeval --feature-dir ${DIR_CLIPEVAL} + +## Prep-Step 2 Two-Level Clustering +# torchrun --master_port=29500 --nproc_per_node=1 mode/prep_hrchy.py --dataset ${DATASET} + +## Prep-Step 3 Fine-grained cluster Assignment +# python mode/prep_inference.py --dataset ${DATASET} --tar-end ${TEND} + +# ============================================== +# Run the commands below to prepare data clustering +# ============================================== +## Training-Step 1 +# For details, please refer to run_configs_mode.py +# torchrun --nproc_per_node=8 src/training/main.py b32_mode + +# ============================================== +# Run the commands below when all ckpts are ready +# ============================================== +## Post-Step 1 +# torchrun --master_port=29600 --nproc_per_node=4 mode/post_expert_eval.py b32_mode + +## Post-Step 2 +# python mode/post_report_ensemble.py b32_mode ${DIR_CLIPEVAL} diff --git a/mode/move2root/submitit_mode.py b/mode/move2root/submitit_mode.py new file mode 100644 index 0000000..85ecbb6 --- /dev/null +++ b/mode/move2root/submitit_mode.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# -------------------------------------------------------- +# A script to run multinode training with submitit. +# -------------------------------------------------------- + +import argparse +import os +import uuid +from pathlib import Path + +import sys +sys.path.append("src") + +import training.main as main +import submitit + + +def parse_args(): + parser = argparse.ArgumentParser("Submitit for openclip") + parser.add_argument("config_name", type=str, help="name of the config.") + parser.add_argument("--ngpus", default=None, type=int, help="Number of gpus to request on each node") + parser.add_argument("--nodes", default=None, type=int, help="Number of nodes to request") + parser.add_argument("--resume", default="", type=str, help="resume a checkpoint.") + parser.add_argument("--timeout", default=4320, type=int, help="Duration of the job") + parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.") + + parser.add_argument("--partition", default="learnlab", type=str, help="Partition where to submit") + parser.add_argument("--use_volta32", action='store_true', help="Request 32G V100 GPUs") + parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler") + args = parser.parse_args() + return args + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + if Path("/checkpoint/").is_dir(): + p = Path(f"/checkpoint/{user}/openclip") + p.mkdir(exist_ok=True) + return p + raise RuntimeError("No shared folder available") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer(object): + def __init__(self, args): + self.args = args + self.args.config.dist_url = get_init_file().as_uri() + + def __call__(self): + import sys + sys.path.append("src") + import training.main as main + self._setup_gpu_args() + main.main(self.args.config) + + def checkpoint(self): + import os + import submitit + + self.args.config.dist_url = get_init_file().as_uri() + checkpoint_file = os.path.join(self.args.config.output_dir, "checkpoints", "epoch_latest.pt") + if os.path.exists(checkpoint_file): + self.args.config.resume = checkpoint_file + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + import submitit + from pathlib import Path + + job_env = submitit.JobEnvironment() + if self.args.ngpus >= 1: + self.args.config.local_rank = job_env.local_rank + self.args.config.rank = job_env.global_rank + self.args.config.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(args): + if args.job_dir == "": + args.job_dir = get_shared_folder() + + assert args.job_dir != "" + if os.path.exists(args.job_dir) and len(args.resume) == 0 and not hasattr(args.config, "eval"): + raise ValueError(f"{args.job_dir} existed, rm -rf {args.job_dir} ?") + + args.job_dir = Path(args.job_dir) / "%j" + + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + + partition = args.partition + kwargs = {} + if args.use_volta32: + kwargs['slurm_constraint'] = 'volta32gb' + if args.comment: + kwargs['slurm_comment'] = args.comment + + executor.update_parameters( + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=timeout_min, + # Below are cluster dependent parameters + slurm_partition=partition, + slurm_signal_delay_s=120, + **kwargs + ) + + executor.update_parameters(name=args.config.name) + + args.dist_url = get_init_file().as_uri() + args.output_dir = args.job_dir + + trainer = Trainer(args) + job = executor.submit(trainer) + + print("Submitted job_id:", job.job_id, "@", str(args.job_dir).replace("%j", job.job_id)) + + +def submit(): + args = parse_args() + from configs_mode import search_config + from copy import deepcopy + + config = search_config(args.config_name) + _args = deepcopy(args) + if len(args.resume): + checkpoint_file = os.path.join(config.output_dir, "checkpoints", args.resume) + args.resume = checkpoint_file + config.resume = checkpoint_file + + setattr(_args, "config", config) + if args.ngpus is not None: + _args.ngpus = args.ngpus + elif hasattr(config, "ngpus"): + _args.ngpus = config.ngpus + else: + raise ValueError("must specify ngpus in arg or config.") + if args.nodes is not None: + _args.nodes = args.nodes + elif hasattr(config, "nodes"): + _args.nodes = config.nodes + else: + raise ValueError("must specify ngpus in arg or config.") + _args.job_dir = config.output_dir + main(_args) + + +if __name__ == "__main__": + submit() diff --git a/mode/post_expert_eval.py b/mode/post_expert_eval.py new file mode 100644 index 0000000..346d68c --- /dev/null +++ b/mode/post_expert_eval.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import sys +sys.path.append("./") +sys.path.append("src") + +import os +import logging +import torch +import json + +from open_clip import tokenize +from open_clip import create_model_and_transforms, get_mean_std + +from training.distributed import init_distributed_device +from training.logger import setup_logging +from clipeval.eval_zeroshot import validate_zeroshot, mean_per_class, accuracy, roc_auc + + +def evaluate_logits(d, val_loader, templates, labels, model, tokenizer, classnorm=False): + print('Evaluating {}'.format(d)) + + outputs = validate_zeroshot(val_loader, templates, labels, model, tokenizer, False, classnorm) + + if d in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102']: + metric = mean_per_class(*outputs) + elif d == 'Kinetics700': + top1, top5 = accuracy(*outputs, topk=(1, 5)) + metric = (top1 + top5) / 2 + metric = metric.item() + elif d == 'HatefulMemes': + metric = roc_auc(*outputs) + else: + pred = outputs[0].argmax(dim=1) + correct = pred.eq(outputs[1]).sum() + metric = correct.item() / float(pred.size(0)) * 100.0 + + return metric, outputs + + +@torch.no_grad() +def slip_evaluate_expert(args, model, val_transform, tokenizer, idx): + from clipeval import datasets, eval_zeroshot + + catalog, all_templates, all_labels = eval_zeroshot.load_metadata("clipeval") + + if hasattr(model, "module"): + model = model.module + + metrics = {} + for d in catalog: + result_fn = os.path.join(args.output_dir, 'eval_outputs', f'{d}_pred-{idx}.pth') + if os.path.exists(result_fn): + 'logits' in torch.load(result_fn) + continue + + val_dataset = datasets.get_downstream_dataset( + catalog, d, is_train=False, transform=val_transform) + templates = all_templates[d] + labels = all_labels[d] + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=args.batch_size//2, shuffle=False, + num_workers=args.workers, pin_memory=False, drop_last=False) + + metric, logits = evaluate_logits(d, val_loader, templates, labels, model, tokenizer) + metrics[d] = metric + json_str = json.dumps({"model": idx, "task": d, "acc": metric}) + torch.save({'logits':logits[0], 'targets':logits[1]}, result_fn) + logging.info(json_str) + return metrics + + +def main(args): + + device = init_distributed_device(args) + mean, std = get_mean_std(args) + args.log_path = os.path.join(args.output_dir, f'expert.log') + args.log_level = logging.INFO + setup_logging(args.log_path, args.log_level) + os.makedirs(os.path.join(args.output_dir,'eval_outputs'), exist_ok=True) + + model, _, preprocess_val = create_model_and_transforms( + args.model, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + pretrained_image=args.pretrained_image, + mean=mean, std=std, + inmem=hasattr(args, "inmem"), + clip_model=args.clip_model + ) + + ckpt_list = [os.path.join(args.output_dir,f'expert_{i}','checkpoints','epoch_latest.pt') for i in range(args.mode_size)] + logging.info('There are {} ckpts to be ensembled for exp {}.'.format(len(ckpt_list),os.path.dirname(args.name))) + idxes = [i for i in range(len(ckpt_list))] + if args.world_size > 1: + assert len(ckpt_list) % args.world_size == 0 + seg = len(ckpt_list) // args.world_size + idxes = idxes[args.rank*seg:(args.rank+1)*seg] + ckpt_list = ckpt_list[args.rank*seg:(args.rank+1)*seg] + + for idx, ckpt_path in zip(idxes,ckpt_list): + logging.info(f'Loading Model {ckpt_path}') + ckpt = torch.load(ckpt_path, map_location=torch.device('cuda')) + + if next(iter(ckpt['state_dict'].items()))[0].startswith('_orig_mod'): + ckpt['state_dict'] = {k[len('_orig_mod.'):]: v for k, v in ckpt['state_dict'].items()} + if next(iter(ckpt['state_dict'].items()))[0].startswith('module'): + ckpt['state_dict'] = {k[len('module.'):]: v for k, v in ckpt['state_dict'].items()} + + model.load_state_dict(ckpt['state_dict']) + model.eval() + slip_evaluate_expert(args, model, preprocess_val, tokenize, idx) + + +if __name__ == "__main__": + from configs_mode import search_config + config = search_config(sys.argv[1]) + config.output_dir = os.path.dirname(config.output_dir) + main(config) diff --git a/mode/post_report_ensemble.py b/mode/post_report_ensemble.py new file mode 100644 index 0000000..2618bfc --- /dev/null +++ b/mode/post_report_ensemble.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import sys +sys.path.append("src") +sys.path.append("./") +import os + +import numpy as np +import pandas as pd +import json +import torch +from clipeval.eval_zeroshot import mean_per_class, accuracy, roc_auc + +from scipy.special import softmax + +def evaluate_dataset(d,acc_or_outputs): + if d in ['FGVCAircraft', 'OxfordPets', 'Caltech101', 'Flowers102']: + metric = mean_per_class(*acc_or_outputs) + elif d == 'Kinetics700': + top1, top5 = accuracy(*acc_or_outputs, topk=(1, 5)) + metric = (top1 + top5) / 2 + metric = metric.item() + elif d == 'HatefulMemes': + metric = roc_auc(*acc_or_outputs) + else: + pred = acc_or_outputs[0].argmax(dim=1) + correct = pred.eq(acc_or_outputs[1]).sum() + metric = correct.item() / float(pred.size(0)) * 100.0 + return metric + +def process(dataset, num_cluster, result_path, ensemble_weight): + outputs = [] + units_acc = [] + result = {} + for i in range(num_cluster): + data = torch.load(os.path.join(result_path,'{}_pred-{}.pth'.format(dataset,i)),map_location='cpu') + outputs.append(data) + units_acc.append(evaluate_dataset(dataset,(data['logits'],data['targets']))) + + result = {f'expert-{i}':acc for i,acc in enumerate(units_acc)} + result['expert-max'] = max(units_acc) + all_logits = torch.stack([item['logits'] for item in outputs]) + result['all_unit'] = evaluate_dataset(dataset,(all_logits.mean(dim=0), data['targets'])) + result['ensemble'] = evaluate_dataset(dataset,((all_logits * torch.from_numpy(ensemble_weight).view(-1,1,1)).sum(dim=0), data['targets'])) + + return result + +def main(opts): + + csv_path = os.path.join(opts.output_dir,'result.csv') + tasks = json.load(open( + '{}/clipeval/dataset_catalog.json'.format('..' if os.path.abspath('.').endswith('mode') else '.'),'r' + )) + fine_fn = os.path.join(opts.hrchy_assign,opts.dist_type,'F{}.pth'.format(opts.mode_fine)) + fine_cluster = torch.load(fine_fn)['center'].cpu() + hrchy_fn = os.path.join(opts.hrchy_assign,opts.dist_type,'F{}-C{}.pth'.format(opts.mode_fine,opts.mode_size)) + hrchy_assign = torch.load(hrchy_fn)['assign'].numpy() + + results = {'dataset':[],'num_classes':[]} + for dataset in tasks: + + # prepare raw distance between task embedding and cluster centers + close_fn = os.path.join(opts.metadata_dir,f'{dataset}.json') + if os.path.exists(close_fn): + with open(close_fn,'r') as json_file: + fine_closeness = json.load(json_file) + [dist,assign] = fine_closeness + else: + task_embedding = torch.load(os.path.join(opts.metadata_dir,f'{dataset}.pth')) + # top 1 filtering + dist,assign = torch.cdist(task_embedding,fine_cluster).min(dim=-1) + dist,assign = dist.tolist(),assign.tolist() + with open(close_fn,'w') as json_file: + json.dump([dist,assign],json_file) + + # grouping along coarse cluster + n_class = len(assign) + hrchy_avg = hrchy_assign[np.array(assign)] + soft_avg = {key: np.array(dist)[hrchy_avg==key] for key in np.unique(hrchy_assign).tolist()} + + # Ensembling + weight_soft_unit = np.zeros(len(soft_avg)) + for key,item in soft_avg.items(): + if len(item) > 0: + sharpen_add = np.exp(0.5-np.sqrt(n_class)) if n_class < 10 else 0.0 + sharpen_mul = opts.smooth_weight[0] * np.log10(max(10,n_class-opts.too_many_class)) + weight_soft_unit[key] = np.exp((sharpen_add-item) * sharpen_mul).sum() + weight_soft_unit = softmax(weight_soft_unit/opts.smooth_weight[1]) + unit_result = process(dataset, opts.mode_size, opts.result_dir, weight_soft_unit) + + # summary + results['dataset'].append(dataset) + results['num_classes'].append(n_class) + for key in unit_result: + if key not in results: + results[key] = [] + results[key].append(unit_result[key]) + + # stat average and report + for key in results: + if key == 'dataset': + results[key].append('average') + else: + results[key].append(np.mean(results[key])) + result_df = pd.DataFrame.from_dict(results) + result_df.to_csv(csv_path) + print('Ensembling Done, please check', csv_path) + +if __name__ == "__main__": + from configs_mode import search_config + config = search_config(sys.argv[1]) + + config.output_dir = os.path.dirname(config.output_dir) + config.result_dir = os.path.join(config.output_dir,'eval_outputs') + config.metadata_dir = sys.argv[2] + config.smooth_weight = [5.0,8.0] + config.too_many_class = 200 + + main(config) + diff --git a/mode/prep_caption.py b/mode/prep_caption.py new file mode 100644 index 0000000..34685a3 --- /dev/null +++ b/mode/prep_caption.py @@ -0,0 +1,165 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import sys +sys.path.append("src") +sys.path.append("./") +import os +import random + +from tqdm import tqdm +from multiprocessing import Pool +from get_prep_parser import get_args_parser, get_default_paths + +import json +import argparse +import tarfile + +import pdb + +def get_tarfile_path(wds_dir,tar_id): + # File Organization + return os.path.join(wds_dir, str(tar_id % 100), f'{tar_id}.tar') + + +def gather_caption_from_tarfile(root_dir, tar_id): + tarball_path = get_tarfile_path(root_dir, tar_id) + + if not os.path.exists(tarball_path): + return None, 'file not exists' + + captions,tmems,imems = [],[],[] + with tarfile.open(tarball_path) as tar: + members = tar.getmembers() + + json_cnt = 0 + json_mid,img_mid = -1,-1 + iuid, uuid = None, None + for midx, member in enumerate(members): + if member.name.endswith(".json"): + uuid = member.name[:-len(".json")] + with tar.extractfile(member) as f: + text_json = json.load(f) + + if 'demo' in root_dir: + txt = random.choice(text_json["texts"]) + else: + raise ValueError('Please Implement by yourself and uncomment this line in prep_caption.py') + json_cnt += 1 + json_mid = midx + + if member.name.endswith(".jpeg") or member.name.endswith(".jpg"): + suffix = len(member.name.split('.')[-1]) + 1 + iuid = member.name[:-suffix] + img_mid = midx + + if iuid is not None and iuid == uuid: + if txt is None or len(txt)==0 or txt in ['"',]: + continue + if json_mid in tmems or img_mid in imems: + continue + captions.append(txt) + tmems.append(json_mid) + imems.append(img_mid) + if len(set(tmems)) == len(set(imems)) and len(set(imems)) == len(captions): + return {'tmems':tmems, 'imems':imems, 'caption':captions}, 'success' + else: + return None, 'fail' + + +def build_caption(wds_dir, shard_id, caption_dir, overwrite=True): + shard_folder = shard_id % 100 + os.makedirs(os.path.join(caption_dir, f"{shard_folder}"), exist_ok=True) + + output_fn_group = os.path.join(caption_dir, f"{shard_folder}", f"{shard_id}_caption.json") + if os.path.exists(output_fn_group) and not overwrite: + return True + + data, status = gather_caption_from_tarfile(wds_dir, shard_id) + + if status == 'success': + with open(output_fn_group,'w') as json_file: + json.dump(data, json_file) + print('Newly write {} with {} items'.format(shard_id, len(data['caption']))) + return True + else: + return None + + +def func(args, _start, _end): + missing_shards = [] + + if isinstance(_start, list): + warc_iter = _start + else: + warc_iter = ( + tqdm(range(_start, _end)) if _start == 0 else range(_start, _end) + ) + + for idx, shard_id in enumerate(warc_iter): + + wds_fn = get_tarfile_path(wds_dir, shard_id) + if not os.path.exists(wds_fn): + missing_shards.append(shard_id) + continue + + status = build_caption( + wds_dir, shard_id, args.caption_dir, overwrite=False, + ) + if status: + pass + elif status is None: + missing_shards.append(shard_id) + else: + raise ValueError('No Implementation Error') + + return missing_shards + + +def main(args): + global wds_dir + + shard_ids = [[] for _ in range(args.num_threads)] + for shard_id in range(args.tar_init, args.tar_end): + group_offset = shard_id % args.num_threads + shard_ids[group_offset].append(shard_id) + + print(f"shard_ids[0]={shard_ids[0]}") + starts = shard_ids + ends = [None for _ in range(len(starts))] + + argss = [args for _ in range(len(starts))] + assert len(argss) == len(starts) == len(ends) + assert len(starts) <= args.num_threads + wds_dir = os.path.dirname(args.root) + + with Pool(len(starts)) as p: + results = p.starmap( + func, + zip( + argss, + starts, + ends + ), + ) + + all_results = [] + for result in results: + all_results.extend(result) + print("missing file", len(all_results), all_results) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser('Clustering evaluation', parents=[get_args_parser()]) + config = parser.parse_args() + + paths = get_default_paths()[config.dataset] + config.root = paths['root'] + config.caption_dir = paths['caption'] + + config.num_threads = 40 + if config.tar_end == -1: + config.tar_end = int(os.path.basename(config.root).split("{")[1].split("}")[0].split("..")[1]) + + os.makedirs(config.caption_dir, exist_ok=True) + main(config) diff --git a/mode/prep_feature.py b/mode/prep_feature.py new file mode 100644 index 0000000..6f4bd88 --- /dev/null +++ b/mode/prep_feature.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import sys +sys.path.append("src") +sys.path.append("./") +import os +import random + +import torch +import torch.nn.functional as F +import json +from tqdm import tqdm +import argparse +from transformers import AutoModel, AutoTokenizer + +from src.training.distributed import init_distributed_device + +from get_prep_parser import get_args_parser, get_default_paths +from prep_caption import gather_caption_from_tarfile, get_tarfile_path +from clipeval import eval_zeroshot + + +@torch.no_grad() +def build_text_indfeatures(templates, labels, model, tokenizer): + text_features = [] + for i,label in enumerate(labels): + if isinstance(label, list): + texts = [t.format(l) for t in templates for l in label] + else: + texts = [t.format(label) for t in templates] + + texts = tokenizer(texts, padding=True, truncation=True, return_tensors="pt") + texts = {key:item.cuda() for key,item in texts.items()} + class_embeddings = model(**texts, output_hidden_states=True, return_dict=True).pooler_output.cpu() + + class_embeddings = class_embeddings / class_embeddings.norm(dim=-1, keepdim=True) + text_features.append(F.normalize(class_embeddings.mean(dim=0),dim=-1)) + text_features = torch.stack(text_features,dim=0) + return text_features + + +def main(args): + device = init_distributed_device(args) + os.makedirs(args.feature_dir,exist_ok=True) + + if 'demo' in config.dataset: + _, tar_end = os.path.basename(config.root).split("{")[1].split("}")[0].split("..") + if config.tar_end == -1: + config.tar_end = int(tar_end) + else: + config.tar_end = min(config.tar_end, int(tar_end)) + if config.tar_per_gpu == -1: + config.tar_per_gpu = int((config.tar_end - config.tar_init) / config.world_size) + + tokenizer = AutoTokenizer.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased") + model = AutoModel.from_pretrained("princeton-nlp/sup-simcse-bert-base-uncased").cuda() + + stat = {'non_exist':[], 'failed':[], 'success':[]} + if 'demo' in args.dataset: + + current_init = args.tar_init+args.rank*args.tar_per_gpu + current_end = current_init + args.tar_per_gpu + shard_id_list = [i for i in range(current_init,min(current_end,args.tar_end))] + random.shuffle(shard_id_list) + + for shard_id in shard_id_list: + save_path = os.path.join(args.feature_dir, str(shard_id % 100),'{}_feat.pth'.format(shard_id)) + if os.path.exists(save_path): + stat['success'].append(shard_id) + print(shard_id, f'already written in {save_path}') + continue + + txtfeats = {'feat':[]} + with torch.no_grad(): + os.makedirs(os.path.dirname(save_path),exist_ok=True) + + if args.file_mode == 'caption': + caption_file = os.path.join(args.caption_dir,str(shard_id % 100),f'{shard_id}_caption.json') + if not os.path.isfile(caption_file): + tarpath = get_tarfile_path(args.root, shard_id) + if os.path.isfile(tarpath): + stat['failed'].append(shard_id) + else: + stat['non_exist'].append(shard_id) + continue + captions = json.load(open(caption_file,'r')) + else: + captions, status = gather_caption_from_tarfile(os.path.dirname(args.root), shard_id) + if captions is None: + stat['non_exist'].append(shard_id) + continue + + chunks = [captions['caption'][x:x+args.chunk_size] for x in range(0, len(captions['caption']), args.chunk_size)] + for chunk in chunks: + inputs = tokenizer(chunk, padding=True, truncation=True, return_tensors="pt") + inputs = {key:item.cuda() for key,item in inputs.items()} + embeddings = model(**inputs, output_hidden_states=True, return_dict=True).pooler_output.cpu() + txtfeats['feat'].append(embeddings) + + txtfeats['filekeys'] = captions['tmems'] + txtfeats['img_midx'] = captions['imems'] + txtfeats['feat'] = torch.cat(txtfeats['feat'],dim=0) + if len(txtfeats['feat']) != len(txtfeats['filekeys']): + print(f'check {shard_id}') + stat['failed'].append(shard_id) + else: + torch.save(txtfeats, save_path) + stat['success'].append(shard_id) + print('write feature for {} with {} items'.format(shard_id, len(txtfeats['feat']))) + + if args.dataset == 'clipeval': + + catalog, all_templates, all_labels = eval_zeroshot.load_metadata("clipeval") + for d in catalog: + feat_file = os.path.join(args.feature_dir, f'{d}.pth') + if os.path.exists(feat_file): + continue + templates = all_templates[d] + labels = all_labels[d] + text_embeddings = build_text_indfeatures(templates, labels, model, tokenizer) + torch.save(text_embeddings, feat_file) + + else: + raise ValueError('Please comment this command and customize the code for yourself') + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser('Clustering evaluation', parents=[get_args_parser()]) + config = parser.parse_args() + + if config.dataset != 'clipeval': + paths = get_default_paths()[config.dataset] + config.root = paths['root'] + config.caption_dir = paths['caption'] + config.feature_dir = paths['feature'] + + os.makedirs(config.feature_dir, exist_ok=True) + main(config) diff --git a/mode/prep_hrchy.py b/mode/prep_hrchy.py new file mode 100644 index 0000000..36c30d1 --- /dev/null +++ b/mode/prep_hrchy.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import sys +sys.path.append("src") +sys.path.append("./") +import os +import random + +import torch +import torch.nn.functional as F +import argparse, pickle, pdb +import numpy as np + +from kmeans_pytorch import KMeans as BalancedKMeans +from get_prep_parser import get_args_parser, get_default_paths + + +if torch.cuda.is_available(): + device = torch.device('cuda') +else: + device = torch.device('cpu') + + +def cluster_fine(n_clusters, args, balanced=1): + + path_to_fine = os.path.join(args.ccenter_dir, args.cd, f'F{n_clusters}.pth') + os.makedirs(os.path.dirname(path_to_fine), exist_ok=True) + if os.path.exists(path_to_fine): + print(f'{path_to_fine} is written') + return True + + print(f'Preparing data for file {path_to_fine}') + file_for_run = [] + for i in range(100): + feat_files = os.listdir(os.path.join(args.feature_dir,str(i))) + num_files = len(feat_files) + num_fun_files = int(num_files*0.05) + 1 # num_files + files = np.random.choice(feat_files,num_fun_files).tolist() + file_for_run.extend([os.path.join(args.feature_dir,str(i),file) for file in files]) + + np.random.shuffle(file_for_run) + print('{} files are selected'.format(len(file_for_run))) + kmeans = BalancedKMeans(n_clusters=n_clusters, device=device, balanced=(balanced==1)) + total_size = 0 + for i,file in enumerate(file_for_run): + print(i, file) + feat = F.normalize(torch.load(file)['feat'].cuda(),dim=-1) + + total_size += feat.size(0) + if 'cos' in args.cd: + kmeans.fit(feat, distance='cosine', iter_limit=50, online=True, iter_k=i) + elif 'euc' in args.cd.lower(): # euclidean + kmeans.fit(feat, distance='euclidean', iter_limit=50, online=True, iter_k=i) + else: + raise ValueError('Not Implemented') + if (i+1) % 100 == 0: + print(f'checkpointing at step {i}') + torch.save({'center':kmeans.cluster_centers.cpu()},path_to_fine) + + print('there are {} files involved in clustering'.format(total_size)) + with open(path_to_fine.replace('.pth','.pkl'), 'wb+') as f: + _ = pickle.dump(kmeans, f) + torch.save({'center':kmeans.cluster_centers.cpu()},path_to_fine) + return True + + +def cluster_coarse(n_clusters, args, balanced=1): + + path_to_fine = os.path.join(args.ccenter_dir, args.cd, f'F{args.cm}.pth') + centers = torch.load(path_to_fine)['center'] + + path_to_coarse = os.path.join(args.ccenter_dir, args.cd, f'F{args.cm}-C{n_clusters}.pth') + if os.path.exists(path_to_coarse): + print(f'{path_to_coarse} is written') + return True + + kmeans = BalancedKMeans(n_clusters=n_clusters, device=device, balanced=(balanced==1)) + + if 'cos' in args.cd: + kmeans.fit(F.normalize(centers.cuda(),dim=-1), distance='cosine', iter_limit=100, online=False) + elif 'euc' in args.cd.lower(): # euclidean + kmeans.fit(centers.cuda(), distance='euclidean', iter_limit=100, online=False) + else: + raise ValueError('Not Implemented') + + assign = kmeans.predict(centers.cuda(), args.cd) + torch.save({'coarse':kmeans.cluster_centers.cpu(),'assign':assign.cpu()}, path_to_coarse) + return True + +parser = argparse.ArgumentParser('Clustering Evaluation', parents=[get_args_parser()]) +args = parser.parse_args() +paths = get_default_paths()[args.dataset] +args.feature_dir = paths['feature'] +args.ccenter_dir = paths['cluster'] + +cluster_fine(args.cm, args) +cluster_coarse(args.cn, args) diff --git a/mode/prep_inference.py b/mode/prep_inference.py new file mode 100644 index 0000000..bbfed1d --- /dev/null +++ b/mode/prep_inference.py @@ -0,0 +1,150 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import sys +sys.path.append("src") +sys.path.append("./") +import os + +import torch +import torch.nn.functional as F +import argparse +from get_prep_parser import get_args_parser, get_default_paths +from prep_caption import get_tarfile_path + +import json,pdb + +from multiprocessing import Pool +from tqdm import tqdm + + +def build_assignment(feat_dir, shard_id, assign_dir, overwrite=True): + + shard_folder = shard_id % 100 + output_fn_group = os.path.join(assign_dir, f"{shard_folder}", f'{shard_id}_assign_dist.json') + os.makedirs(os.path.dirname(output_fn_group), exist_ok=True) + + feat_fn = os.path.join(feat_dir, f"{shard_folder}", f'{shard_id}_feat.pth') + if not os.path.exists(feat_fn): + print(feat_fn, 'Not Found') + return None + + if os.path.exists(output_fn_group) and not overwrite: + print(f'{output_fn_group} Written already') + return True + + feature = torch.load(feat_fn, map_location='cpu') + assign = {'key':feature['filekeys'],'image':feature['img_midx']} + feature = F.normalize(feature['feat'],dim=-1) + + for key,ccenter in ccenters.items(): + if key[0] == 'E': # for euclidean + dist = torch.cdist(feature[None], ccenter[None])[0] + min_dist,assign_tensor = dist.min(dim=-1) + min_dist = min_dist.numpy().tolist() + elif key[0] == 'C': # for cosine + sim = torch.mm(feature, ccenter.T) + max_sim,assign_tensor = sim.max(dim=-1) + min_dist = (1.0 - max_sim).numpy().tolist() + # add "'dist':min_dist" in the dict if needed + assign[key] = {'assign':assign_tensor.numpy().tolist()} + + with open(output_fn_group,'w') as json_file: + json.dump(assign, json_file) + print('Newly written', shard_id) + return assign + + +def func(args, _start, _end): + missing_shards = [] + + if isinstance(_start, list): + warc_iter = _start + else: + warc_iter = ( + tqdm(range(_start, _end)) if _start == 0 else range(_start, _end) + ) + + for idx, shard_id in enumerate(warc_iter): + + wds_fn = get_tarfile_path(wds_dir, shard_id) + if not os.path.exists(wds_fn): + continue + + status = build_assignment( + args.feature_dir, shard_id, args.cassign_dir, overwrite=False, + ) + if status: + pass + elif status is None: + missing_shards.append(shard_id) + else: + raise ValueError('No Implementation Error') + + return missing_shards + + +def main(args): + + shard_ids = [[] for _ in range(args.num_threads)] + for shard_id in range(args.tar_init, args.tar_end): + group_offset = shard_id % args.num_threads + shard_ids[group_offset].append(shard_id) + + print(f"shard_ids[0]={shard_ids[0]}") + starts = shard_ids + ends = [None for _ in range(len(starts))] + + argss = [args for _ in range(len(starts))] + assert len(argss) == len(starts) == len(ends) + assert len(starts) <= args.num_threads + + global wds_dir + wds_dir = os.path.dirname(args.root) + + global ccenters + ccenters = {} + # MoDE originally uses euclidean dist in clustering + # The dict structure below provides flexibility + # for cluster assginment with different cm and and cosine + for dist_type in ['euclidean']: + for cm in [args.cm,]: + path = os.path.join(args.ccenter_dir,dist_type,f'F{cm}.pth') + if os.path.exists(path): + key = '{}{}'.format(dist_type[0].upper(),args.cm) + ccenters[key] = torch.load(path)['center'] + if 'cos' in dist_type: + ccenters[key] = F.normalize(ccenters[key],dim=-1) + + with Pool(len(starts)) as p: + results = p.starmap( + func, + zip( + argss, + starts, + ends + ), + ) + + all_results = [] + for result in results: + all_results.extend(result) + print("missing npy", len(all_results)) + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser('Clustering evaluation', parents=[get_args_parser()]) + config = parser.parse_args() + + paths = get_default_paths()[config.dataset] + config.root = paths['root'] + config.feature_dir = paths['feature'] + config.cassign_dir = paths['assign'] + config.ccenter_dir = paths['cluster'] + + config.num_threads = 40 + if config.tar_end == -1: + config.tar_end = int(os.path.basename(config.root).split("{")[1].split("}")[0].split("..")[1]) + + os.makedirs(config.cassign_dir, exist_ok=True) + main(config) diff --git a/src/training/data.py b/src/training/data.py index 80d8226..3a70287 100644 --- a/src/training/data.py +++ b/src/training/data.py @@ -428,11 +428,19 @@ def get_metaclip_dataset(args, preprocess_fn, is_train, epoch=0): return get_metaclip_iter_wds_dataset(args, preprocess_fn, is_train, epoch) +def get_mode_dataset(args, preprocess_fn, is_train, epoch=0): + # a switcher func for different versions of dataloader. + from .mode_wds import get_mode_iter_wds_dataset + return get_mode_iter_wds_dataset(args, preprocess_fn, is_train, epoch) + + def get_dataset_fn(data_path, dataset_type): if dataset_type == "webdataset": return get_wds_dataset elif dataset_type == "csv": return get_csv_dataset + elif dataset_type == "cluster": + return get_mode_dataset elif dataset_type == "auto": ext = data_path.split('.')[-1] if ext in ['csv', 'tsv']: