Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to run inference from a file #107

Merged
merged 11 commits into from Oct 11, 2019
2 changes: 1 addition & 1 deletion docs/examples.rst
Expand Up @@ -79,5 +79,5 @@ b) and a prediction head on top that is suited for our task => Text classificati
{"text": "Martin Müller spielt Fussball"},
]
model = Inferencer(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)
print(result)
2 changes: 1 addition & 1 deletion examples/doc_classification.py
Expand Up @@ -102,7 +102,7 @@
{"text": "Martin Müller spielt Handball in Berlin"},
]
model = Inferencer.load(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)
print(result)

# fmt: on
2 changes: 1 addition & 1 deletion examples/doc_classification_multilabel.py
Expand Up @@ -106,7 +106,7 @@
{"text": "What a lovely world"},
]
model = Inferencer.load(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)
print(result)


Expand Down
2 changes: 1 addition & 1 deletion examples/doc_regression.py
Expand Up @@ -94,7 +94,7 @@
{"text": ""},
]
model = Inferencer.load(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)

print(result)

Expand Down
2 changes: 1 addition & 1 deletion examples/ner.py
Expand Up @@ -96,5 +96,5 @@
{"text": "Martin Müller spielt Handball in Berlin"},
]
model = Inferencer.load(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)
print(result)
2 changes: 1 addition & 1 deletion examples/question_answering.py
Expand Up @@ -104,7 +104,7 @@
}]

model = Inferencer.load(save_dir)
result = model.run_inference(dicts=QA_input)
result = model.inference_from_dicts(dicts=QA_input)

for x in result:
pprint.pprint(x)
2 changes: 1 addition & 1 deletion farm/data_handler/data_silo.py
Expand Up @@ -56,7 +56,7 @@ def _multiproc(cls, chunk, processor):
return dataset

def _get_dataset(self, filename):
dicts = self.processor._file_to_dicts(filename)
dicts = self.processor.file_to_dicts(filename)
#shuffle list of dicts here if we later want to have a random dev set splitted from train set
if filename == self.processor.train_filename:
if not self.processor.dev_filename:
Expand Down
31 changes: 16 additions & 15 deletions farm/data_handler/processor.py
Expand Up @@ -4,6 +4,7 @@
import random
import logging
import json
import time
import inspect
from inspect import signature
import numpy as np
Expand Down Expand Up @@ -41,7 +42,7 @@ class Processor(ABC):
"""
Is used to generate PyTorch Datasets from input data. An implementation of this abstract class should be created
for each new data source.
Implement the abstract methods: _file_to_dicts(), _dict_to_samples(), _sample_to_features()
Implement the abstract methods: file_to_dicts(), _dict_to_samples(), _sample_to_features()
to be compatible with your data format
"""

Expand Down Expand Up @@ -236,7 +237,7 @@ def add_task(self, name, metric, label_list, label_column_name=None, label_name
}

@abc.abstractmethod
def _file_to_dicts(self, file: str) -> [dict]:
def file_to_dicts(self, file: str) -> [dict]:
raise NotImplementedError()

@abc.abstractmethod
Expand All @@ -248,7 +249,7 @@ def _sample_to_features(cls, sample: Sample) -> dict:
raise NotImplementedError()

