Skip to content

Commit

Permalink
#161 update query size
Browse files Browse the repository at this point in the history
  • Loading branch information
fr.branchaud-charron committed Dec 1, 2021
1 parent 9f19313 commit a1afb8b
Show file tree
Hide file tree
Showing 17 changed files with 1,293 additions and 943 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ model = ModelWrapper(model, your_criterion)
active_loop = ActiveLearningLoop(dataset,
get_probabilities=model.predict_on_dataset,
heuristic=heuristics.BALD(shuffle_prop=0.1),
ndata_to_label=NDATA_TO_LABEL)
query_size=NDATA_TO_LABEL)
for al_step in range(N_ALSTEP):
model.train_on_dataset(dataset, optimizer, BATCH_SIZE, use_cuda=use_cuda)
if not active_loop.step():
Expand Down
17 changes: 13 additions & 4 deletions baal/active/active_loop.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import os
import pickle
import types
import warnings
from typing import Callable

import numpy as np
import structlog
import torch.utils.data as torchdata

from . import heuristics
from .dataset import ActiveLearningDataset

log = structlog.get_logger(__name__)
pjoin = os.path.join


Expand All @@ -20,7 +23,7 @@ class ActiveLearningLoop:
get_probabilities (Function): Dataset -> **kwargs ->
ndarray [n_samples, n_outputs, n_iterations].
heuristic (Heuristic): Heuristic from baal.active.heuristics.
ndata_to_label (int): Number of sample to label per step.
query_size (int): Number of sample to label per step.
max_sample (int): Limit the number of sample used (-1 is no limit).
uncertainty_folder (Optional[str]): If provided, will store uncertainties on disk.
**kwargs: Parameters forwarded to `get_probabilities`.
Expand All @@ -31,12 +34,18 @@ def __init__(
dataset: ActiveLearningDataset,
get_probabilities: Callable,
heuristic: heuristics.AbstractHeuristic = heuristics.Random(),
ndata_to_label: int = 1,
query_size: int = 1,
max_sample=-1,
uncertainty_folder=None,
ndata_to_label=None,
**kwargs,
) -> None:
self.ndata_to_label = ndata_to_label
if ndata_to_label is not None:
warnings.warn(
"`ndata_to_label` is deprecated, please use `query_size`.", DeprecationWarning
)
query_size = ndata_to_label
self.query_size = query_size
self.get_probabilities = get_probabilities
self.heuristic = heuristic
self.dataset = dataset
Expand Down Expand Up @@ -88,6 +97,6 @@ def step(self, pool=None) -> bool:
open(pjoin(self.uncertainty_folder, uncertainty_name), "wb"),
)
if len(to_label) > 0:
self.dataset.label(to_label[: self.ndata_to_label])
self.dataset.label(to_label[: self.query_size])
return True
return False
15 changes: 11 additions & 4 deletions baal/utils/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import types
import warnings
from collections.abc import Sequence
from typing import Dict, Any, Optional

Expand All @@ -19,6 +20,12 @@

log = structlog.get_logger("PL testing")

warnings.warn(
"baal.utils.pytorch_lightning is deprecated. BaaL is now integrated into Lightning Flash!"
" Please see experiments/pytorch_lightning/lightning_flash_example.py for a new tutorial!",
DeprecationWarning,
)


class BaaLDataModule(LightningDataModule):
def __init__(self, active_dataset: ActiveLearningDataset, batch_size=1, **kwargs):
Expand Down Expand Up @@ -99,7 +106,7 @@ class BaalTrainer(Trainer):
Args:
dataset (ActiveLearningDataset): Dataset with some sample already labelled.
heuristic (Heuristic): Heuristic from baal.active.heuristics.
ndata_to_label (int): Number of sample to label per step.
query_size (int): Number of sample to label per step.
max_sample (int): Limit the number of sample used (-1 is no limit).
**kwargs: Parameters forwarded to `get_probabilities`
and to pytorch_ligthning Trainer.__init__
Expand All @@ -109,12 +116,12 @@ def __init__(
self,
dataset: ActiveLearningDataset,
heuristic: heuristics.AbstractHeuristic = heuristics.Random(),
ndata_to_label: int = 1,
query_size: int = 1,
**kwargs
) -> None:

