In [4]:
from pymongo import MongoClient
from pprint import pprint
from termcolor import colored
import json
from bson import BSON
import functools as f

In [8]:
username = 'admin'
password = 'adminadmin'
host = 'localhost'
port = 27017
auth_source = 'admin'

uri = f"mongodb://{username}:{password}@{host}:{port}/?authSource={auth_source}"

In [22]:
client = MongoClient()


In [23]:
client.list_database_names()
db = client['epilepsy-data']

In [24]:
for x in db['events'].find():
    pprint(x)

{'_id': ObjectId('666bf6ae39b9736541209513'),
 'crisisType': 'FWIA',
 'description': 'Verbose',
 'eventID': 'YYYY',
 'lobe': 'temporal',
 'offset': '2018-01-10T06:16:30Z',
 'onset': '2018-01-10T06:14:00Z',
 'recordID': 'XXXX'}


In [None]:
''' 
collection: collection to query
field: data field to be queried
value: value of field to find
fetch: which level to fetch (record, session, patient)'''
def query_event_field(collection, field, value, fetch):
    # For now, only finds stuff in events and fetches records.
    data = db['events']
    query_res = data.find({field:value})
    query_records = [x['recordID'] for x in query_res]
    query_records = [db['records'].find({'recordID':recordID}) for recordID in query_records]
    pprint(query_records)

    if fetch=='record':
        return query_records
    

In [None]:
for x in query_event_field('events', 'crisisType', 'FWIA', 'record'):
    pprint([y for y in x])

[<pymongo.cursor.Cursor object at 0x112667220>]
[{'_id': ObjectId('666bf78539b9736541209522'),
  'duration_s': 2000,
  'events': 'YYYY',
  'recordFileName': 'example.txt',
  'recordID': 'XXXX',
  'samplingFreq_Hz': 100,
  'sessionID': 'MSOL',
  'startTime': '2018-01-10T06:14:00Z'}]


In [13]:
levels = ['patients', 'sessions', 'records', 'events']
def go_up_level(record: dict, start_level: str) -> dict:
    above_level = levels[levels.index(start_level)-1]
    above_level_key = above_level[:-1]+'ID'
    above_level_ID = record[above_level_key]
    above_level_record = db[above_level].find({above_level_key:above_level_ID})[0]
    return above_level_record

In [14]:
def go_up_to_level(record: dict, from_level: str, to_level: str) -> dict:
    from_level_ind = levels.index(from_level)
    to_level_ind = levels.index(to_level)
    for i in range(from_level_ind, to_level_ind, -1):
        record = go_up_level(record, levels[i])
    return record

In [37]:
def go_down_level(record: dict, start_level: str) -> list[dict]:
    below_level_index = levels.index(start_level)+1
    if below_level_index >= len(levels):
        print('No level below')
        return
    below_level = levels[below_level_index]
    
    below_level_key = below_level
    below_level_IDs = record[below_level_key]
    print(below_level_IDs)
    print(below_level_key)
    key_in_below_level = below_level[:-1]+'ID'
    below_level_records = [db[below_level].find({key_in_below_level:below_level_ID})[0] for below_level_ID in below_level_IDs]
    return below_level_records

In [35]:
def go_down_to_level(record: dict, from_level: str, to_level: str) -> list[dict]:
    below_key = levels[levels.index(from_level)+1]
    from_level_ind = levels.index(from_level)
    to_level_ind = levels.index(to_level)
    records = record[below_key]
    for i in range(from_level_ind, to_level_ind):
        records_lists = [go_down_level(record, levels[i]) for record in records]
        records = sum(records_lists, [])
    return records



In [39]:
record = db['sessions'].find({'patientID': 'MSOL'})[0]
go_down_level(record, 'sessions')

['XXXX']
records


[{'_id': ObjectId('666bf78539b9736541209522'),
  'recordID': 'XXXX',
  'samplingFreq_Hz': 100,
  'startTime': '2018-01-10T06:14:00Z',
  'duration_s': 2000,
  'recordFileName': 'example.txt',
  'sessionID': 'MSOL',
  'events': 'YYYY'}]

In [28]:
record = db['records'].find({'recordID':'XXXX'})[0]
go_up_to_level(record, 'records', 'patients')

{'_id': ObjectId('666bf75839b9736541209519'),
 'age': '+18',
 'gender': 'f',
 'diagnosis': {'type': 'refractory', 'other': 'mesial sclerosis'},
 'aura': False,
 'commorbidities': ['hipertension',
  'depression',
  'dermatitis',
  'consanguinity'],
 'sessions': ['MSOL'],
 'patientID': 'MSOL'}

In [None]:
# query: number of crisis

def get_crisis_number(query_operator: str, no: int):
    # won't sum number of crises over multiple sessions of the same patient
    return get_items_numerical('sessions', 'patients', 'no_clinical_seizures', query_operator, no)
    

In [None]:
get_session_duration('$gt', 1)

TypeError: get_items_numerical() missing 3 required positional arguments: 'field', 'query_operator', and 'no'

In [None]:
def get_session_duration(query_operator: str, no: int):
    return get_items_numerical('sessions', 'patients')

In [40]:
def get_items_numerical(source_lvl: str, return_lvl: str, field: str, query_operator: str, no: int):
    '''Returns list of patients (dict) matching given query of crisis number
    query_operator options: 
    - "$lt"
    - "$lte"
    - "$gt"
    - "$gte"
    - "$eq"
    - ...
    '''
    source_match = db[source_lvl].find({field: {query_operator: no}})
    return_match = [go_up_to_level(item, source_lvl, return_lvl) for item in source_match]
    return return_match