<a href="https://colab.research.google.com/github/hacksaremeta/IS-Sentence-Completion/blob/main/src/dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
# DataManager Class
- Provides save and load functionality for datasets in json format

In [None]:
import os, json, logging
from Bio import Entrez, Medline

class DataManager():
    """Provides save and load functionality for datasets in json format"""
    def __init__(self, email, root_dir):
        self.email = email
        self.root_dir = root_dir
        self.log = logging.getLogger(self.__class__.__name__)

    def _exists_dataset(self, name):
        """Checks whether a dataset with the given name exists"""
        if not os.path.isdir(self.root_dir):
            return False
            
        for file in os.listdir(self.root_dir):
            if file.endswith(".json"):
                with open(os.path.join(self.root_dir, file), 'r') as f:
                    content = json.load(f)
                    if content["name"] == name:
                        return True
        return False

    def _fetch_papers(self, query : str, limit : int) -> 'list[dict]':
        """Retrieves data from PubMed"""
        Entrez.email = self.email
        record = Entrez.read(Entrez.esearch(db="pubmed", term=query, retmax=limit))
        idlist = record["IdList"]

        self.log.info("\nFound %d records for %s." % (len(idlist), query.strip()))

        records = Medline.parse(Entrez.efetch(db="pubmed", id=idlist, rettype="medline", retmode = "text"))
        return list(records)

    def _fetch_abstracts(self, query : str, limit : int) -> 'list[str]':
        """Retrieves abstracts from PubMed"""
        papers = self._fetch_papers(query, limit)
        list_of_abstracts = [p['AB'] for p in papers]

        return list_of_abstracts
        
    def create_dataset(self, queries : 'list[str]', name : str, limit=50, overwrite=False) -> None:
        """
        Wraps other methods in this class
        Creates a dataset from multiple queries
        Does nothing if the dataset is already present (param overwrite)
        Limits every query to <limit> results
        """
        exists_dataset = self._exists_dataset(name)
        if not exists_dataset or (exists_dataset and overwrite):
            self.log.info("Dataset does not exist, fetching from PubMed...")

            res = dict()
            res["name"] = name
            res["data"] = list()
            
            for q in queries:
                q_data = dict()
                q_data["query"] = q
                q_data["abstracts"] = self._fetch_abstracts(q, limit)
                res["data"].append(q_data)
            
            self._save_dataset(res, name)
        else:
            self.log.info("Dataset already exists, skipping fetch")

    def _save_dataset(self, dataset: dict, name : str) -> None:
        """
        Creates a file <name>.json in the dataset directory
        For JSON file structure see below
        Param dataset has a structure analogous to the JSON file
        """
        if not os.path.isdir(self.root_dir):
            os.makedirs(self.root_dir)

        with open(os.path.join(self.root_dir, name + ".json"), 'w') as f:
            json.dump(dataset, f, indent=2)
        
    def load_full_dataset(self, name : str) -> 'list[str]':
        """
        Finds the file that matches given <name> in JSON information,
        parses it, loading all abstracts into a list (one string for each abstract)
        and returns it (Error if dataset doesn't exist)
        """
        # TODO: <use Python json to load dataset from file>
        pass

    def load_query_from_dataset(self, name : str, query : str) -> 'list[str]':
        """Like load_full_dataset but only loads abstracts for a single query"""
        # TODO: <use Python json to load single query from dataset>
        pass
    

In [None]:
# Usage example
if __name__ == "__main__":
    # Set log level
    logging.basicConfig(level=logging.DEBUG, format='[%(levelname)s] %(name)s: %(message)s')

    # Create DataManager in '../res/datasets' folder
    data_folder = os.path.join("..", "res", "datasets")
    dman = DataManager("hexameter.trash@gmail.com", data_folder)

    dataset_name = "RNA Dataset"
    queries = ["RNA", "mRNA", "tRNA"]

    # Gather maximum of 100 abstracts for each query
    # I would suggest around 5 - 20 abstracts in total for the small data sets
    # and maybe 500 - 5000 for the final ones but we'll have to test
    # since that depends on how long it takes to train the network
    # This only queries PubMed if data if the data is not already present
    dman.create_dataset(queries, dataset_name, 5)

    # Load the dataset
    abstracts = dman.load_full_dataset(dataset_name)
    abstracts_mrna = dman.load_query_from_dataset(dataset_name, queries[1])

    # Do stuff with abstracts
    pass
