Skip to content
Permalink
Browse files

Initial commit of Python collections wrapper and passage retrieval se…

…tup (#645)
  • Loading branch information...
emmileaf authored and lintool committed May 15, 2019
1 parent 949cd39 commit 284d0191166da8aa3125fa4d9e4831bb75fc4ee4
@@ -0,0 +1 @@

@@ -0,0 +1,48 @@

## Python Collections Wrapper

### Example usage

```
import sys
sys.path += ['src/main/python/io/anserini']
from collection import pycollection, pygenerator
```

```
# Fetching collection given class and path to directory
collection = Collection('TrecCollection', '/path/to/disk45')
# Get file segment in collection
fs = next(collection.segments)
# Get doc in file segment
doc = next(fs)
# Document id
doc.id
# Raw document contents
doc.contents
```

```
# Fetching Lucene document generator given generator class
generator = Generator('JsoupGenerator')
```

### To iterate over collection and process document

```
collection = Collection(collection_class, input_path)
for (i, fs) in enumerate(collection.segments):
for (i, doc) in enumerate(fs):
# foo(doc)
# for example:
parsed_doc = generator.generator.createDocument(doc.document)
id = parsed_doc.get('id') # FIELD_ID
raw = parsed_doc.get('raw') # FIELD_RAW
contents = parsed_doc.get('contents') # FIELD_BODY
```
@@ -0,0 +1 @@

@@ -0,0 +1,73 @@
from .pyjnius_utils import JCollections, JPaths
from .threading_utils import Counters

import logging
logger = logging.getLogger(__name__)

class Collection:

def __init__(self, collection_class, collection_path):
self.counters = Counters()
self.collection_class = collection_class
self.collection_path = JPaths.get(collection_path)
self.collection = self._get_collection()
self.collection.setCollectionPath(self.collection_path)
self.segment_paths = self.collection.getFileSegmentPaths()
self.segments = (FileSegment(self,
self.collection.createFileSegment(path),
path) for path in self.segment_paths.toArray())

def _get_collection(self):
try:
return JCollections[self.collection_class].value()
except:
raise ValueError(self.collection_class)


class FileSegment:

def __init__(self, collection, segment, segment_path):
self.collection = collection
self.segment = segment
self.segment_path = segment_path
self.segment_name = segment_path.getFileName().toString()

def __iter__(self):
return self

def __next__(self):
if self.segment.hasNext():
try:
d = self.segment.next()
if not d.indexable():
logger.error(self.segment_name +
": Document not indexable, skipping...")
self.collection.counters.unindexable.increment()
return self.__next__()
else:
return Document(self, d)
except:
logger.error(self.segment_name +
": Error fetching iter.next(), skipping...")
self.collection.counters.skipped.increment()
return self.__next__()
else:
if (self.segment.getNextRecordStatus().toString == 'ERROR'):
logger.error(self.segment_name +
": EOF - Error from getNextRecordStatus()")
self.collection.counters.errors.increment()
self.segment.close()
raise StopIteration
else:
self.segment.close()
raise StopIteration


class Document:

def __init__(self, segment, document):
self.segment = segment
self.document = document
self.id = document.id()
self.contents = document.content()

@@ -0,0 +1,19 @@
from .pyjnius_utils import JIndexHelpers, JGenerators

import logging
logger = logging.getLogger(__name__)

class Generator:

def __init__(self, generator_class):
self.counters = JIndexHelpers.JCounters()
self.args = JIndexHelpers.JArgs()
self.generator_class = generator_class
self.generator = self._get_generator()

def _get_generator(self):
try:
return JGenerators[self.generator_class].value(self.args, self.counters)
except:
raise ValueError(self.generator_class)

@@ -0,0 +1,52 @@
import jnius_config
jnius_config.set_classpath("target/anserini-0.4.1-SNAPSHOT-fatjar.jar")

from jnius import autoclass
from jnius import cast

from enum import Enum


JString = autoclass('java.lang.String')
JPath = autoclass('java.nio.file.Path')
JPaths = autoclass('java.nio.file.Paths')
JList = autoclass('java.util.List')


class JIndexHelpers:

def JArgs():
args = autoclass('io.anserini.index.IndexCollection$Args')()
args.storeRawDocs = True ## to store raw text as an option
args.dryRun = True ## So that indexing will be skipped
return args

def JCounters():
IndexCollection = autoclass('io.anserini.index.IndexCollection')
Counters = autoclass('io.anserini.index.IndexCollection$Counters')
return Counters(IndexCollection)


class JCollections(Enum):
CarCollection = autoclass('io.anserini.collection.CarCollection')
ClueWeb09Collection = autoclass('io.anserini.collection.ClueWeb09Collection')
ClueWeb12Collection = autoclass('io.anserini.collection.ClueWeb12Collection')
HtmlCollection = autoclass('io.anserini.collection.HtmlCollection')
JsonCollection = autoclass('io.anserini.collection.JsonCollection')
NewYorkTimesCollection = autoclass('io.anserini.collection.NewYorkTimesCollection')
TrecCollection = autoclass('io.anserini.collection.TrecCollection')
TrecwebCollection = autoclass('io.anserini.collection.TrecwebCollection')
TweetCollection = autoclass('io.anserini.collection.TweetCollection')
WashingtonPostCollection = autoclass('io.anserini.collection.WashingtonPostCollection')
WikipediaCollection = autoclass('io.anserini.collection.WikipediaCollection')


class JGenerators(Enum):
LuceneDocumentGenerator = autoclass('io.anserini.index.generator.LuceneDocumentGenerator')
JsoupGenerator = autoclass('io.anserini.index.generator.JsoupGenerator')
NekoGenerator = autoclass('io.anserini.index.generator.NekoGenerator')
TweetGenerator = autoclass('io.anserini.index.generator.TweetGenerator')
WapoGenerator = autoclass('io.anserini.index.generator.WapoGenerator')



@@ -0,0 +1,23 @@
import threading


class ThreadSafeCount:

def __init__(self):
self.value = 0
self.lock = threading.Lock()

def increment(self, inc=1):
with self.lock:
self.value += inc
return self.value


class Counters:

def __init__(self):
self.indexable = ThreadSafeCount()
self.unindexable = ThreadSafeCount()
self.skipped = ThreadSafeCount()
self.errors = ThreadSafeCount()

@@ -0,0 +1,40 @@
### Segmenting collection

`segment.py` can be called from command line by specifying the following arguments:
- `--input {path to input directory containing collection}`
- `--collection {collection class}`
- `--generator {generator class}`
- `--output {path to create output collection directory}`
- `--threads {max number of threads} `
- `--tokenize {tokenizing function to call}`
- `--raw` if raw text to be used instead of transformed body contents

An example run with Robust04 sentencing can be found [here](example/example.md).

### Python-Java Bridging

`collection_iterator.py` replicates `IndexCollection` logic for iterating over collections and generating Lucene documents with:
- id (FIELD_ID)
- parsed contents (FIELD_BODY)
- raw contents (FIELD_RAW)

`collection` contains Python wrapper code for accessing Anserini's Java collection and generator classes with Pyjnius.

### Document Tokenizing

Instead of performing indexing steps in `IndexCollection`, a tokenizer function defined in `document_tokenizer.py`
can be called on either the parsed or raw document content to split into segments and output as JSON arrays into a JsonCollection:

```
[
{
"id":"{$DOCID}.000000",
"content":"{segment-content}"
},
{
"id":"{$DOCID}.000001",
"content":"{segment-content}"
}
]
```

No changes.
@@ -0,0 +1,108 @@
import os
import json
import time
import datetime
from concurrent.futures import ThreadPoolExecutor
from document_tokenizer import DocumentTokenizer

import sys
sys.path += ['src/main/python/io/anserini']
from collection import pycollection, pygenerator

import logging
logger = logging.getLogger(__name__)


def IterSegment(fs, generator, output_path, tokenizer, raw):

results = []
doc_count = 0

for (i, d) in enumerate(fs):
# Generate Lucene document, then fetch fields
try:
doc = generator.generator.createDocument(d.document)
if doc is None:
logger.warn("Generator did not return document, skipping...")
fs.collection.counters.skipped.increment()
continue
id = doc.get('id')
contents = doc.get('raw') if raw else doc.get('contents')
doc = {'id': id, 'contents': contents}

except:
logger.error("Error generating Lucene document, skipping...")
fs.collection.counters.skipped.increment()
continue

# append resulting json to list of documents
if tokenizer is None:
results.append(doc)
else:
# split document into segments
try:
array = tokenizer(id, contents)
results += array # merge lists
except:
fs.collection.counters.skipped.increment()
logger.error("Error tokenizing document, skipping...")
continue

doc_count += 1

# count number of full documents parsed
fs.collection.counters.indexable.increment(doc_count)
logger.info(fs.segment_name + ": " + str(doc_count) + " documents parsed")

count = len(results)
if (count > 0):
# write json array to outputdir (either as array of docs, or many arrays of doc tokens)
with open(os.path.join(output_path, '{}.json'.format(fs.segment_name)), 'w') as f:
jsonstr = json.dumps(results, separators=(',', ':'), indent=2)
f.write(jsonstr)

logger.info("Finished iterating over segment: " +
fs.segment_name + " with " +
str(count) + " results.")
else:
logger.info("No documents parsed from segment: " + fs.segment_name)


def IterCollection(input_path, collection_class,
generator_class, output_path,
threads=1, tokenize=None, raw=False):

start = time.time()
logger.info("Begin reading collection.")

## Check and create tokenizer
tokenizer = None
if tokenize is not None:
try:
tokenizer = DocumentTokenizer(tokenize).tokenizer
except:
raise ValueError(tokenize)

collection = pycollection.Collection(collection_class, input_path)
generator = pygenerator.Generator(generator_class)

if not os.path.exists(output_path):
logger.info("making directory...")
os.mkdir(output_path)

with ThreadPoolExecutor(max_workers=threads) as executor:
for (seg_num, fs) in enumerate(collection.segments):
executor.submit(IterSegment, fs, generator, output_path, tokenizer, raw)

end = time.time()
elapsed = end - start

print("all threads complete")
logger.info("# Final Counter Values");
logger.info("indexable: {:12d}".format(collection.counters.indexable.value))
logger.info("unindexable: {:12d}".format(collection.counters.unindexable.value))
logger.info("skipped: {:12d}".format(collection.counters.skipped.value))
logger.info("errors: {:12d}".format(collection.counters.errors.value))

logger.info("Total duration: %s", str(datetime.timedelta(seconds=elapsed)))

Oops, something went wrong.

0 comments on commit 284d019

Please sign in to comment.
You can’t perform that action at this time.