def _init_baskets_from_file(self, file):
dicts = self._file_to_dicts(file)
dicts = self.file_to_dicts(file)
dataset_name = os.path.splitext(os.path.basename(file))[0]
baskets = [
SampleBasket(raw=tr, id=f"{dataset_name}-{i}") for i, tr in enumerate(dicts)
Expand Down Expand Up @@ -406,7 +407,7 @@ def __init__(
label_column_name=label_column_name,
task_type=task_type)

def _file_to_dicts(self, file: str) -> [dict]:
def file_to_dicts(self, file: str) -> [dict]:
column_mapping = {task["label_column_name"]: task["label_name"] for task in self.tasks.values()}
dicts = read_tsv(
filename=file,
Expand Down Expand Up @@ -497,7 +498,7 @@ def load_from_dir(cls, load_dir):
return processor


def _file_to_dicts(self, file: str) -> [dict]:
def file_to_dicts(self, file: str) -> [dict]:
raise NotImplementedError

def _dict_to_samples(self, dictionary: dict, **kwargs) -> [Sample]:
Expand Down Expand Up @@ -554,7 +555,7 @@ def __init__(
if metric and label_list:
self.add_task("ner", metric, label_list)

def _file_to_dicts(self, file: str) -> [dict]:
def file_to_dicts(self, file: str) -> [dict]:
dicts = read_ner_file(filename=file, sep=self.delimiter)
return dicts

Expand Down Expand Up @@ -616,7 +617,7 @@ def __init__(
self.add_task("nextsentence", "acc", ["False", "True"])


def _file_to_dicts(self, file: str) -> list:
def file_to_dicts(self, file: str) -> list:
dicts = read_docs_from_txt(filename=file, delimiter=self.delimiter, max_docs=self.max_docs)
return dicts

Expand Down Expand Up @@ -716,10 +717,10 @@ def __init__(
if metric and labels:
self.add_task("question_answering", metric, labels)

def dataset_from_dicts(self, dicts, index=None, from_inference=False):
if(from_inference):
dicts = [self._convert_inference(x) for x in dicts]
if(from_inference):
def dataset_from_dicts(self, dicts, index=None, rest_api_schema=False):
if rest_api_schema:
dicts = [self._convert_rest_api_dict(x) for x in dicts]
if rest_api_schema:
id_prefix = "infer"
else:
id_prefix = "train"
Expand All @@ -734,7 +735,7 @@ def dataset_from_dicts(self, dicts, index=None, from_inference=False):
dataset, tensor_names = self._create_dataset()
return dataset, tensor_names

def _convert_inference(self, infer_dict):
def _convert_rest_api_dict(self, infer_dict):
# convert input coming from inferencer to SQuAD format
converted = {}
converted["paragraphs"] = [
Expand All @@ -750,13 +751,13 @@ def _convert_inference(self, infer_dict):
]
return converted

def _file_to_dicts(self, file: str) -> [dict]:
def file_to_dicts(self, file: str) -> [dict]:
dict = read_squad_file(filename=file)
return dict

def _dict_to_samples(self, dictionary: dict, **kwargs) -> [Sample]:
if "paragraphs" not in dictionary: # TODO change this inference mode hack
dictionary = self._convert_inference(infer_dict=dictionary)
dictionary = self._convert_rest_api_dict(infer_dict=dictionary)
samples = create_samples_squad(entry=dictionary)
for sample in samples:
tokenized = tokenize_with_metadata(
Expand Down Expand Up @@ -822,7 +823,7 @@ def __init__(
self.add_task(name="regression", metric="mse", label_list= [scaler_mean, scaler_scale], label_column_name=label_column_name, task_type="regression", label_name=label_name)


def _file_to_dicts(self, file: str) -> [dict]:
def file_to_dicts(self, file: str) -> [dict]:
column_mapping = {task["label_column_name"]: task["label_name"] for task in self.tasks.values()}
dicts = read_tsv(
rename_columns=column_mapping,
Expand Down
92 changes: 71 additions & 21 deletions farm/infer.py
@@ -1,15 +1,20 @@
import os
import torch
import logging
import multiprocessing as mp
from contextlib import ExitStack
from functools import partial

from torch.utils.data.sampler import SequentialSampler
from tqdm import tqdm

from farm.data_handler.dataloader import NamedDataLoader
from farm.modeling.adaptive_model import AdaptiveModel

from farm.utils import initialize_device_settings
from farm.data_handler.processor import Processor, InferenceProcessor
from farm.utils import set_all_seeds
from farm.utils import log_ascii_workers


logger = logging.getLogger(__name__)
Expand All @@ -29,13 +34,14 @@ class Inferencer:
{"text": "Martin Müller spielt Handball in Berlin"},
]
model = Inferencer.load(your_model_dir)
model.run_inference(dicts=basic_texts)
model.inference_from_dicts(dicts=basic_texts)
# LM embeddings
model.extract_vectors(dicts=basic_texts)

"""

def __init__(self, model, processor, batch_size=4, gpu=False, name=None, return_class_probs=False):
def __init__(self, model, processor, batch_size=4, gpu=False, name=None, return_class_probs=False,
multiprocessing_chunk_size=100):
"""
Initializes inferencer from an AdaptiveModel and a Processor instance.

Expand Down Expand Up @@ -75,6 +81,7 @@ def __init__(self, model, processor, batch_size=4, gpu=False, name=None, return_
# raise NotImplementedError("A model with multiple prediction heads is currently not supported by the Inferencer")
self.name = name if name != None else f"anonymous-{self.prediction_type}"
self.return_class_probs = return_class_probs
self.multiprocessing_chunk_size = multiprocessing_chunk_size

model.connect_heads_with_processor(processor.tasks, require_labels=False)
set_all_seeds(42, n_gpu)
Expand Down Expand Up @@ -110,27 +117,45 @@ def load(cls, load_dir, batch_size=4, gpu=False, embedder_only=False, return_cla
name = os.path.basename(load_dir)
return cls(model, processor, batch_size=batch_size, gpu=gpu, name=name, return_class_probs=return_class_probs)

def run_inference(self, dicts):
"""
Runs down-stream inference using the prediction head.
def inference_from_file(self, file):
dicts = self.processor.file_to_dicts(file)

:param dicts: Samples to run inference on provided as a list of dicts. One dict per sample.
:type dicst: [dict]
:return: dict of predictions
dict_batches_to_process = int(len(dicts) / self.multiprocessing_chunk_size)
num_cpus = min(mp.cpu_count(), dict_batches_to_process) or 1

"""
if self.prediction_type == "embedder":
tanaysoni marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError(
"You have called run_inference for a model without any prediction head! "
"If you want to: "
"a) ... extract vectors from the language model: call `Inferencer.extract_vectors(...)`"
f"b) ... run inference on a downstream task: make sure your model path {self.name} contains a saved prediction head"
with ExitStack() as stack:
p = stack.enter_context(mp.Pool(processes=num_cpus))

logger.info(
f"Got ya {num_cpus} parallel workers to do inference on {len(dicts)}dicts (chunksize = {self.multiprocessing_chunk_size})..."
)
dataset, tensor_names = self.processor.dataset_from_dicts(dicts, from_inference=True)
log_ascii_workers(num_cpus, logger)

results = p.imap(
partial(self._multiproc_dict_to_samples, processor=self.processor),
dicts,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why dont we need the grouper here anymore?

chunksize=self.multiprocessing_chunk_size,
)

preds_all = []
with tqdm(total=len(dicts), unit=' Dicts') as pbar:
for dataset, tensor_names, sample in results:
preds_all.append(self._run_inference(dataset, tensor_names, sample))
pbar.update(self.multiprocessing_chunk_size)

return preds_all

@classmethod
def _multiproc_dict_to_samples(cls, dicts, processor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets change the name, since dict_to_samples has an own connotation inside farm. here we do conversion to datasets and converting input data to samples at the same time.

Another way would be to add an inference flag and not delete the samples after the torch dataset is created. That way we do not need to preprocess twice the exact same data. Lets do this only if it is easy to integrate. Otherwise lets move forward with this PR

dicts_list = [dicts]
dataset, tensor_names = processor.dataset_from_dicts(dicts_list)
samples = []
for dict in dicts:
samples.extend(self.processor._dict_to_samples(dict))
for d in dicts_list:
samples.extend(processor._dict_to_samples(d))

return dataset, tensor_names, samples

def _run_inference(self, dataset, tensor_names, samples):
data_loader = NamedDataLoader(
dataset=dataset,
sampler=SequentialSampler(dataset),
Expand All @@ -141,7 +166,7 @@ def run_inference(self, dicts):
preds_all = []
for i, batch in enumerate(data_loader):
batch = {key: batch[key].to(self.device) for key in batch}
batch_samples = samples[i * self.batch_size : (i + 1) * self.batch_size]
batch_samples = samples[i * self.batch_size: (i + 1) * self.batch_size]
with torch.no_grad():
logits = self.model.forward(**batch)
preds = self.model.formatted_preds(
Expand All @@ -155,6 +180,31 @@ def run_inference(self, dicts):

return preds_all

def inference_from_dicts(self, dicts):
"""
Runs down-stream inference using the prediction head.

:param dicts: Samples to run inference on provided as a list of dicts. One dict per sample.
:type dicts: [dict]
:return: dict of predictions

"""
if self.prediction_type == "embedder":
raise TypeError(
"You have called inference_from_dicts for a model without any prediction head! "
"If you want to: "
"a) ... extract vectors from the language model: call `Inferencer.extract_vectors(...)`"
f"b) ... run inference on a downstream task: make sure your model path {self.name} contains a saved prediction head"
)
dataset, tensor_names = self.processor.dataset_from_dicts(dicts, from_inference=True)
samples = []
for dict in dicts:
samples.extend(self.processor._dict_to_samples(dict))

preds_all = self._run_inference(dataset, tensor_names, samples)

return preds_all

def extract_vectors(
self, dicts, extraction_strategy="cls_token", extraction_layer=-1
):
Expand All @@ -163,8 +213,8 @@ def extract_vectors(

:param dicts: Samples to run inference on provided as a list of dicts. One dict per sample.
:type dicts: [dict]
:param extraction_strategy: Strategy to extract vectors. Choices: 'cls_token' (sentence vector),
'reduce_mean' (sentence vector), reduce_max (sentence vector), 'per_token' (individual token vectors)
:param extraction_strategy: Strategy to extract vectors. Choices: 'cls_token' (sentence vector), 'reduce_mean'
(sentence vector), reduce_max (sentence vector), 'per_token' (individual token vectors)
:type extraction_strategy: str
:param extraction_layer: number of layer from which the embeddings shall be extracted. Default: -1 (very last layer).
:type: int
Expand Down
2 changes: 1 addition & 1 deletion test/test_doc_classification.py
Expand Up @@ -79,7 +79,7 @@ def test_doc_classification(caplog):


inf = Inferencer.load(save_dir,batch_size=2)
result = inf.run_inference(dicts=basic_texts)
result = inf.inference_from_dicts(dicts=basic_texts)
assert isinstance(result[0]["predictions"][0]["probability"],np.float32)


Expand Down
2 changes: 1 addition & 1 deletion test/test_doc_regression.py
Expand Up @@ -75,7 +75,7 @@ def test_doc_regression(caplog):
]

model = Inferencer.load(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)
assert isinstance(result[0]["predictions"][0]["pred"], np.float32)

if(__name__=="__main__"):
Expand Down
2 changes: 1 addition & 1 deletion test/test_ner.py
Expand Up @@ -75,7 +75,7 @@ def test_ner(caplog):
{"text": "Schartau sagte dem Tagesspiegel, dass Fischer ein Idiot sei"},
]
model = Inferencer.load(save_dir)
result = model.run_inference(dicts=basic_texts)
result = model.inference_from_dicts(dicts=basic_texts)
assert result[0]["predictions"][0]["context"] == "sagte"
assert isinstance(result[0]["predictions"][0]["probability"], np.float32)

Expand Down
2 changes: 1 addition & 1 deletion test/test_question_answering.py
Expand Up @@ -78,7 +78,7 @@ def test_qa(caplog):
]

model = Inferencer.load(save_dir)
result = model.run_inference(dicts=QA_input)
result = model.inference_from_dicts(dicts=QA_input)
assert isinstance(result[0]["predictions"][0]["end"],int)

if(__name__=="__main__"):
Expand Down