# SPLADE for Portuguese

Inspired by https://github.com/naver/splade

In [1]:
!pip install git+https://github.com/naver/splade.git -q

  Preparing metadata (setup.py) ... [?25l[?25hdone
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
hydra-core 1.3.2 requires antlr4-python3-runtime==4.9.*, but you have antlr4-python3-runtime 4.8 which is incompatible.
hydra-core 1.3.2 requires omegaconf<2.4,>=2.2, but you have omegaconf 2.1.2 which is incompatible.[0m[31m
[0m

In [2]:
!pip install hydra-core --upgrade

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting omegaconf<2.4,>=2.2 (from hydra-core)
  Using cached omegaconf-2.3.0-py3-none-any.whl (79 kB)
Collecting antlr4-python3-runtime==4.9.* (from hydra-core)
  Using cached antlr4_python3_runtime-4.9.3-py3-none-any.whl
Installing collected packages: antlr4-python3-runtime, omegaconf
  Attempting uninstall: antlr4-python3-runtime
    Found existing installation: antlr4-python3-runtime 4.8
    Uninstalling antlr4-python3-runtime-4.8:
      Successfully uninstalled antlr4-python3-runtime-4.8
  Attempting uninstall: omegaconf
    Found existing installation: omegaconf 2.1.2
    Uninstalling omegaconf-2.1.2:
      Successfully uninstalled omegaconf-2.1.2
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
splade 2.1 requires omegaconf==2.1.2, but you have omegac

In [6]:
!pip install pytrec_eval

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pytrec_eval
  Downloading pytrec_eval-0.5.tar.gz (15 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pytrec_eval
  Building wheel for pytrec_eval (setup.py) ... [?25l[?25hdone
  Created wheel for pytrec_eval: filename=pytrec_eval-0.5-cp310-cp310-linux_x86_64.whl size=293453 sha256=819e564284a5d34f3ca801637bfebb93bea3b6722ed8006abad94967b5705a64
  Stored in directory: /root/.cache/pip/wheels/51/3a/cd/dcc1ddfc763987d5cb237165d8ac249aa98a23ab90f67317a8
Successfully built pytrec_eval
Installing collected packages: pytrec_eval
Successfully installed pytrec_eval-0.5


In [4]:
import os

CONFIG_NAME = None
CONFIG_PATH = "../conf"

##############################################################
# Provide (as env var), either:
# * 'SPLADE_CONFIG_NAME', this config in splade/conf will be used
# * or 'SPLADE_CONFIG_FULLPATH' (full path, from an exp, such as '/my/path/to/exp/config.yaml'

# if nothing is provided, 'config_default' is used
##############################################################

assert sum([v in os.environ.keys() for v in ["SPLADE_CONFIG_NAME", "SPLADE_CONFIG_FULLPATH"]]) <= 1

if "SPLADE_CONFIG_NAME" in os.environ.keys():
    CONFIG_NAME = os.environ["SPLADE_CONFIG_NAME"]
elif "SPLADE_CONFIG_FULLPATH" in os.environ.keys():
    CONFIG_FULLPATH = os.environ["SPLADE_CONFIG_FULLPATH"]
    CONFIG_PATH, CONFIG_NAME = os.path.split(CONFIG_FULLPATH)
else:
    CONFIG_NAME = "config_default"

if ".yaml" in CONFIG_NAME:
    CONFIG_NAME = CONFIG_NAME.split(".yaml")[0]


In [7]:
import os

import hydra
import torch
from omegaconf import DictConfig, open_dict
from torch.utils import data

from splade.datasets.dataloaders import CollectionDataLoader, SiamesePairsDataLoader, DistilSiamesePairsDataLoader
from splade.datasets.datasets import PairsDatasetPreLoad, DistilPairsDatasetPreLoad, MsMarcoHardNegatives, \
    CollectionDatasetPreLoad
from splade.losses.regularization import init_regularizer, RegWeightScheduler
from splade.models.models_utils import get_model
from splade.optim.bert_optim import init_simple_bert_optim
from splade.tasks.transformer_evaluator import SparseApproxEvalWrapper
from splade.tasks.transformer_trainer import SiameseTransformerTrainer
from splade.utils.utils import set_seed, restore_model, get_initialize_config, get_loss, set_seed_from_config

In [8]:
@hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)
def train(exp_dict: DictConfig):
    exp_dict, config, init_dict, _ = get_initialize_config(exp_dict, train=True)
    model = get_model(config, init_dict)
    random_seed = set_seed_from_config(config)

    optimizer, scheduler = init_simple_bert_optim(model, lr=config["lr"], warmup_steps=config["warmup_steps"],
                                                  weight_decay=config["weight_decay"],
                                                  num_training_steps=config["nb_iterations"])

    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    ################################################################
    # CHECK IF RESUME TRAINING
    ################################################################
    iterations = (1, config["nb_iterations"] + 1)  # tuple with START and END
    regularizer = None
    if os.path.exists(os.path.join(config["checkpoint_dir"], "model_ckpt/model_last.tar")):
        print("@@@@ RESUMING TRAINING @@@")
        print("WARNING: change seed to change data order when restoring !")
        set_seed(random_seed + 666)
        if device == torch.device("cuda"):
            ckpt = torch.load(os.path.join(config["checkpoint_dir"], "model_ckpt/model_last.tar"))
        else:
            ckpt = torch.load(os.path.join(config["checkpoint_dir"], "model_ckpt/model_last.tar"), map_location=device)
        print("starting from step", ckpt["step"])
        print("{} remaining iterations".format(iterations[1] - ckpt["step"]))
        iterations = (ckpt["step"] + 1, config["nb_iterations"])
        restore_model(model, ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        if device == torch.device("cuda"):
            for state in optimizer.state.values():
                for k, v in state.items():
                    if torch.is_tensor(v):
                        state[k] = v.cuda()
        scheduler.load_state_dict(ckpt["scheduler_state_dict"])
        if "regularizer" in ckpt:
            print("loading regularizer")
            regularizer = ckpt.get("regularizer", None)

    if torch.cuda.device_count() > 1:
        print(" --- use {} GPUs --- ".format(torch.cuda.device_count()))
        model = torch.nn.DataParallel(model)
    model.to(device)

    loss = get_loss(config)

    # initialize regularizer dict
    if "regularizer" in config and regularizer is None:  # else regularizer is loaded
        output_dim = model.module.output_dim if hasattr(model, "module") else model.output_dim
        regularizer = {"eval": {"L0": {"loss": init_regularizer("L0")},
                                "sparsity_ratio": {"loss": init_regularizer("sparsity_ratio",
                                                                            output_dim=output_dim)}},
                       "train": {}}
        if config["regularizer"] == "eval_only":
            # just in the case we train a model without reg but still want the eval metrics like L0
            pass
        else:
            for reg in config["regularizer"]:
                temp = {"loss": init_regularizer(config["regularizer"][reg]["reg"]),
                        "targeted_rep": config["regularizer"][reg]["targeted_rep"]}
                d_ = {}
                if "lambda_q" in config["regularizer"][reg]:
                    d_["lambda_q"] = RegWeightScheduler(config["regularizer"][reg]["lambda_q"],
                                                        config["regularizer"][reg]["T"])
                if "lambda_d" in config["regularizer"][reg]:
                    d_["lambda_d"] = RegWeightScheduler(config["regularizer"][reg]["lambda_d"],
                                                        config["regularizer"][reg]["T"])
                temp["lambdas"] = d_  # it is possible to have reg only on q or d if e.g. you only specify lambda_q
                # in the reg config
                # targeted_rep is just used to indicate which rep to constrain (if e.g. the model outputs several
                # representations)
                # the common case: model outputs "rep" (in forward) and this should be the value for this targeted_rep
                regularizer["train"][reg] = temp

    # fix for current in batch neg losses that break on last batch
    if config["loss"] in ("InBatchNegHingeLoss", "InBatchPairwiseNLL"):
        drop_last = True
    else:
        drop_last = False

    if exp_dict["data"].get("type", "") == "triplets":
        data_train = PairsDatasetPreLoad(data_dir=exp_dict["data"]["TRAIN_DATA_DIR"])
        train_mode = "triplets"
    elif exp_dict["data"].get("type", "") == "triplets_with_distil":
        data_train = DistilPairsDatasetPreLoad(data_dir=exp_dict["data"]["TRAIN_DATA_DIR"])
        train_mode = "triplets_with_distil"
    elif exp_dict["data"].get("type", "") == "hard_negatives":
        data_train = MsMarcoHardNegatives(
            dataset_path=exp_dict["data"]["TRAIN"]["DATASET_PATH"],
            document_dir=exp_dict["data"]["TRAIN"]["D_COLLECTION_PATH"],
            query_dir=exp_dict["data"]["TRAIN"]["Q_COLLECTION_PATH"],
            qrels_path=exp_dict["data"]["TRAIN"]["QREL_PATH"])
        train_mode = "triplets_with_distil"
    else:
        raise ValueError("provide valid data type for training")

    val_loss_loader = None  # default
    if "VALIDATION_SIZE_FOR_LOSS" in exp_dict["data"]:
        print("initialize loader for validation loss")
        print("split train, originally {} pairs".format(len(data_train)))
        data_train, data_val = torch.utils.data.random_split(data_train, lengths=[
            len(data_train) - exp_dict["data"]["VALIDATION_SIZE_FOR_LOSS"],
            exp_dict["data"]["VALIDATION_SIZE_FOR_LOSS"]])
        print("train: {} pairs ~~ val: {} pairs".format(len(data_train), len(data_val)))
        if train_mode == "triplets":
            val_loss_loader = SiamesePairsDataLoader(dataset=data_val, batch_size=config["eval_batch_size"],
                                                     shuffle=False,
                                                     num_workers=4,
                                                     tokenizer_type=config["tokenizer_type"],
                                                     max_length=config["max_length"], drop_last=drop_last)
        elif train_mode == "triplets_with_distil":
            val_loss_loader = DistilSiamesePairsDataLoader(dataset=data_val, batch_size=config["eval_batch_size"],
                                                           shuffle=False,
                                                           num_workers=4,
                                                           tokenizer_type=config["tokenizer_type"],
                                                           max_length=config["max_length"], drop_last=drop_last)
        else:
            raise NotImplementedError

    if train_mode == "triplets":
        train_loader = SiamesePairsDataLoader(dataset=data_train, batch_size=config["train_batch_size"], shuffle=True,
                                              num_workers=4,
                                              tokenizer_type=config["tokenizer_type"],
                                              max_length=config["max_length"], drop_last=drop_last)
    elif train_mode == "triplets_with_distil":
        train_loader = DistilSiamesePairsDataLoader(dataset=data_train, batch_size=config["train_batch_size"],
                                                    shuffle=True,
                                                    num_workers=4,
                                                    tokenizer_type=config["tokenizer_type"],
                                                    max_length=config["max_length"], drop_last=drop_last)
    else:
        raise NotImplementedError

    val_evaluator = None
    if "VALIDATION_FULL_RANKING" in exp_dict["data"]:
        with open_dict(config):
            config["val_full_rank_qrel_path"] = exp_dict["data"]["VALIDATION_FULL_RANKING"]["QREL_PATH"]
        full_ranking_d_collection = CollectionDatasetPreLoad(
            data_dir=exp_dict["data"]["VALIDATION_FULL_RANKING"]["D_COLLECTION_PATH"], id_style="row_id")
        full_ranking_d_loader = CollectionDataLoader(dataset=full_ranking_d_collection,
                                                     tokenizer_type=config["tokenizer_type"],
                                                     max_length=config["max_length"],
                                                     batch_size=config["eval_batch_size"],
                                                     shuffle=False, num_workers=4)
        full_ranking_q_collection = CollectionDatasetPreLoad(
            data_dir=exp_dict["data"]["VALIDATION_FULL_RANKING"]["Q_COLLECTION_PATH"], id_style="row_id")
        full_ranking_q_loader = CollectionDataLoader(dataset=full_ranking_q_collection,
                                                     tokenizer_type=config["tokenizer_type"],
                                                     max_length=config["max_length"], batch_size=1,
                                                     # TODO fix: bs currently set to 1
                                                     shuffle=False, num_workers=4)
        val_evaluator = SparseApproxEvalWrapper(model,
                                                config={"top_k": exp_dict["data"]["VALIDATION_FULL_RANKING"]["TOP_K"],
                                                        "out_dir": os.path.join(config["checkpoint_dir"],
                                                                                "val_full_ranking")
                                                        },
                                                collection_loader=full_ranking_d_loader,
                                                q_loader=full_ranking_q_loader,
                                                restore=False)

    # #################################################################
    # # TRAIN
    # #################################################################
    print("+++++ BEGIN TRAINING +++++")
    trainer = SiameseTransformerTrainer(model=model, iterations=iterations, loss=loss, optimizer=optimizer,
                                        config=config, scheduler=scheduler,
                                        train_loader=train_loader, validation_loss_loader=val_loss_loader,
                                        validation_evaluator=val_evaluator,
                                        regularizer=regularizer)
    trainer.train()


The version_base parameter is not specified.
Please specify a compatability version level, or None.
Will assume defaults for version 1.1
  @hydra.main(config_path=CONFIG_PATH, config_name=CONFIG_NAME)


In [9]:

train()

usage: ipykernel_launcher.py [--help] [--hydra-help] [--version]
                             [--cfg {job,hydra,all}] [--resolve]
                             [--package PACKAGE] [--run] [--multirun]
                             [--shell-completion] [--config-path CONFIG_PATH]
                             [--config-name CONFIG_NAME]
                             [--config-dir CONFIG_DIR]
                             [--experimental-rerun EXPERIMENTAL_RERUN]
                             [--info [{all,config,defaults,defaults-tree,plugins,searchpath}]]
                             [overrides ...]
ipykernel_launcher.py: error: unrecognized arguments: -f


SystemExit: ignored

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
