Skip to content

Commit

Permalink
Add chunk-wise processing for conversion of file to PyTorch DataSets (#…
Browse files Browse the repository at this point in the history
…88)

This moves the `multiprocessing` bits from the `Processor` to the `DataSilo`. All dynamic updates of Class attributes in the `Processor` have been removed. 

The per chunk processing lowers the memory footprint as compared to reading an entire file and then converting to PyTorch DataSets.
  • Loading branch information
tanaysoni committed Sep 19, 2019
1 parent 05eab4b commit 5c93e1e
Show file tree
Hide file tree
Showing 29 changed files with 366 additions and 403 deletions.
9 changes: 6 additions & 3 deletions examples/doc_classification.py
Expand Up @@ -44,9 +44,9 @@
processor = TextClassificationProcessor(tokenizer=tokenizer,
max_seq_len=128,
data_dir="../data/germeval18",
labels=label_list,
label_list=label_list,
metric=metric,
source_field="coarse_label"
label_column_name="coarse_label"
)

# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
Expand All @@ -58,7 +58,10 @@
# a) which consists of a pretrained language model as a basis
language_model = Bert.load(lang_model)
# b) and a prediction head on top that is suited for our task => Text classification
prediction_head = TextClassificationHead(layer_dims=[768, len(processor.tasks["text_classification"]["label_list"])])
prediction_head = TextClassificationHead(layer_dims=[768, len(processor.tasks["text_classification"]["label_list"])],
class_weights=data_silo.calculate_class_weights(task_name="text_classification"))



model = AdaptiveModel(
language_model=language_model,
Expand Down
4 changes: 3 additions & 1 deletion examples/doc_classification_multilabel.py
Expand Up @@ -45,10 +45,12 @@
processor = TextClassificationProcessor(tokenizer=tokenizer,
max_seq_len=128,
data_dir="../data/toxic-comments",
labels=label_list,
label_list=label_list,
label_column_name="label",
metric=metric,
quote_char='"',
multilabel=True,
train_filename="train.tsv",
dev_filename="val.tsv",
test_filename=None,
dev_split=0
Expand Down
1 change: 1 addition & 0 deletions examples/doc_regression.py
Expand Up @@ -40,6 +40,7 @@
processor = RegressionProcessor(tokenizer=tokenizer,
max_seq_len=128,
data_dir="../data/<YOUR-DATASET>",
label_column_name="label"
)

# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
Expand Down
6 changes: 3 additions & 3 deletions examples/ner.py
Expand Up @@ -37,11 +37,11 @@
)

# 2. Create a DataProcessor that handles all the conversion from raw text into a pytorch Dataset
ner_labels = ["[PAD]", "X", "O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-OTH", "I-OTH"]

processor = NERProcessor(
tokenizer=tokenizer, max_seq_len=128, data_dir="../data/conll03-de"
tokenizer=tokenizer, max_seq_len=128, data_dir="../data/conll03-de", metric="seq_f1",label_list=ner_labels
)
ner_labels = ["[PAD]", "X", "O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-OTH", "I-OTH"]
processor.add_task("ner", "seq_f1", ner_labels)

# 3. Create a DataSilo that loads several datasets (train/dev/test), provides DataLoaders for them and calculates a few descriptive statistics of our datasets
data_silo = DataSilo(processor=processor, batch_size=batch_size)
Expand Down
2 changes: 1 addition & 1 deletion experiments/ner/conll2003_de_config.json
Expand Up @@ -27,7 +27,7 @@
"dev_filename": {"value": null, "default": "dev.txt", "desc": "Filename for development. Missing in case of GermEval2018."},
"test_filename": {"value": null, "default": "test.txt", "desc": "Filename for testing. It is the submission file from competition."},
"delimiter": {"value": null, "default": "\t", "desc": "Delimiter used to seprate columns in input data."},
"labels": {"value": null, "default": ["[PAD]", "X", "O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-OTH", "I-OTH"], "desc": ""},
"label_list": {"value": null, "default": ["[PAD]", "X", "O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-OTH", "I-OTH"], "desc": ""},
"metric": {"value": null, "default": "seq_f1", "desc": "Metric used. A f1 scored tailored to sequences of labels."}
},

