# Referential integrity checker (prototype)

## Prerequisites

Before running this notebook, make sure you have done the following:

1. Run `$ make up-dev`
2. Map `localhost:27018` to the Mongo server you want to use
3. Load a recent dump of the production Mongo database into that Mongo server (see `$ make mongorestore-nmdc-db` for an example)
4. In the `.env` file, set `MONGO_HOST` to `mongodb://localhost:27018`
5. Run `$ export $(grep -v '^#' .env | xargs)` to load the environment variables defined in `.env` into your shell environment
6. Run `make init` to ensure a consistent python kernel for this notebook.

Once you've done all of those things, you can run this notebook (e.g. via `$ jupyter notebook`) 


In [1]:
!echo $MONGO_HOST

mongodb://localhost:27018


## Enable automatic reloading of modules

Reference: https://ipython.readthedocs.io/en/stable/config/extensions/autoreload.html#autoreload

In [2]:
# Ensure code changes in this notebook will be import-able  
# without needing to restart the kernel and lose state
%load_ext autoreload
%autoreload 2

## Import Python modules

Be sure you're using the version of `nmdc-schema` you think you are!

In [3]:
from importlib.metadata import version

version("nmdc-schema")

'11.0.0rc22'

In [4]:
from collections import defaultdict
import concurrent.futures
from itertools import chain
import os
import re

from linkml_runtime.utils.schemaview import SchemaView
from pymongo import InsertOne
from toolz import dissoc, assoc
from tqdm.notebook import tqdm

from nmdc_runtime.api.core.util import pick
from nmdc_runtime.api.db.mongo import get_mongo_db, get_nonempty_nmdc_schema_collection_names, get_collection_names_from_schema
from nmdc_runtime.util import collection_name_to_class_names, populated_schema_collection_names_with_id_field, nmdc_schema_view, nmdc_database_collection_instance_class_names, get_nmdc_jsonschema_dict
from nmdc_schema.nmdc import Database as NMDCDatabase 
from nmdc_schema.get_nmdc_view import ViewGetter

mdb = get_mongo_db()
schema_view = nmdc_schema_view()

## Check for errors in the database

The `nmdc_schema_collection_names` function returns the populated (having at least one document) set-intersection of (a) the set of collection names present in the Mongo database and (b) the set of Database slots in the schema that correspond to a collection (defined as being multivalued and values being inlined as a list).

In [5]:
collection_names = get_nonempty_nmdc_schema_collection_names(mdb)
print(collection_names)

{'study_set', 'workflow_execution_set', 'material_processing_set', 'instrument_set', 'data_object_set', 'configuration_set', 'biosample_set', 'functional_annotation_agg', 'calibration_set', 'processed_sample_set', 'field_research_site_set', 'data_generation_set'}


Collect all possible classes of documents across all schema collections. `collection_name_to_class_names` is a mapping from collection name to a list of class names allowable for that collection's documents.

In [6]:
document_class_names = set(chain.from_iterable(collection_name_to_class_names.values()))

Map each document-class name to a map of slot name to slot definition. Class slots here are (to quote the LinkML SchemaView documentation) "all slots that are asserted or inferred for [the] class, with their inferred semantics."

In [7]:
cls_slot_map = {
    cls_name : {slot.name: slot
                for slot in schema_view.class_induced_slots(cls_name)
               }
    for cls_name in document_class_names
}

