In [None]:
from datetime import datetime
from typing import List

from Bio import Entrez
from pydash import get
from toolz import thread_first
from typeguard import typechecked

from proseflow.spec import *
from proseflow.utils import get_paths, tree_select_kv


@typechecked
def _get_pubmed_records(
    pmids: List[str],
    shape=None,
):  # -> List[Tuple[str, str]]:
    # ! TODO: move this into auto convert
    pmids_to_fetch = (
        [pmid.split("/")[-2] for pmid in pmids]
        if SPEC[PUBMED_IDS].match(pmids[0])
        else pmids
    )
    Entrez.email = "strasser.ms@gmail.com"

    # handle type is http.client.HTTPResponse
    handle = Entrez.efetch(
        db="pubmed",
        id=",".join(map(str, pmids_to_fetch)),
        rettype="xml",
        retmode="text",
    )
    records = Entrez.read(handle)
    return records


def get_info(record, keys_wanted):
    return thread_first(
        record,
        # (id_with_side, JSON),
        get_paths,
        list,
        # print,
        lambda paths: tree_select_kv(record, paths, keys_wanted),
        m_parse_flat_pubmed,
    )


def _get_pubmed_info(
    pmids: List[str],
    keys_wanted=[
        "PMID",
        "DateCompleted",
        "Journal",
        "PubDate",
        "AbstractText",
        "ChemicalList",
    ],
):
    records = _get_pubmed_records(pmids)
    return [get_info(r, keys_wanted) for r in records["PubmedArticle"]]


def _get_pubmed_abstracts(pmids: List[str]):
    return [
        get(record, "MedlineCitation.Article.Abstract.AbstractText")[0]
        for record in _get_pubmed_records(pmids)["PubmedArticle"]
    ]


"""
Achtung:

PubDate has form:
{
"Year":"1999",
"Month":"Nov", #stupid....
"Day":"09"
}
and DateCompleted has form:
{
"Year":"1999",
"Month":"12",
"Day":"09"
}

"""

MONTH_MAP = {
    "JAN": 1,
    "FEB": 2,
    "MAR": 3,
    "APR": 4,
    "MAY": 5,
    "JUN": 6,
    "JUL": 7,
    "AUG": 8,
    "SEP": 9,
    "OCT": 10,
    "NOV": 11,
    "DEC": 12,
}


def m_parse_flat_pubmed(flat_pm):
    """! make sure that all the key_wanted are represented here too ... otherwise key error"""
    # [{a:, b:}]
    if flat_pm.get("ChemicalList"):  # only
        # stringelement
        flat_pm["ChemicalList"] = [
            str(chem["NameOfSubstance"]) for chem in flat_pm["ChemicalList"]
        ]

    if flat_pm.get("MeshHeadingList"):
        flat_pm["MeshHeadingList"] = [
            mesh["DescriptorName"] for mesh in flat_pm["MeshHeadingList"]
        ]

    if flat_pm.get("DateCompleted"):
        date_comp = {
            key.lower(): int(num)
            for key, num in flat_pm["DateCompleted"].items()
            if key
        }
        flat_pm["DateCompleted"] = str(datetime(**date_comp).date())

    # TODO: not working with month --> int("Nov") errror
    # date_pub = {key.lower(): int(num) for key, num in flat_pm["PubDate"].items()}
    # date_pub["month"] = MONTH_MAP[date_pub["month"]] #Nov => 11
    # flat_pm["PubDate"] = datetime(**date_pub).date()

    # {"DAY":"11"} -> {"day": 11}

    if flat_pm.get("AbstractText"):
        flat_pm["AbstractText"] = flat_pm["AbstractText"][0]
    if flat_pm.get("JournalTitle"):
        flat_pm["JournalTitle"] = flat_pm["Journal"]["Title"]
    if flat_pm.get("PMID"):
        flat_pm["PMID"] = flat_pm["PMID"][0:]
    if flat_pm.get("Journal"):
        del flat_pm["Journal"]

    return flat_pm
