# FitsFileGroup and Database

Currently astropop uses simple [Astropy Table](https://docs.astropy.org/en/stable/table/) to store headers data in `FitsFileGroup`. This can be fast, but have the problem to be in-memory only.

So, we will investigate the use of SQL database to store the header table and filtering.

Fortunately, Python has a built-in sqlite module called [sqlite3](https://docs.python.org/3/library/sqlite3.html), that we can use here.

## SQLite3 Wrapper

First of all, lest implement a wrapper to do a more object-oriented sql work. This is limited to a single table database and just some functions and interactions.

In [None]:
from astropy.io import fits
from astropop.file_collection import list_fits_files
from astropop.logger import logger
from astropop._db import SQLDatabase

logger.setLevel('INFO')

Is everything working?

In [None]:
db = SQLDatabase(':memory:')
db.add_table('files')
tab = db['files']

In [None]:
files = list_fits_files('/home/julio/19jan30', fits_extensions=['fits.gz'])
for i, f in enumerate(files):
    logger.debug("file %i from %i", i+1, len(files))
    header = dict(fits.getheader(f))
    header.pop('COMMENT', None)
    header.pop('HISTORY', None)
    db.add_row('files', dict(header), add_columns=True)

In [None]:
print('columns: ', db.column_names('files'))
print('files:', len(db['files']))

In [None]:
print(db['files', 'jd'][0:10])
print(db['files'][34])
print(db['files']['lamina'][0:10])
print(db['files', 0]['lamina'])
print(db['files', 'lamina'][0:10])

It looks like everything is working properly. All 427 files were added and all columns are present. What is int seems to be returned as int and what is float seems to be float. The filtering is also working too.

Notes:

- SQL tables don't like `-` character in name. So it is replaced by `_`. Same for spaces. This create possible conflicts, but may be a smaller problem in practice.

# FitsFileGroup implementations

Lets try our two `FitsFileGroup` implementations. One using `astropy.table.Table`, another using our SQL wrapper.

## Table implementation

In [None]:
import numpy as np
from astropy.table import Table
from astropop.file_collection import list_fits_files

from astropop.fits_utils import _fits_extensions, \
                                _fits_extensions_with_compress
from astropop.framedata import check_framedata
from astropop.py_utils import check_iterable
from astropop.logger import logger


def create_table_summary(headers, n):
    """Create a table summary of headers.

    Parameters
    ----------
    headers: iterator
        Iterator for a list of header files.
    n: int
        Number of headers to iterate.
    """
    summary_dict = {}
    for i, head in enumerate(headers):
        logger.debug('Reading file %i from %i', i, n)
        keys = head.keys()
        for k in keys:
            k_lower = k.lower()
            if k_lower in ('history', 'comment'):
                logger.debug('%s key ignored', k)
                continue
            if k_lower not in summary_dict.keys():
                summary_dict[k_lower] = [None]*n
            summary_dict[k_lower][i] = head.get(k)

    return Table(summary_dict)


def gen_mask(table, keywords):
    """Generate a mask to be applyed in the filtering."""
    if len(table) == 0:
        return []

    t = Table(table)

    mask = np.ones(len(t), dtype=bool)
    for k, v in keywords.items():
        if not check_iterable(v):
            v = [v]
        k = k.lower()
        if k not in t.colnames:
            t[k] = [None]*len(t)
        nmask = [t[k][i] in v for i in range(len(t))]
        mask &= np.array(nmask)

    return mask


class FitsFileGroup_Table():
    """Easy handle groups of fits files."""

    def __init__(self, location=None, files=None, ext=0,
                 compression=False, **kwargs):
        self._ext = ext
        self._extensions = kwargs.get('fits_ext',
                                      _fits_extensions_with_compress
                                      if compression else _fits_extensions)

        self._include = kwargs.get('glob_include')
        self._exclude = kwargs.get('glob_exclude')
        self._keywords = kwargs.get('keywords')

        if location is None and files is None:
            raise ValueError("You must specify a 'location'"
                             "or a list of 'files'")
        if files is None and location is not None:
            files = list_fits_files(location, self._extensions,
                                    self._include, self._exclude)

        self._files = files
        self._location = location

        self._summary = create_table_summary(self.headers(), len(self))

    def __len__(self):
        return len(self.files)

    @property
    def files(self):
        return self._files.copy()

    @property
    def location(self):
        return self._location

    @property
    def keywords(self):
        return self._keywords

    @property
    def summary(self):
        return Table(self._summary)

    def __copy__(self, files=None, summary=None):
        nfg = FitsFileGroup_Table.__new__(FitsFileGroup_Table)
        for k, v in self.__dict__.items():
            if k == '_summary':
                nfg._summary = summary or self._summary
            elif k == '_files':
                nfg._files = files if files is not None else self._files
            else:
                nfg.__dict__[k] = v
        return nfg

    def __getitem__(self, item):
        if isinstance(item, str):
            # string will be interpreted as collumn name
            if item.lower() not in self._summary.colnames:
                raise KeyError(f'Column {item} not found.')
            return self._summary.columns[item.lower()]

        # returning FitsFileGroups
        if isinstance(item, (int, np.integer)):
            # single index will be interpreted as a single file group
            return self.__copy__(files=[self._files[item]],
                                 summary=self._summary[item])
        if (isinstance(item, slice)):
            files = self._files[item]
            summ = self._summary[item]
            return self.__copy__(files=files, summary=summ)
        if isinstance(item, (np.ndarray, list, tuple)):
            item = np.array(item)
            if len(item) == 0:
                return self.__copy__(files=[], summary=self._summary[item])
            files = list(np.take(self._files, item))
            summ = self._summary[item]
            return self.__copy__(files=files, summary=summ)

        raise KeyError(f'{item}')

    def filtered(self, keywords=None):
        """Create a new FileGroup with only filtered files."""
        where = np.where(gen_mask(self._summary, keywords))[0]
        return self[where]

    def values(self, keyword, unique=False):
        """Return the values of a keyword in the summary.

        If unique, only unique values returned.
        """
        if keyword not in self.summary.colnames:
            if unique:
                n = 1
            else:
                n = len(self.summary)
            return [None]*n
        if unique:
            return list(set(self.summary[keyword].tolist()))
        return self.summary[keyword].tolist()

    def add_column(self, name, values, mask=None):
        """Add a new column to the summary."""
        if not check_iterable(values):
            values = [values]*len(self.summary)
        elif len(values) != len(self.summary):
            values = [values]*len(self.summary)

        self.summary[name] = values
        self.summary[name].mask = mask

    def _intern_yelder(self, files=None, ext=None, ret_type=None,
                       **kwargs):
        """Iter over files."""
        ext = ext if ext is not None else self._ext
        files = files if files is not None else self._files
        for i in files:
            if ret_type == 'header':
                yield fits.open(i, **kwargs)[ext].header
            if ret_type == 'data':
                yield fits.open(i, **kwargs)[ext].data
            if ret_type == 'hdu':
                yield fits.open(i, **kwargs)[ext]
            if ret_type == 'framedata':
                yield check_framedata(i, hdu=ext, **kwargs)

    def hdus(self, ext=None, **kwargs):
        """Read the files and iterate over their HDUs."""
        return self._intern_yelder(ext=ext, ret_type='hdu', **kwargs)

    def headers(self, ext=None, **kwargs):
        """Read the files and iterate over their headers."""
        return self._intern_yelder(ext=ext, ret_type='header', **kwargs)

    def data(self, ext=None, **kwargs):
        """Read the files and iterate over their data."""
        return self._intern_yelder(ext=ext, ret_type='data', **kwargs)

    def framedata(self, ext=None, **kwargs):
        """Read the files and iterate over their data."""
        return self._intern_yelder(ext=ext, ret_type='framedata', **kwargs)

In [None]:
f = FitsFileGroup_Table(files=files)
print(len(f.filtered({'object': 'HD126593'}).summary))

## SQL implementation

In [None]:
import os
from pathlib import Path

_headers = 'headers'
_metadata = 'astropop_metadata'
_files_col = '__file'

class FitsFileGroup_SQL():
    """Easy handle groups of fits files."""

    def __init__(self, location=None, files=None, ext=0,
                 compression=False, database=':memory:', **kwargs):
        self._ext = ext
        self._extensions = kwargs.get('fits_ext')
        self._include = kwargs.get('glob_include')
        self._exclude = kwargs.get('glob_exclude')

        self._db = SQLDatabase(database)        
        if database == ':memory:':
            self._db_dir = None
        else:
            self._db_dir = Path(database).resolve().parent

        self._read_db(files, location, compression, kwargs.get('update', 0))

    def __len__(self):
        return len(self.files)

    def _list_files(self, files, location, compression):
        extensions = self._extensions
        if extensions is None:
            if compression:
                extensions = _fits_extensions_with_compress
            else:
                extensions = _fits_extensions

        if files is not None and location is not None:
            raise ValueError('You can only specify either files or location.')
        if files is None and location is not None:
            files = list_fits_files(location, extensions,
                                    self._include, self._exclude)
        return files
    
    def _read_db(self, files, location, compression, update=False):
        """Read the database and generate the summary if needed."""
        initialized = _metadata in self._db.table_names

        if not initialized:
            self._db.add_table(_metadata)
            self._db.add_row(_metadata, {'DB_API_MAJ': 1,
                                         'DB_API_MIN': 0,
                                         'GLOB_INCLUDE': self._include,
                                         'GLOB_EXCLUDE': self._exclude,
                                         'LOCATION': location,
                                         'COMPRESSION': compression,
                                         'FITS_EXT': self._extensions,
                                         'EXT': self._ext},
                             add_columns=True)

        self._include = self._db[_metadata, 'glob_include'][0]
        self._exclude = self._db[_metadata, 'glob_exclude'][0]
        self._extensions = self._db[_metadata, 'fits_ext'][0]
        self._ext = self._db[_metadata, 'ext'][0]
        self._location = self._db[_metadata, 'location'][0]
        self._compression = self._db[_metadata, 'compression'][0]

        if update or not initialized:
            self.update(files, location, compression)

    @property
    def files(self):
        files = self._db[_headers, _files_col].values
        if self._db_dir is not None:
            return [os.path.join(self._db_dir, f) for f in files]
        return files

    @property
    def summary(self):
        return self._db[_headers].as_table()

    def __copy__(self, files=None, db=None):
        raise NotImplementedError

    def filtered(self, keywords):
        """Create a new FileGroup with only filtered files."""
        raise NotImplementedError
    
    def update(self, files=None, location=None, compression=False):
        """Update the database with the current files."""
        if _headers in self._db.table_names:
            self._db.drop_table(_headers)

        self._db.add_table(_headers)
        location = location or self._location
        compression = compression or self._compression
        files = self._list_files(files, location, compression)
        for i, f in enumerate(files):
            logger.debug('reading file %i from %i', i, len(files))
            self.add_file(f)

    def values(self, keyword, unique=False):
        """Return the values of a keyword in the summary.

        If unique, only unique values returned.
        """
        vals = self._db[_headers, keyword].values()
        if unique:
            vals = list(set(vals))
        return vals

    def add_column(self, name, values=None):
        """Add a new column to the summary."""
        self._db.add_column(_headers, name, data=values)
    
    def add_file(self, file):
        """Add a new file to the group."""
        header = fits.getheader(file, ext=self._ext)
        logger.debug('reading file %s', file)
        if self._db_dir is not None:
            file = os.path.relpath(file, self._db_dir)
        hdr = {_files_col: file}
        hdr.update(dict(header))
        hdr.pop('COMMENT', None)
        hdr.pop('HISTORY', None)
        self._db.add_row(_headers,  hdr, add_columns=True)

    def _intern_yelder(self, ext=None, ret_type=None, **kwargs):
        """Iter over files."""
        ext = ext if ext is not None else self._ext
        for i in self.files:
            if ret_type == 'header':
                yield fits.open(i, **kwargs)[ext].header
            if ret_type == 'data':
                yield fits.open(i, **kwargs)[ext].data
            if ret_type == 'hdu':
                yield fits.open(i, **kwargs)[ext]
            if ret_type == 'framedata':
                yield check_framedata(i, hdu=ext, **kwargs)

    def hdus(self, ext=None, **kwargs):
        """Read the files and iterate over their HDUs."""
        return self._intern_yelder(ext=ext, ret_type='hdu', **kwargs)

    def headers(self, ext=None, **kwargs):
        """Read the files and iterate over their headers."""
        return self._intern_yelder(ext=ext, ret_type='header', **kwargs)

    def data(self, ext=None, **kwargs):
        """Read the files and iterate over their data."""
        return self._intern_yelder(ext=ext, ret_type='data', **kwargs)

    def framedata(self, ext=None, **kwargs):
        """Read the files and iterate over their data."""
        return self._intern_yelder(ext=ext, ret_type='framedata', **kwargs)

In [None]:
logger.setLevel('INFO')
fg = FitsFileGroup_SQL(location='/home/julio/19jan30', database='/home/julio/19jan30/test.db', compression=True, update=False)

In [None]:
fg.summary