In [8]:
def collect_errors(note_doc_field_errors):
    errors = {"bad_type": [], "no_type": [], "bad_slot": [], "is_null": []}
    n_docs_total = sum(mdb[coll_name].estimated_document_count() for coll_name in collection_names)
    pbar = tqdm(total=n_docs_total)
    n_errors_cache = 0
    for coll_name in sorted(collection_names):
        cls_names = collection_name_to_class_names[coll_name]
        pbar.set_description(f"processing {coll_name}...")
        # Iterate over each document (as a dictionary) in this collection.
        for doc in mdb[coll_name].find():
            doc = dissoc(doc, "_id")
            
            # Ensure we know the document's type.
            cls_name = None
            cls_type_match = re.match(r"^nmdc:(?P<name>.+)", doc.get("type", ""))
            if cls_type_match is not None:
                cls_name = cls_type_match.group("name")
                if cls_name not in cls_names:
                    errors["bad_type"].append(f"{coll_name} doc {doc['id']}: doc type {cls_name} not in those allowed for {coll_name}, i.e. {cls_names}.")
                    cls_name = None
            elif len(cls_names) == 1:
                cls_name = cls_names[0]
            else:
                errors["no_type"].append(f"{coll_name} doc {doc['id']}: 'type' not set.")

            if cls_name is not None:        
                slot_map = cls_slot_map[cls_name]
                # Iterate over each key/value pair in the dictionary (document).
                for field, value in doc.items():
                    if field in slot_map:
                        if not isinstance(value, list):
                            value = [value]
                        for v in value:
                            note_doc_field_errors(value=v,field=field,doc=doc,coll_name=coll_name,errors=errors)                
                    else:
                        errors["bad_slot"].append(f"{coll_name} doc {doc['id']}: field '{field}' not a valid slot")
            pbar.update(1)
            n_errors = sum([len(v) for v in errors.values()])
            if n_errors > n_errors_cache:
                print(f"{n_errors} errors so far...")
                n_errors_cache = n_errors
    pbar.close()
    return errors

In [9]:
def note_doc_field_errors(value=None, field=None, doc=None, coll_name=None, errors=None):
    # No fields should be null-valued.
    # Example of how this may happen: JSON serialization from pydantic models may set optional fields to `null`.
    if value is None:
        errors["is_null"].append(f"{coll_name} doc {doc['id']}: field {field} is null.")

In [10]:
errors = collect_errors(note_doc_field_errors)
print(errors)

  0%|          | 0/2351449 [00:00<?, ?it/s]

{'bad_type': [], 'no_type': [], 'bad_slot': [], 'is_null': []}


## Materialize single-collection view of database

The `alldocs` collection associates each database document's `id` with not only its class (via that document's `type` field) but also with all ancestors of the docuement's class.

The set-of-classes association is done by setting the `type` field in an `alldocs` document to be a list, which facilitates filtering by type using the same strutured query forms as for upstream schema collections. The first element of the `type` list *must* correspond to the source document's asserted class; this is so that validation code can determine the expected range of document slots, as slot ranges may be specialized by a class (via linkml "slot_usage").

To keep the `alldocs` collection focused on supporting referential-integrity checking, only document-reference-ranged slots from source documents are copied to an entity's corresponding `alldocs` materialization. 

In [11]:
# Any ancestor of a document class is a document-referenceable range, i.e., a valid range of a document-reference-ranged slot.
document_referenceable_ranges = set(chain.from_iterable(schema_view.class_ancestors(cls_name) for cls_name in document_class_names))

document_reference_ranged_slots = defaultdict(list)
for cls_name, slot_map in cls_slot_map.items():
    for slot_name, slot in slot_map.items():
        if str(slot.range) in document_referenceable_ranges:
            document_reference_ranged_slots[cls_name].append(slot_name)

In [12]:
def doc_cls(doc, coll_name=None):
    """Return unprefixed name of document class.

    Try to get from doc['type'] (lopping off 'nmdc:' prefix).
    Else, if can unambiguously infer type given coll_name, use that.
    Else, return None.
    """
    if 'type' in doc:
        return doc['type'][5:] # lop off "nmdc:" prefix
    elif coll_name and len(collection_name_to_class_names[coll_name]) == 1:
        return collection_name_to_class_names[coll_name][0]

In [13]:
# Drop any existing `alldocs` collection (e.g. from previous use of this notebook).
mdb.alldocs.drop()

# Set up progress bar
n_docs_total = sum(mdb[name].estimated_document_count() for name in collection_names)
pbar = tqdm(total=n_docs_total)

