# Command Line Interface

> Convert spectra to wavelet images.


In [None]:
#| default_exp cli

In [None]:
#| export
from fastcore.script import *
import yaml
from pathlib import Path
from typing import Optional, List

import numpy as np
from uhina.wavelets import (
    convert_to_wavelet_images, 
    create_image_target_csv, 
    estimate_percentiles)
from sklearn.pipeline import Pipeline
from uhina.preprocessing import SNV, TakeDerivative
from uhina.loading import LoaderFactory

In [None]:
#| exports
@call_parse
def main(
    config: Path, # Path to the configuration file
    ):
    "Convert spectra to wavelet images using configuration from a file."
    
    cfg = load_config(config)
    params = extract_params(cfg)
    
    loader = get_loader(params)
    X, y, wavenumbers, smp_idx, ds_name, ds_label = loader.load_data(params['analytes'])
    
    X_trans = preprocess_data(X)
    
    create_output_files(X_trans, y, wavenumbers, smp_idx, params)

In [None]:
#| exports
def load_config(config_path: Path) -> dict:
    "Load the configuration from a YAML file."
    with open(config_path, 'r') as f:
        return yaml.safe_load(f)

In [None]:
#| exports
def extract_params(cfg: dict) -> dict:
    "Extract parameters from the configuration."
    return {
        'src': cfg['src'],
        'dir_out': cfg['dir_out'],
        'img_dir': cfg.get('img_dir', 'im'),
        'dataset': cfg.get('dataset', 'ossl'),
        'spectra_type': cfg.get('spectra_type', 'mir'),
        'analytes': cfg.get('analytes', 'k.ext_usda.a725_cmolc.kg'),
        'n_samples': cfg.get('n_samples'),
        'batch_size': cfg.get('batch_size', 10)
    }

In [None]:
#| exports
def get_loader(params: dict):
    "Get the loader from the parameters."
    loader_kwargs = {'spectra_type': params['spectra_type']} if params['dataset'] == 'ossl' else {}
    return LoaderFactory.get_loader(params['src'], params['dataset'], **loader_kwargs)

In [None]:
#| exports
def preprocess_data(X):
    "Preprocess the data."
    pipe = Pipeline([
        ('SNV', SNV()),
        ('Derivative', TakeDerivative())
    ])
    return pipe.fit_transform(X)

In [None]:
#| exports
def create_output_files(X_trans, y, wavenumbers, smp_idx, params):
    "Create the output files."
    print(f'Creating image target csv in {params["dir_out"]} ...')
    create_image_target_csv(smp_idx, y, n_samples=params['n_samples'], output_dir=params['dir_out'])  
    
    print(f'Creating wavelet images in {Path(params["dir_out"])/params["img_dir"]} ...')
    convert_to_wavelet_images(X_trans, smp_idx, wavenumbers, 
                              output_dir=Path(params['dir_out'])/params['img_dir'], 
                              n_samples=params['n_samples'], batch_size=params['batch_size'])