Expand Down
2 changes: 1 addition & 1 deletion experiments/ner/germEval14_config.json
Expand Up @@ -27,7 +27,7 @@
"dev_filename": {"value": null, "default": "dev.txt", "desc": "Filename for development. Missing in case of GermEval2018."},
"test_filename": {"value": null, "default": "test.txt", "desc": "Filename for testing. It is the submission file from competition."},
"delimiter": {"value": null, "default": " ", "desc": "Delimiter used to seprate columns in input data."},
"labels": {"value": null, "default": ["[PAD]", "X", "O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-OTH", "I-OTH"], "desc": ""},
"label_list": {"value": null, "default": ["[PAD]", "X", "O", "B-MISC", "I-MISC", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-OTH", "I-OTH"], "desc": ""},
"metric": {"value": null, "default": "seq_f1", "desc": "Metric used. A f1 scored tailored to sequences of labels."}
},

Expand Down
2 changes: 1 addition & 1 deletion experiments/qa/squad20_config.json
Expand Up @@ -26,7 +26,7 @@
"train_filename": {"value": null, "default": "train-v2.0.json", "desc": "Filename for training."},
"dev_filename": {"value": null, "default": "dev-v2.0.json", "desc": "Filename for development. Contains multiple possible answers."},
"test_filename": {"value": null, "default": null, "desc": "Filename for testing. It is the submission file from competition."},
"labels": {"value": null, "default": ["start_token", "end_token"], "desc": ""},
"label_list": {"value": null, "default": ["start_token", "end_token"], "desc": ""},
"metric": {"value": null, "default": "squad", "desc": "Metric used. A f1 scored tailored to sequences of labels."}
},

Expand Down
4 changes: 2 additions & 2 deletions experiments/text_classification/germEval18Coarse_config.json
Expand Up @@ -28,9 +28,9 @@
"test_filename": {"value": null, "default": "test.tsv", "desc": "Filename for testing. It is the submission file from competition."},
"delimiter": {"value": null, "default": "\t", "desc": "Filename for testing. It is the submission file from competition."},
"columns": {"value": null, "default": ["text", "label", "unused"], "desc": "Columns specifying position of text and labels in data files."},
"labels": {"value": null, "default": ["OTHER", "OFFENSE"], "desc": "List of possible labels."},
"label_list": {"value": null, "default": ["OTHER", "OFFENSE"], "desc": "List of possible labels."},
"metric": {"value": null, "default": "f1_macro", "desc": "Metric used. The competition uses macro averaged f1 score."},
"source_field": {"value": null, "default": "coarse_label", "desc":"Name of field that the label comes from in datasource"},
"label_column_name": {"value": null, "default": "coarse_label", "desc":"Name of field that the label comes from in datasource"},
"skiprows": {"value": null, "default": null, "desc":""}
},
"parameter": {
Expand Down
4 changes: 2 additions & 2 deletions experiments/text_classification/germEval18Fine_config.json
Expand Up @@ -28,9 +28,9 @@
"test_filename": {"value": null, "default": "test.tsv", "desc": "Filename for testing. It is the submission file from competition."},
"delimiter": {"value": null, "default": "\t", "desc": "Filename for testing. It is the submission file from competition."},
"columns": {"value": null, "default": ["text", "unused", "label"], "desc": "Columns specifying position of text and labels in data files."},
"labels": {"value": null, "default": ["OTHER", "INSULT", "ABUSE", "PROFANITY"],"desc": "List of possible labels."},
"label_list": {"value": null, "default": ["OTHER", "INSULT", "ABUSE", "PROFANITY"],"desc": "List of possible labels."},
"metric": {"value": null, "default": "f1_macro", "desc": "Metric used. The competition uses macro averaged f1 score."},
"source_field": {"value": null, "default": "fine_label", "desc":"Name of field that the label comes from in datasource"},
"label_column_name": {"value": null, "default": "fine_label", "desc":"Name of field that the label comes from in datasource"},
"skiprows": {"value": null, "default": null, "desc":""}
},