for coll_name in collection_names:
    pbar.set_description(f"processing {coll_name}...")
    requests = []
    for doc in mdb[coll_name].find():
        doc_type = doc_cls(doc, coll_name=coll_name)
        slots_to_include = ["id"] + document_reference_ranged_slots[doc_type]
        new_doc = pick(slots_to_include, doc)
        new_doc["type"] = schema_view.class_ancestors(doc_type)
        requests.append(InsertOne(new_doc))
        if len(requests) == 1000: # ensure bulk-write batches aren't too huge
            result = mdb.alldocs.bulk_write(requests, ordered=False)
            pbar.update(result.inserted_count)
            requests.clear()
    if len(requests) > 0:
        result = mdb.alldocs.bulk_write(requests, ordered=False)
        pbar.update(result.inserted_count)
pbar.close()

# Prior to re-ID-ing, some IDs are not unique across Mongo collections (eg nmdc:0078a0f981ad3f92693c2bc3b6470791)

# Ensure unique id index for `alldocs` collection.
# The index is sparse because e.g. nmdc:FunctionalAnnotationAggMember documents don't have an "id".
mdb.alldocs.create_index("id", unique=True, sparse=True)

print("refreshed `alldocs` collection")

  0%|          | 0/2351449 [00:00<?, ?it/s]

refreshed `alldocs` collection


