Skip to content

Commit

Permalink
Merge pull request #17 from cmacdonald/main
Browse files Browse the repository at this point in the history
move all source to splade/ module
  • Loading branch information
thibault-formal committed Jul 20, 2022
2 parents 571c794 + 939e886 commit 3f50461
Show file tree
Hide file tree
Showing 43 changed files with 97 additions and 73 deletions.
18 changes: 9 additions & 9 deletions README.md
Expand Up @@ -118,7 +118,7 @@ In order to perform all steps (here on toy data, i.e. `config_default.yaml`), go
conda activate splade_env
export PYTHONPATH=$PYTHONPATH:$(pwd)
export SPLADE_CONFIG_NAME="config_default.yaml"
python3 -m src.all \
python3 -m splade.all \
config.checkpoint_dir=experiments/debug/checkpoint \
config.index_dir=experiments/debug/index \
config.out_dir=experiments/debug/out
Expand All @@ -129,14 +129,14 @@ python3 -m src.all \
We provide additional examples that can be plugged in the above code. See [conf/README.md](conf/README.md) for details
on how to change experiment settings.

* you can similarly run training `python3 -m src.train` (same for indexing or retrieval)
* you can similarly run training `python3 -m splade.train` (same for indexing or retrieval)
* to create Anserini readable files (after training),
run `SPLADE_CONFIG_FULLPATH=/path/to/checkpoint/dir/config.yaml python3 -m src.create_anserini +quantization_factor_document=100 +quantization_factor_query=100`
run `SPLADE_CONFIG_FULLPATH=/path/to/checkpoint/dir/config.yaml python3 -m splade.create_anserini +quantization_factor_document=100 +quantization_factor_query=100`
* config files for various settings (distillation etc.) are available in `/conf`. For instance, to run the `SelfDistil`
setting:
* change to `SPLADE_CONFIG_NAME=config_splade++_selfdistil.yaml`
* to further change parameters (e.g. lambdas) *outside* the config,
run: `python3 -m src.all config.regularizer.FLOPS.lambda_q=0.06 config.regularizer.FLOPS.lambda_d=0.02`
run: `python3 -m splade.all config.regularizer.FLOPS.lambda_q=0.06 config.regularizer.FLOPS.lambda_d=0.02`