super().__init__(**kwargs)
self.ndata_to_label = ndata_to_label
self.query_size = query_size
self.heuristic = heuristic
self.dataset = dataset
self.kwargs = kwargs
Expand Down Expand Up @@ -194,6 +201,6 @@ def step(self, model=None, datamodule: Optional[BaaLDataModule] = None) -> bool:
if probs is not None and (isinstance(probs, types.GeneratorType) or len(probs) > 0):
to_label = self.heuristic(probs)
if len(to_label) > 0:
self.dataset.label(to_label[: self.ndata_to_label])
self.dataset.label(to_label[: self.query_size])
return True
return False
2 changes: 1 addition & 1 deletion experiments/mlp_mcdropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
dataset=al_dataset,
get_probabilities=wrapper.predict_on_dataset,
heuristic=bald,
ndata_to_label=100, # We will label 100 examples per step.
query_size=100, # We will label 100 examples per step.
# KWARGS for predict_on_dataset
iterations=20, # 20 sampling for MC-Dropout
batch_size=32,
Expand Down
2 changes: 1 addition & 1 deletion experiments/mlp_regression_mcdropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __getitem__(self, item):
dataset=al_dataset,
get_probabilities=wrapper.predict_on_dataset,
heuristic=variance,
ndata_to_label=250, # We will label 20 examples per step.
query_size=250, # We will label 20 examples per step.
# KWARGS for predict_on_dataset
iterations=20, # 20 sampling for MC-Dropout
batch_size=16,
Expand Down
4 changes: 2 additions & 2 deletions experiments/nlp_bert_mcdropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def parse_args():
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--initial_pool", default=1000, type=int)
parser.add_argument("--model", default="bert-base-uncased", type=str)
parser.add_argument("--n_data_to_label", default=100, type=int)
parser.add_argument("--query_size", default=100, type=int)
parser.add_argument("--heuristic", default="bald", type=str)
parser.add_argument("--iterations", default=20, type=int)
parser.add_argument("--shuffle_prop", default=0.05, type=float)
Expand Down Expand Up @@ -112,7 +112,7 @@ def main():
active_set,
model.predict_on_dataset,
heuristic,
hyperparams.get("n_data_to_label", 1),
hyperparams.get("query_size", 1),
iterations=hyperparams["iterations"],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def main():
max_epochs=args.training_duration,
logger=logger,
heuristic=heuristic,
ndata_to_label=args.query_size,
query_size=args.query_size,
)

AL_STEPS = 100
Expand Down
4 changes: 2 additions & 2 deletions experiments/segmentation/unet_mcdropout_pascal.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def parse_args():
parser.add_argument("--al_step", default=200, type=int)
parser.add_argument("--batch_size", default=8, type=int)
parser.add_argument("--initial_pool", default=40, type=int)
parser.add_argument("--n_data_to_label", default=20, type=int)
parser.add_argument("--query_size", default=20, type=int)
parser.add_argument("--lr", default=0.001)
parser.add_argument("--heuristic", default="random", type=str)
parser.add_argument("--reduce", default="sum", type=str)
Expand Down Expand Up @@ -134,7 +134,7 @@ def main():
active_set,
model.predict_on_dataset_generator,
heuristic=heuristic,
ndata_to_label=hyperparams["n_data_to_label"],
query_size=hyperparams["query_size"],
# Instead of predicting on the entire pool, only a subset is used
max_sample=1000,
batch_size=batch_size,
Expand Down
2 changes: 1 addition & 1 deletion experiments/ssl_experiments/pimodel_mcdropout_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def add_model_specific_args(parent_parser):
callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))],
dataset=active_set,
heuristic=heuristic,
ndata_to_label=params.query_size,
query_size=params.query_size,
)

AL_STEPS = 2000
Expand Down
8 changes: 4 additions & 4 deletions experiments/vgg_mcdropout_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torchvision.transforms import transforms
from tqdm import tqdm

from baal import get_heuristic, ActiveLearningDataset
from baal import ActiveLearningLoop
from baal.active import get_heuristic, ActiveLearningDataset
from baal.active.active_loop import ActiveLearningLoop
from baal.bayesian.dropout import patch_module
from baal import ModelWrapper

Expand All @@ -27,7 +27,7 @@ def parse_args():
parser.add_argument("--epoch", default=100, type=int)
parser.add_argument("--batch_size", default=32, type=int)
parser.add_argument("--initial_pool", default=1000, type=int)
parser.add_argument("--n_data_to_label", default=100, type=int)
parser.add_argument("--query_size", default=100, type=int)
parser.add_argument("--lr", default=0.001)
parser.add_argument("--heuristic", default="bald", type=str)
parser.add_argument("--iterations", default=20, type=int)
Expand Down Expand Up @@ -107,7 +107,7 @@ def main():
active_set,
model.predict_on_dataset,
heuristic,
hyperparams.get("n_data_to_label", 1),
hyperparams.get("query_size", 1),
batch_size=10,
iterations=hyperparams["iterations"],
use_cuda=use_cuda,
Expand Down
2 changes: 1 addition & 1 deletion notebooks/compatibility/pytorch_lightning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@
"source": [
"trainer = BaalTrainer(dataset=data_module.active_dataset,\n",
" heuristic=heuristic,\n",
" ndata_to_label=hparams.query_size,\n",
" query_size=hparams.query_size,\n",
" max_epochs=10, default_root_dir=hparams.data_root,\n",
" gpus=hparams.gpus,\n",
" callbacks=[ResetCallback(copy.deepcopy(model.state_dict()))])"
Expand Down
4 changes: 2 additions & 2 deletions notebooks/compatibility/sklearn_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@
" print(\"Test Accuracy :: \", accuracy_score(test_y, predictions))\n",
" probs = predict(dataset.pool, clf)\n",
" to_label = heuristic(probs)\n",
" ndata_to_label = 10\n",
" query_size = 10\n",
" if len(to_label) > 0:\n",
" dataset.label(to_label[: ndata_to_label])\n",
" dataset.label(to_label[: query_size])\n",
" else:\n",
" break"
],
Expand Down
Loading

0 comments on commit a1afb8b

Please sign in to comment.