The resulting `alldocs` collection contains a copy of every document from every Mongo collection identified earlier. The copy has a subset of the key-value pairs as the original document, except that its `type` field contains a list of the names of its own class and all of its ancestor classes (whereas the original document's `type` field either is unset or contains its own class only).

## Validate

### Check referential integrity

In this cell, we populate two lists:

- `errors.not_found`: a list of "naive" errors
- `errors.invalid_type`: a list of (hierarchy-aware) type errors (document was found, but is of an invalid type)

Reference: https://linkml.io/linkml/developers/schemaview.html#linkml_runtime.utils.schemaview.SchemaView.class_induced_slots

In [14]:
def doc_assertions(limit=0):
    """Yields batches of 1000 assertions to greatly speed up processing."""
    # Initialize progress bar.
    pbar = tqdm(total=(mdb.alldocs.estimated_document_count() if limit == 0 else limit))
    rv = []
    for doc in mdb.alldocs.find(limit=limit):
        # Iterate over each key/value pair in the dictionary (document).
        for field, value in doc.items():
            if field in ("_id", "id", "type"):
                continue
            slot_range = str(cls_slot_map[doc["type"][0]][field].range) # assumes upstream doc type is listed first.
            if not isinstance(value, list):
                value = [value]
            for v in value:
                rv.append({
                    "id": doc.get("id", doc["_id"]),
                    "id_is_nmdc_id": "id" in doc,
                    "field": field,
                    "value": v,
                    "slot_range": slot_range,
                })
                if len(rv) == 1000:
                    yield rv
                    rv.clear()
        pbar.update(1)
    yield rv
    pbar.close()

In [15]:
from pprint import pprint

def doc_field_value_errors(assertions):
    errors = {"not_found": [], "invalid_type": []}
    assertions_by_referenced_id_value = defaultdict(list)
    for a in assertions:
        assertions_by_referenced_id_value[a["value"]].append(a)
    doc_id_types = {}
    for d in list(mdb.alldocs.find({"id": {"$in": list(assertions_by_referenced_id_value.keys())}}, {"_id": 0, "id": 1, "type": 1})):
        doc_id_types[d["id"]] = d["type"]

    for id_value, id_value_assertions in assertions_by_referenced_id_value.items():
        if id_value not in doc_id_types:
            errors["not_found"].extend(id_value_assertions)
        else:
            for a in id_value_assertions:
                if a["slot_range"] not in doc_id_types[a["value"]]:
                    errors["invalid_type"].append(a)

    return errors


# Initialize "global" error lists.
errors = {"not_found": [], "invalid_type": []}

# Use a with statement to ensure threads are cleaned up promptly
with concurrent.futures.ThreadPoolExecutor(max_workers=None) as executor:
    future_to_errors = {executor.submit(doc_field_value_errors, das): das for das in doc_assertions()}
    for future in concurrent.futures.as_completed(future_to_errors):
        doc_asserts = future_to_errors[future]
        try:
            data = future.result()
        except Exception as exc:
            print("exception:", str(exc))
        else:
            errors["not_found"].extend(data["not_found"])
            errors["invalid_type"].extend(data["invalid_type"])

  0%|          | 0/3039449 [00:00<?, ?it/s]

## Results

Display the number errors in each list.

In [16]:
len(errors["not_found"]), len(errors["invalid_type"])
# results prior to re-id-ing: (4857, 23503)
# results prior to v10.5.5: (33, 20488)
# results with v10.5.5: (33, 6900)

(5, 45604)

Display a few errors from one of the lists, as an example.

In [17]:
{e["value"] for e in errors["not_found"]}

{'nmdc:dobj-11-cvcxxr53', 'nmdc:dobj-11-fg28a080', 'nmdc:dobj-11-gxgpbv06'}

In [18]:
errors["not_found"][:5]

[{'id': 'nmdc:wfmgan-11-w1d6gy98.1',
  'id_is_nmdc_id': True,
  'field': 'has_input',
  'value': 'nmdc:dobj-11-cvcxxr53',
  'slot_range': 'NamedThing'},
 {'id': 'nmdc:wfmgan-11-fmymf551.1',
  'id_is_nmdc_id': True,
  'field': 'has_input',
  'value': 'nmdc:dobj-11-fg28a080',
  'slot_range': 'NamedThing'},
 {'id': 'nmdc:wfmgan-11-3nkefn97.1',
  'id_is_nmdc_id': True,
  'field': 'has_input',
  'value': 'nmdc:dobj-11-gxgpbv06',
  'slot_range': 'NamedThing'},
 {'id': 'nmdc:wfmgan-11-fmymf551.1',
  'id_is_nmdc_id': True,
  'field': 'has_input',
  'value': 'nmdc:dobj-11-fg28a080',
  'slot_range': 'NamedThing'},
 {'id': 'nmdc:wfmgan-11-3nkefn97.1',
  'id_is_nmdc_id': True,
  'field': 'has_input',
  'value': 'nmdc:dobj-11-gxgpbv06',
  'slot_range': 'NamedThing'}]

Display an example `invalid_type` errors for each of the set of expected types that are not being found:

In [19]:
slot_range_examples = {}
for e in errors["invalid_type"]:
    slot_range_examples[e["slot_range"]] = e

for ex in slot_range_examples.values():
    print(ex)

{'id': 'nmdc:dobj-11-xt088e26', 'id_is_nmdc_id': True, 'field': 'was_generated_by', 'value': 'nmdc:omprc-11-ymxzx274', 'slot_range': 'WorkflowExecution'}


Spot check one of those errors.

In [20]:
# OmicsProcessing is not subclass of Activity
mdb.alldocs.find_one({"id": "nmdc:omprc-11-sxze4w22"})

{'_id': ObjectId('66edad78007ef07eb670a09d'),
 'id': 'nmdc:omprc-11-sxze4w22',
 'has_input': ['nmdc:bsm-11-978cs285'],
 'has_output': ['nmdc:dobj-11-1epz0d53'],
 'associated_studies': ['nmdc:sty-11-28tm5d36'],
 'instrument_used': ['nmdc:inst-14-mwrrj632'],
 'type': ['MassSpectrometry',
  'DataGeneration',
  'PlannedProcess',
  'NamedThing']}

In [21]:
# ProcessedSample is not subclass of Biosample
mdb.alldocs.find_one({"id": "nmdc:procsm-11-v5sykd35"})

{'_id': ObjectId('66edad78007ef07eb67078c8'),
 'id': 'nmdc:procsm-11-v5sykd35',
 'type': ['ProcessedSample', 'MaterialEntity', 'NamedThing']}