We provide several base configurations which correspond to the experiments in the v2bis and "efficiency" papers. Please note that these are
suited for our hardware setting, i.e. 4 GPUs Tesla V100 with 32GB memory. In order to train models with e.g. one GPU,
Expand All @@ -153,24 +153,24 @@ or [Anserini](https://github.com/castorini/anserini). Let's perform these steps
conda activate splade_env
export PYTHONPATH=$PYTHONPATH:$(pwd)
export SPLADE_CONFIG_NAME="config_splade++_cocondenser_ensembledistil"
python3 -m src.index \
python3 -m splade.index \
init_dict.model_type_or_dir=naver/splade-cocondenser-ensembledistil \
config.pretrained_no_yamlconfig=true \
config.index_dir=experiments/pre-trained/index
python3 -m src.retrieve \
python3 -m splade.retrieve \
init_dict.model_type_or_dir=naver/splade-cocondenser-ensembledistil \
config.pretrained_no_yamlconfig=true \
config.index_dir=experiments/pre-trained/index \
config.out_dir=experiments/pre-trained/out
# pretrained_no_yamlconfig indicates that we solely rely on a HF-valid model path
```

* To change the data, simply override the hydra retrieve_evaluate package, e.g. add `retrieve_evaluate=msmarco` as argument of `src.retrieve`.
* To change the data, simply override the hydra retrieve_evaluate package, e.g. add `retrieve_evaluate=msmarco` as argument of `splade.retrieve`.

You can similarly build the files that will be ingested by Anserini:

```bash
python3 -m src.create_anserini \
python3 -m splade.create_anserini \
init_dict.model_type_or_dir=naver/splade-cocondenser-ensembledistil \
config.pretrained_no_yamlconfig=true \
config.index_dir=experiments/pre-trained/index \
Expand All @@ -192,7 +192,7 @@ export PYTHONPATH=$PYTHONPATH:$(pwd)
export SPLADE_CONFIG_FULLPATH="/path/to/checkpoint/dir/config.yaml"
for dataset in arguana fiqa nfcorpus quora scidocs scifact trec-covid webis-touche2020 climate-fever dbpedia-entity fever hotpotqa nq
do
python3 -m src.beir_eval \
python3 -m splade.beir_eval \
+beir.dataset=$dataset \
+beir.dataset_path=data/beir \
config.index_retrieve_batch_size=100
Expand Down
2 changes: 1 addition & 1 deletion inference_splade.ipynb
Expand Up @@ -36,7 +36,7 @@
"source": [
"import torch\n",
"from transformers import AutoModelForMaskedLM, AutoTokenizer\n",
"from src.models.transformer_rep import Splade"
"from splade.models.transformer_rep import Splade"
]
},
{
Expand Down
23 changes: 23 additions & 0 deletions setup.py
@@ -0,0 +1,23 @@
from setuptools import setup, find_packages

with open('README.md') as f:
readme = f.read()

setup(
name='SPLADE',
version='2.1',
description='SParse Lexical AnD Expansion Model for First Stage Ranking',
url='https://github.com/naver/splade',
classifiers=[
'Intended Audience :: Science/Research',
'Programming Language :: Python :: 3.7',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
],
packages=['splade'] + ['splade.' + i for i in find_packages('splade')],
license="Creative Commons Attribution-NonCommercial-ShareAlike",
long_description=readme,
install_requires=[
'transformers==4.18.0',
'omegaconf==2.1.2'
],
)
Empty file added splade/__init__.py
Empty file.
12 changes: 6 additions & 6 deletions src/all.py → splade/all.py
Expand Up @@ -2,12 +2,12 @@
from omegaconf import DictConfig

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.flops import flops
from src.index import index
from src.retrieve import retrieve_evaluate
from src.train import train
from src.utils.hydra import hydra_chdir
from src.utils.index_figure import index_figure
from .flops import flops
from .index import index
from .retrieve import retrieve_evaluate
from .train import train
from .utils.hydra import hydra_chdir
from .utils.index_figure import index_figure


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
10 changes: 5 additions & 5 deletions src/beir_eval.py → splade/beir_eval.py
Expand Up @@ -11,11 +11,11 @@
from tqdm.auto import tqdm

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.datasets.dataloaders import CollectionDataLoader
from src.datasets.datasets import BeirDataset
from src.models.models_utils import get_model
from src.tasks.transformer_evaluator import SparseIndexing, SparseRetrieval
from src.utils.utils import get_initialize_config
from .datasets.dataloaders import CollectionDataLoader
from .datasets.datasets import BeirDataset
from .models.models_utils import get_model
from .tasks.transformer_evaluator import SparseIndexing, SparseRetrieval
from .utils.utils import get_initialize_config


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
10 changes: 5 additions & 5 deletions src/create_anserini.py → splade/create_anserini.py
Expand Up @@ -2,11 +2,11 @@
from omegaconf import DictConfig

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.datasets.dataloaders import TextCollectionDataLoader
from src.datasets.datasets import CollectionDatasetPreLoad
from src.models.models_utils import get_model
from src.tasks.transformer_evaluator import EncodeAnserini
from src.utils.utils import get_initialize_config
from .datasets.dataloaders import TextCollectionDataLoader
from .datasets.datasets import CollectionDatasetPreLoad
from .models.models_utils import get_model
from .tasks.transformer_evaluator import EncodeAnserini
from .utils.utils import get_initialize_config


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
Empty file added splade/datasets/__init__.py
Empty file.
Expand Up @@ -6,7 +6,7 @@
from torch.utils.data.dataloader import DataLoader
from transformers import AutoTokenizer

from src.utils.utils import rename_keys
from ..utils.utils import rename_keys


class DataLoaderWrapper(DataLoader):
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions src/evaluate.py → splade/evaluate.py
Expand Up @@ -5,8 +5,8 @@
from omegaconf import DictConfig

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.evaluation.eval import load_and_evaluate
from src.utils.utils import get_dataset_name
from .evaluation.eval import load_and_evaluate
from .utils.utils import get_dataset_name


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
Empty file added splade/evaluation/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion src/evaluation/eval.py → splade/evaluation/eval.py
@@ -1,7 +1,7 @@
import argparse
import json

from src.utils.metrics import mrr_k, evaluate
from ..utils.metrics import mrr_k, evaluate


def load_and_evaluate(qrel_file_path, run_file_path, metric):
Expand Down
10 changes: 5 additions & 5 deletions src/flops.py → splade/flops.py
Expand Up @@ -6,11 +6,11 @@
from omegaconf import DictConfig

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.datasets.dataloaders import CollectionDataLoader
from src.datasets.datasets import CollectionDatasetPreLoad
from src.models.models_utils import get_model
from src.tasks.transformer_evaluator import SparseIndexing
from src.utils.utils import get_initialize_config
from .datasets.dataloaders import CollectionDataLoader
from .datasets.datasets import CollectionDatasetPreLoad
from .models.models_utils import get_model
from .tasks.transformer_evaluator import SparseIndexing
from .utils.utils import get_initialize_config


def estim_act_prob(dist, collection_size, voc_size=30522):
Expand Down
10 changes: 5 additions & 5 deletions src/index.py → splade/index.py
Expand Up @@ -2,11 +2,11 @@
from omegaconf import DictConfig

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.datasets.dataloaders import CollectionDataLoader
from src.datasets.datasets import CollectionDatasetPreLoad
from src.models.models_utils import get_model
from src.tasks.transformer_evaluator import SparseIndexing
from src.utils.utils import get_initialize_config
from .datasets.dataloaders import CollectionDataLoader
from .datasets.datasets import CollectionDatasetPreLoad
from .models.models_utils import get_model
from .tasks.transformer_evaluator import SparseIndexing
from .utils.utils import get_initialize_config


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
Empty file added splade/indexing/__init__.py
Empty file.
File renamed without changes.
Empty file added splade/losses/__init__.py
Empty file.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file added splade/models/__init__.py
Empty file.
@@ -1,6 +1,6 @@
from omegaconf import DictConfig

from src.models.transformer_rep import Splade, SpladeDoc
from ..models.transformer_rep import Splade, SpladeDoc


def get_model(config: DictConfig, init_dict: DictConfig):
Expand Down
Expand Up @@ -3,8 +3,8 @@
import torch
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel

from src.tasks.amp import NullContextManager
from src.utils.utils import generate_bow, normalize
from ..tasks.amp import NullContextManager
from ..utils.utils import generate_bow, normalize

"""
we provide abstraction classes from which we can easily derive representation-based models with transformers like SPLADE
Expand Down
Empty file added splade/optim/__init__.py
Empty file.
File renamed without changes.
12 changes: 6 additions & 6 deletions src/retrieve.py → splade/retrieve.py
Expand Up @@ -2,12 +2,12 @@
from omegaconf import DictConfig

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.datasets.dataloaders import CollectionDataLoader
from src.datasets.datasets import CollectionDatasetPreLoad
from src.evaluate import evaluate
from src.models.models_utils import get_model
from src.tasks.transformer_evaluator import SparseRetrieval
from src.utils.utils import get_dataset_name, get_initialize_config
from .datasets.dataloaders import CollectionDataLoader
from .datasets.datasets import CollectionDatasetPreLoad
from .evaluate import evaluate
from .models.models_utils import get_model
from .tasks.transformer_evaluator import SparseRetrieval
from .utils.utils import get_dataset_name, get_initialize_config


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
Empty file added splade/tasks/__init__.py
Empty file.
File renamed without changes.
Empty file added splade/tasks/base/__init__.py
Empty file.
File renamed without changes.
Expand Up @@ -2,7 +2,7 @@

import torch

from src.utils.utils import restore_model
from ...utils.utils import restore_model


class Evaluator:
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions src/tasks/base/trainer.py → splade/tasks/base/trainer.py
Expand Up @@ -7,9 +7,9 @@
from omegaconf import open_dict
from torch.utils.tensorboard import SummaryWriter

from src.tasks.base.early_stopping import EarlyStopping
from src.tasks.base.saver import ValidationSaver
from src.utils.utils import makedir, remove_old_ckpt
from .early_stopping import EarlyStopping
from .saver import ValidationSaver
from ...utils.utils import makedir, remove_old_ckpt


class BaseTrainer:
Expand Down
Expand Up @@ -9,10 +9,10 @@
import torch
from tqdm.auto import tqdm

from src.indexing.inverted_index import IndexDictOfArray
from src.losses.regularization import L0
from src.tasks.base.evaluator import Evaluator
from src.utils.utils import makedir, to_list
from ..indexing.inverted_index import IndexDictOfArray
from ..losses.regularization import L0
from ..tasks.base.evaluator import Evaluator
from ..utils.utils import makedir, to_list


class SparseIndexing(Evaluator):
Expand Down
Expand Up @@ -6,10 +6,10 @@
from omegaconf import open_dict
from tqdm.auto import tqdm

from src.tasks import amp
from src.tasks.base.trainer import TrainerIter
from src.utils.metrics import init_eval
from src.utils.utils import parse
from ..tasks import amp
from ..tasks.base.trainer import TrainerIter
from ..utils.metrics import init_eval
from ..utils.utils import parse


class TransformerTrainer(TrainerIter):
Expand Down
16 changes: 8 additions & 8 deletions src/train.py → splade/train.py
Expand Up @@ -6,15 +6,15 @@
from torch.utils import data

from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.datasets.dataloaders import CollectionDataLoader, SiamesePairsDataLoader, DistilSiamesePairsDataLoader
from src.datasets.datasets import PairsDatasetPreLoad, DistilPairsDatasetPreLoad, MsMarcoHardNegatives, \
from .datasets.dataloaders import CollectionDataLoader, SiamesePairsDataLoader, DistilSiamesePairsDataLoader
from .datasets.datasets import PairsDatasetPreLoad, DistilPairsDatasetPreLoad, MsMarcoHardNegatives, \
CollectionDatasetPreLoad
from src.losses.regularization import init_regularizer, RegWeightScheduler
from src.models.models_utils import get_model
from src.optim.bert_optim import init_simple_bert_optim
from src.tasks.transformer_evaluator import SparseApproxEvalWrapper
from src.tasks.transformer_trainer import SiameseTransformerTrainer
from src.utils.utils import set_seed, restore_model, get_initialize_config, get_loss, set_seed_from_config
from .losses.regularization import init_regularizer, RegWeightScheduler
from .models.models_utils import get_model
from .optim.bert_optim import init_simple_bert_optim
from .tasks.transformer_evaluator import SparseApproxEvalWrapper
from .tasks.transformer_trainer import SiameseTransformerTrainer
from .utils.utils import set_seed, restore_model, get_initialize_config, get_loss, set_seed_from_config


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
Empty file added splade/utils/__init__.py
Empty file.
File renamed without changes.
2 changes: 1 addition & 1 deletion src/utils/index_figure.py → splade/utils/index_figure.py
Expand Up @@ -6,7 +6,7 @@

import hydra
from conf.CONFIG_CHOICE import CONFIG_NAME, CONFIG_PATH
from src.utils.utils import get_initialize_config
from .utils import get_initialize_config


@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
Expand Down
File renamed without changes.
File renamed without changes.
7 changes: 4 additions & 3 deletions src/utils/utils.py → splade/utils/utils.py
Expand Up @@ -5,9 +5,8 @@
import torch
from omegaconf import DictConfig, OmegaConf

from src.losses.pairwise import DistilKLLoss, PairwiseNLL, DistilMarginMSE, InBatchPairwiseNLL
from src.losses.pointwise import BCEWithLogitsLoss
from src.utils.hydra import hydra_chdir
from ..losses.pairwise import DistilKLLoss, PairwiseNLL, DistilMarginMSE, InBatchPairwiseNLL
from ..losses.pointwise import BCEWithLogitsLoss


def parse(d, name):
Expand Down Expand Up @@ -108,6 +107,8 @@ def get_dataset_name(path):


def get_initialize_config(exp_dict: DictConfig, train=False):
# delay import to reduce dependencies
from ..utils.hydra import hydra_chdir
hydra_chdir(exp_dict)
exp_dict["init_dict"]["fp16"] = exp_dict["config"].get("fp16", False)
config = exp_dict["config"]
Expand Down

0 comments on commit 3f50461

Please sign in to comment.