Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
208 changes: 208 additions & 0 deletions mode/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# MoDE: CLIP Data Experts via Clustering

This repository contains the code for the Mixture of Data Experts, described in the paper [MoDE: CLIP Data Experts via Clustering](https://arxiv.org/abs/2309.16671) that provides the first multi-modal understanding system based on independent CLIP models. The main contributions are:
- Introducing the concept of **data expert** and making the MoDE framework where several small models are separately learned but adaptively ensembled for each task.
- Studying how to build a **wider** system, rather than a deeper network. The system is scalable and capable of integrating new data experts, without compromising the extablished ability, which can thus be applied to online data and be continuously updated.
- Investigating the quality negative samples in contrastive language-image pretraining, and in particular, the false negatives in web-crawled image-caption pairs.
- Demonstrating that a set of small data experts can be comparable with a single large model. As the data experts can be trained asynchorously, MoDE significantly reduces the mximum computation requirement, shedding light on research based on limited computation resource.

We conclude that:
- Effective pretraining should **carefully examine the data distribution**, instead of aggressively learning from the whole dataset.
- Data can be used to explain the model capability and determine the ensemble of models (deep learning is data driven).
- Our algorithm is simpler and easily scalable to comsume the data in the whole Internet

MoDE is trained w/ face blurred images.

```bibtex
@inproceedings{ma2024mode,
title={MoDE: CLIP Data Experts via Clustering},
author={Ma, Jiawei and Huang, Po-Yao and Xie, Saining and Li, Shang-Wen and Zettlemoyer, Luke and Chang, Shih-Fu and Yih, Wen-Tau and Xu, Hu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024}
}

@inproceedings{xu2023metaclip,
title={Demystifying CLIP Data},
author={Xu, Hu and Xie, Saining and Tan, Xiaoqing and Huang, Po-Yao and Howes, Russell and Sharma, Vasu and Li, Shang-Wen and Ghosh, Gargi and Zettlemoyer, Luke and Feichtenhofer, Christoph},
booktitle={The Twelfth International Conference on Learning Representations},
year={2023}
}
```

## Quick Links

- [Getting Started](#getting-started)
- [Data Preparation](#data-preparation)
- [Clustering](#clustering)
- [Training](#training)
- [Inference-Time Task Adaptation (Ensemble)](#ensemble)
- [Bugs or Questions?](#bugs-or-questions)
- [Citation](#citation)
- [Reference](#reference)


## Getting Started

This code is developed with minimal changes on top of [MetaCLIP](https://github.com/facebookresearch/MetaCLIP). The following command should install requirements for MetaCLIP and `submitit=1.2.1` used by this repo:

```bash
conda create -n python=3.10 pytorch torchvision pytorch-cuda=11.7 tqdm ftfy braceexpand webdataset regex pandas submitit=1.2.1 \
-c pytorch-nightly \
-c nvidia \
-c conda-forge \
-c anaconda
```

Then, please refer to the following repo to install the code for kmeans clustering
```bash
https://github.com/subhadarship/kmeans_pytorch/tree/master
```

Finally, please move the config-related files from this folder to the root
```bash
mv move2root/ ../
rm -r move2root
```

## Data Preparation

In this example code, we assume the dataset is called `demo` and all of the image-caption pairs are saved in a bunch of tarfiles while all tarfiles are tarfiles are organized in sharded folders
```
'demo':
'0':
'0.tar'
'100.tar'
...
'1':
'1.tar'
'101.tar'
...
...
'99':
'99.tar'
'199.tar'
...
```
Within each tarfile, the image-caption pairs are saved in sequence.
```
., json, jpeg, json, jpeg ...
```
where for each pair, the text is first stored in a `json` file and the image is then saved in `jpeg`.

For the following steps, we have provided a detailed command example under `prep-steps` in `run_mode.sh` for explanation & usage.
The configuration and the paths for intermediate data storing are summarized in `mode/get_prep_parser.py`. When you run the code, please make sure to be in the root directory of the whole project. For the customization of your own data, you can also modify the `get_default_paths` function in the `py` file.

## Clustering

Data clustering is performed on the language embeddings of captions. This section mainly explains feature extraction and data clustering.
For large-scale data processing, we provide the optimized code below to separate the steps and enable multi-thread processing.

### Step 0 Preparing Captions

This step considers the tarfile where the image-caption pairs are stored together.
As caption extraction is CPU-only, we provide the function below to enable multi-thread caption collection (This is highly recommended for large-scale data processing).

```bash
python mode/prep_caption.py
```

### Step 1 Preparing Features

This step extracts the language embeddings of captions, and the features for captions in one tarfile will be stored in a single pth file. Following the organization of tarfiles, we also organize the features in sharded folders.

When the captions are pre-collected (via step 0), run the command below to extract the features for captions where each thread is allocated on one GPU chip.

```bash
torchrun --nproc_per_node=8 mode/prep_feature.py --file-mode caption
```

As an alternative, you can skip step 0 and directly do feature extraction from the tarfiles.

```bash
torchrun --nproc_per_node=8 mode/prep_feature.py --file-mode tarfile
```

### Step 2 Two-Step Clustering

Once the features are ready, perform two-step clustering to obtain the finegrained clusters and the coarse-grained condition. Note we only use a fraction of the whole data to do the clustering on a single GPU chip. Once finished, both the finegrained clusters, coarse-grained clusters can be provided.

```bash
torchrun --nproc_per_node=1 mode/prep_hrchy.py
```

### Step 3 Cluster Assignment

Once the cluster centers are obtained, use nearest neighborhood search to determine the cluster assignment for each pair. This process is CPU-only and the code below supports multi-thread processing.

```bash
python mode/prep_inference.py
```

## Training

Once the cluster assignment is ready, we do normal training as CLIP but just alter the data sampling. Please check the config file `run_configs_mode.py` and manually change the expert ID via `coarse_idx` to determine the data expert model to be trained.

```bash
torchrun --nproc_per_node=8 src/training/main.py b32_mode
```

## Ensemble

Given the well-trained expert models, for comprehensive evaluation, we gather the outputs from each expert model as well as the ensembled output, and summarize them as a report in original experiment log folder.

Firstly, we evaluate each model and gather their outputs for ensembling.

```bash
torchrun --master_port=29600 --nproc_per_node=4 mode/post_expert_eval.py b32_mode
```

Then, as a preparation for ensembling, we extract the language embeddings of task metadata, e.g., class names. We reuse the feature extraction file but pass different arguments.

```bash
python mode/post_report_ensemble.py b32_mode ${DIR_CLIPEVAL}
```

Lastly, we use the similarity between metadata embeddings and cluster centers to determine ensembling weights for evaluation. By running the command below, all results will be summarized in a csv file.

```bash
python mode/post_report_ensemble.py b32_mode ${DIR_CLIPEVAL}
```

## Bugs or questions?

If you have any questions related to the code or the paper, feel free to email Jiawei Ma (`jiawei.m@columbia.edu`) Hu Xu (`huxu@meta.com`).


## Citation

Please cite our papers (accepted by CVPR 2024 & ICLR 2024) if MoDE helps your work:

```bibtex
@inproceedings{ma2024mode,
title={MoDE: CLIP Data Experts via Clustering},
author={Ma, Jiawei and Huang, Po-Yao and Xie, Saining and Li, Shang-Wen and Zettlemoyer, Luke and Chang, Shih-Fu and Yih, Wen-Tau and Xu, Hu},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
year={2024}
}
@inproceedings{xu2023metaclip,
title={Demystifying CLIP Data},
author={Xu, Hu and Xie, Saining and Tan, Xiaoqing and Huang, Po-Yao and Howes, Russell and Sharma, Vasu and Li, Shang-Wen and Ghosh, Gargi and Zettlemoyer, Luke and Feichtenhofer, Christoph},
booktitle={The Twelfth International Conference on Learning Representations},
year={2023}
}
```

## Reference

The code is based on [MetaCLIP](https://github.com/facebookresearch/MetaCLIP), and only the data loading & sampling is modified.

## TODO
- (welcome your use cases or suggestions to update this codebase regularly)


## License

The MoDE is licensed under CC-BY-NC.

## Acknowledgement
We gratefully acknowledge the [OpenCLIP](https://github.com/mlfoundations/open_clip) team for initial CLIP codebase and [MetaCLIP](https://github.com/facebookresearch/MetaCLIP) for the careful data distribution examination.
Empty file added mode/__init__.py
Empty file.
45 changes: 45 additions & 0 deletions mode/get_prep_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates

import argparse

PATH_TO_DEMO = '/demo'
def get_default_paths():
demo = {
'root':'data/demo/{0..200000}.tar',
'caption':f'{PATH_TO_DEMO}/caption/',
'feature':f'{PATH_TO_DEMO}/feature/',
'assign':f'{PATH_TO_DEMO}/assign/',
'cluster':f'{PATH_TO_DEMO}/cluster_center/',
}
return {'demo': demo}

def get_args_parser():
parser = argparse.ArgumentParser(description='MoDE Data Preparation', add_help=False)
parser.add_argument('--dataset', default='demo', type=str, choices=['clipeval', 'demo'])
parser.add_argument('--root', default="data/demo/{0..200000}.tar", type=str,
help='path to dataset root')
parser.add_argument('--caption-dir', default='caption/', type=str, help='caption dir, highly recommended')
parser.add_argument('--feature-dir', default='feature/', type=str, help='feature output dir')

# Below arguments are only for pre-processing pre-train data on feature extraction
parser.add_argument('--file-mode', default='tarfile', type=str, choices=['caption', 'tarfile'],
help='processing extracted captions or tarfiles direction')
parser.add_argument('--tar-init', default=0, type=int, help='tarfile_id to start')
parser.add_argument('--tar-end', default=-1, type=int, help='tarfile_id to end')
parser.add_argument('--tar-per-gpu', default=-1, type=int, help='number of tarfiles to process per GPU')
parser.add_argument('--chunk-size', default=400, type=int, help='number of captions to be processed')
parser.add_argument('--horovod', default=False, type=bool, help='placeholder, needed to pass ddp initialization')
parser.add_argument('--dist-url', default="env://", type=str, help='placeholder, needed to pass ddp initialization')
parser.add_argument('--dist-backend', default="nccl", type=str, help='placeholder, needed to pass ddp initialization')
parser.add_argument('--no-set-device-rank', default=False, type=bool, help='placeholder, needed to pass ddp initialization')

# Arguments on clustering and assignment
parser.add_argument('--cm', default=1024, type=int, help='number of fine-grained cluster centers')
parser.add_argument('--cn', default=4, type=int, help='number of coarse-grained cluster centers')
parser.add_argument('--cd', default='euclidean', type=str, help='cluster distance, euc or cos')
parser.add_argument('--cassign-dir', default='assign/', type=str, help='dir for cluster assignment')
parser.add_argument('--ccenter-dir', default='cluster_center/', type=str, help='dir for cluster centers')

# Arguments on intermediate variables at inference time
parser.add_argument('--logits-dir', default='./logs/clip_eval', type=str, help='cluster center')
return parser
127 changes: 127 additions & 0 deletions mode/move2root/configs_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import os
import inspect

from collections import OrderedDict
from dataclasses import dataclass

import sys
sys.path.append("src")

from training.params import get_default_params
from mode.get_prep_parser import get_default_paths

@dataclass
class Config:
train_data = None
val_data = None
train_num_samples = None
val_num_samples = None
dataset_type = "auto"
dataset_resampled = False
csv_separator = "\t"
csv_img_key = "filepath"
csv_caption_key = "title"
imagenet_val = "/datasets01/imagenet_full_size/061417/val"
imagenet_v2 = None
logs = "./logs/"
log_local = False
name = None
workers = 8
batch_size = 64
epochs = 32
lr = None
beta1 = None
beta2 = None
eps = None
wd = 0.2
warmup = 2000 # 10000
use_bn_sync = False
skip_scheduler = False
save_frequency = 1
save_most_recent = True # False
zeroshot_frequency = 1
val_frequency = 1
resume = None
precision = "amp"
clip_model = "CLIP"
model = "RN50"
pretrained = ''
pretrained_image = False
lock_image = False
lock_image_unlocked_groups = 0
lock_image_freeze_bn_stats = False
grad_checkpointing = False
local_loss = False
gather_with_grad = False
force_quick_gelu = False
torchscript = False
trace = False
dist_url = "env://"
dist_backend = "nccl"
report_to = ""
wandb_notes = ''
debug = False
copy_codebase = False
horovod = False
ddp_static_graph = False
no_set_device_rank = False
seed = 0
norm_gradient_clip = None

fine_index = ''
hrchy_assign = ''
ooc_ratio = 0.02 # slightly better than 0.0
dist_type = 'euclidean'

def __post_init__(self):
args = self
args.name = self.__class__.__name__

for name, val in get_default_params(args.model).items():
if getattr(args, name) is None:
setattr(args, name, val)

if 'mode' in args.name:
assert args.coarse_idx >=0 and args.coarse_idx < args.mode_size
sub_str=f'expert_{args.coarse_idx}'
args.name = '{}_n{}m{}/{}'.format(args.name, args.mode_size, args.mode_fine, sub_str)


if args.train_data == '':
datakey = 'demo'
args.train_data = get_default_paths()[datakey]['root']
else:
# args.train_data and 'root' of get_default_paths in get_prep_parser.py should be the same
# the data dir is named by the dataset
datakey = args.train_data.split('/')[-2]
paths = get_default_paths()[datakey]
args.fine_index = paths['assign']
args.hrchy_assign = paths['cluster']

if args.resume is None:
# As the checkpoint for data expert initialization is trained via MetaCLIP repo,
# the same format is applied to determine the checkpoint path.
args.resume = os.path.join(args.seed_exp, 'checkpoints', f'epoch_{args.quick_init}.pt')

args.output_dir = os.path.join(args.logs, args.name)

def parse_start_end(shards):
start, end = os.path.basename(shards).split("{")[1].split("}")[0].split("..")
return int(start), int(end)


def search_config(config_name):
import importlib
project_dir = os.path.dirname(__file__)
all_configs = {}
for code in os.listdir(project_dir):
if code.endswith(".py") and code.startswith("run_configs"):
module = importlib.import_module(code[:-3])
for _config_name in dir(module):
if _config_name in ["Config"] or _config_name.startswith("__") or _config_name.startswith("run_config"):
continue
if _config_name not in all_configs:
all_configs[_config_name] = module
print(f"launching {config_name} from {all_configs[config_name].__file__}")
config = getattr(all_configs[config_name], config_name)()
return config
Loading