Expand Down
4 changes: 2 additions & 2 deletions experiments/text_classification/gnad_config.json
Expand Up @@ -27,9 +27,9 @@
"dev_filename": {"value": null, "default": null, "desc": "Filename for development. Missing in case of GermEval2018."},
"test_filename": {"value": null, "default": "test.csv", "desc": "Filename for testing. It is the submission file from competition."},
"delimiter": {"value": null, "default": ";", "desc": "Filename for testing. It is the submission file from competition."},
"labels": {"value": null, "default": ["Web","Sport","International","Panorama","Wissenschaft","Wirtschaft","Kultur","Etat","Inland"], "desc": "List of possible labels."},
"label_list": {"value": null, "default": ["Web","Sport","International","Panorama","Wissenschaft","Wirtschaft","Kultur","Etat","Inland"], "desc": "List of possible labels."},
"metric": {"value": null, "default": "acc", "desc": "Metric used. Multiclass accuracy - metric used in fast.ai forum results for this dataset."},
"source_field": {"value": null, "default": "label", "desc": "Name of field that the label comes from in datasource"},
"label_column_name": {"value": null, "default": "label", "desc": "Name of field that the label comes from in datasource"},
"skiprows": {"value": null, "default": null, "desc": ""}
},

Expand Down
8 changes: 7 additions & 1 deletion farm/__init__.py
@@ -1,5 +1,11 @@
import logging
import torch

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO)
level=logging.INFO,
)

# fix for a file descriptor issue when using multiprocessing: https://github.com/pytorch/pytorch/issues/973
torch.multiprocessing.set_sharing_strategy("file_system")
112 changes: 93 additions & 19 deletions farm/data_handler/data_silo.py
@@ -1,28 +1,36 @@
import copy
import logging

import multiprocessing as mp
import os
import copy
from contextlib import ExitStack
from functools import partial
import random

import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.utils.data import ConcatDataset, DataLoader, random_split, Subset, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler
from tqdm import tqdm

from farm.data_handler.dataloader import NamedDataLoader
from farm.utils import MLFlowLogger as MlLogger
from farm.data_handler.processor import Processor
from farm.data_handler.utils import grouper
from farm.utils import MLFlowLogger as MlLogger
from farm.utils import log_ascii_workers
from farm.visual.ascii.images import TRACTOR_SMALL

logger = logging.getLogger(__name__)


class DataSilo(object):
class DataSilo:
""" Generates and stores PyTorch DataLoader objects for the train, dev and test datasets.
Relies upon functionality in the processor to do the conversion of the data. Will also
calculate and display some statistics.
"""

def __init__(self, processor, batch_size, distributed=False):
def __init__(self, processor, batch_size, distributed=False, multiprocessing_chunk_size=100):
"""
:param processor: A dataset specific Processor object which will turn input (file or dict) into a Pytorch Dataset.
:type processor: Processor
Expand All @@ -37,15 +45,56 @@ def __init__(self, processor, batch_size, distributed=False):
self.data = {}
self.batch_size = batch_size
self.class_weights = None
self.multiprocessing_chunk_size = multiprocessing_chunk_size
self.max_processes = 128
self._load_data()

@classmethod
def _multiproc(cls, chunk, processor):
dicts = [d[1] for d in chunk]
index = chunk[0][0]
dataset = processor.dataset_from_dicts(dicts=dicts,index=index)
return dataset

def _get_dataset(self, filename):
dicts = self.processor._file_to_dicts(filename)
#shuffle list of dicts here if we later want to have a random dev set splitted from train set
if filename == self.processor.train_filename:
if not self.processor.dev_filename:
if self.processor.dev_split > 0.0:
dicts = random.shuffle(dicts)

dict_batches_to_process = int(len(dicts) / self.multiprocessing_chunk_size)
num_cpus = min(mp.cpu_count(), self.max_processes, dict_batches_to_process) or 1

with ExitStack() as stack:
p = stack.enter_context(mp.Pool(processes=num_cpus))

logger.info(
f"Got ya {num_cpus} parallel workers to convert dict chunks to datasets (chunksize = {self.multiprocessing_chunk_size})..."
)
log_ascii_workers(num_cpus, logger)

results = p.imap(
partial(self._multiproc, processor=self.processor),
grouper(dicts, self.multiprocessing_chunk_size),
chunksize=1,
)

datasets = []
for dataset, tensor_names in tqdm(results, total=len(dicts)/self.multiprocessing_chunk_size):
datasets.append(dataset)

concat_datasets = ConcatDataset(datasets)
return concat_datasets, tensor_names

def _load_data(self):
logger.info("\nLoading data into the data silo ..."
"{}".format(TRACTOR_SMALL))
# train data
train_file = os.path.join(self.processor.data_dir, self.processor.train_filename)
logger.info("Loading train set from: {} ".format(train_file))
self.data["train"], self.tensor_names = self.processor.dataset_from_file(train_file)
self.data["train"], self.tensor_names = self._get_dataset(train_file)

# dev data
if not self.processor.dev_filename:
Expand All @@ -58,13 +107,13 @@ def _load_data(self):
else:
dev_file = os.path.join(self.processor.data_dir, self.processor.dev_filename)
logger.info("Loading dev set from: {}".format(dev_file))
self.data["dev"], _ = self.processor.dataset_from_file(dev_file)
self.data["dev"], _ = self._get_dataset(dev_file)

# test data
if self.processor.test_filename:
test_file = os.path.join(self.processor.data_dir, self.processor.test_filename)
logger.info("Loading test set from: {}".format(test_file))
self.data["test"], _ = self.processor.dataset_from_file(test_file)
self.data["test"], _ = self._get_dataset(test_file)
else:
logger.info("No test set is being loaded")
self.data["test"] = None
Expand All @@ -73,7 +122,6 @@ def _load_data(self):
self._calculate_statistics()
#self.calculate_class_weights()
self._initialize_data_loaders()
# fmt: on

def _initialize_data_loaders(self):
if self.distributed:
Expand Down Expand Up @@ -120,14 +168,38 @@ def _create_dev_from_train(self):
n_train = len(self.data["train"]) - n_dev

# Todo: Seed
train_dataset, dev_dataset = random_split(self.data["train"], [n_train, n_dev])
# if(isinstance(self.data["train"], Dataset)):
# train_dataset, dev_dataset = random_split(self.data["train"], [n_train, n_dev])
# else:
train_dataset, dev_dataset = self.random_split_ConcatDataset(self.data["train"], lengths=[n_train, n_dev])
self.data["train"] = train_dataset
self.data["dev"] = dev_dataset
if(len(dev_dataset) > 0):
self.data["dev"] = dev_dataset
else:
logger.warning("No dev set created. Maybe adjust the dev_split parameter or the multiprocessing chunk size")

logger.info(
f"Took {n_dev} samples out of train set to create dev set (dev split = {self.processor.dev_split})"
f"Took {len(dev_dataset)} samples out of train set to create dev set (dev split is roughly {self.processor.dev_split})"
)

def random_split_ConcatDataset(self, ds, lengths):
"""
Roughly split a Concatdataset into non-overlapping new datasets of given lengths.
Samples inside Concatdataset should already be shuffled
Arguments:
ds (Dataset): Dataset to be split
lengths (sequence): lengths of splits to be produced
"""
if sum(lengths) != len(ds):
raise ValueError("Sum of input lengths does not equal the length of the input dataset!")

idx_dataset = np.where(np.array(ds.cumulative_sizes) > lengths[0])[0][0]

train = ConcatDataset(ds.datasets[:idx_dataset])
test = ConcatDataset(ds.datasets[idx_dataset:])
return train, test

def _calculate_statistics(self,):
self.counts = {
"train": len(self.data["train"])
Expand All @@ -143,12 +215,14 @@ def _calculate_statistics(self,):
else:
self.counts["test"] = 0

train_input_numpy = self.data["train"][:][0].numpy()
seq_lens = np.sum(train_input_numpy != 0, axis=1)
self.ave_len = np.mean(seq_lens)
max_seq_len = self.data["train"][:][0].shape[1]
self.clipped = np.mean(seq_lens == max_seq_len)
seq_lens = []
for dataset in self.data["train"].datasets:
train_input_numpy = dataset[:][0].numpy()
seq_lens.extend(np.sum(train_input_numpy != 0, axis=1))
max_seq_len = dataset[:][0].shape[1]

self.clipped = np.mean(np.array(seq_lens) == max_seq_len)
self.ave_len = np.mean(seq_lens)

logger.info("Examples in train: {}".format(self.counts["train"]))
logger.info("Examples in dev : {}".format(self.counts["dev"]))
Expand Down
2 changes: 1 addition & 1 deletion farm/data_handler/dataset.py
Expand Up @@ -11,7 +11,7 @@ def convert_features_to_dataset(features):
names of the type of feature and the keys are the features themselves.
:Return: a Pytorch dataset and a list of tensor names.
"""
tensor_names = features[0].keys()
tensor_names = list(features[0].keys())
all_tensors = []
for t_name in tensor_names:
try:
Expand Down

0 comments on commit 5c93e1e

Please sign in to comment.