# Predict
Similar to training, prediction can be done via three interfaces:
- via python, `dss.predict.predict`
- via the command line, `dss predict`, with audio data from a wav file.
- the GUI - see the [GUI tutorial](/tutorials_gui/predict)

Prediction will:

- load the audio data and the network
- run inference to produce confidence scores (`class_probabilties`)
- post-process the confidence score to extract the times of events and label segments.


## Prediction using python

In [1]:
import numpy as np
from pprint import pprint
import scipy.io.wavfile
import dss.predict
help(dss.predict.predict)

Help on function predict in module dss.predict:

predict(x: <built-in function array>, model_save_name: str = None, verbose: int = 1, batch_size: int = None, model: tensorflow.python.keras.engine.training.Model = None, params: dict = None, event_thres: float = 0.5, event_dist: float = 0.01, event_dist_min: float = 0, event_dist_max: float = None, segment_thres: float = 0.5, segment_minlen: float = None, segment_fillgap: float = None, prepend_padding: bool = True)
    [summary]
    
    Usage:
    Calling predict with the path to the model will load the model and the
    associated params and run inference:
    `dss.predict.predict(x=data, model_save_name='tata')`
    
    To re-use the same model with multiple recordings, load the modal and params
    once and pass them to `predict`
    ```my_model, my_params = dss.utils.load_model_and_params(model_save_name)
    for data in data_list:
        dss.predict.predict(x=data, model=my_model, params=my_params)
    ```
    
    Args:
        

In [2]:
%%time
samplerate, x = scipy.io.wavfile.read('dat/dmel_song_rt.wav')
print(f"DeepSS requires [T, channels], but single-channel wave files are loaded with shape [T,] (data shape is {x.shape}).")
x = np.atleast_2d(x).T
events, segments, class_probabilties = dss.predict.predict(x, 
                                                           model_save_name='models/dmel_single_rt/20200430_201821',
                                                           verbose=2)

DeepSS requires [T, channels], but single-channel wave files are loaded with shape [T,] (data shape is (909003,)).
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
443/443 - 20s
CPU times: user 1min 36s, sys: 5.24 s, total: 1min 41s
Wall time: 24.7 s


### Outputs of `predict`
- `class_probabilties`: `[T, nb_classes]` including noise.
- `segments`: Labelled segments
    - `samplerate_Hz`: 
    - `names`: names of all segment types
    - `index`: indices of all segments types into class_probabiltiies
    - `probabilities = class_probabilites[:, index]`
    - `sequence`: sequence of segment names (one entry per detected segment). Excludes noise
    - `samples`: labelled sample trace (label of the sequence occupying each sample)
    - `onsets_seconds`, `offsets_seconds`, `durations_seconds`: Onsets, offsets, and duration of individual segmeents
- `events`: Detected events
    - `samplerate_Hz`: 
    - `index`: indices of all events types into class_probabiltiies
    - `names`: names of all event types
    - `probabilities`: probabilities (confidence scores) for detected events. Value of `class_probabilities` for the detected event index at each event time.
    - `seconds`: times (seconds) of detected events
    - `sequence`: sequence of event names (one per detected event).

## Prediction using command-line scripts
Will save the output to a h5 file ending in `_dss.h5` or specified via the `--save-filename` argument.

See [cli](/technical/cli) for a full list of arguments.

In [3]:
!dss predict dat/dmel_song_rt.wav models/dmel_single_rt/20200430_201821

INFO:root:   Loading data from dat/dmel_song_rt.wav.
INFO:root:   Annotating using model at models/dmel_single_rt/20200430_201821.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
Instructions for updating:
If using Keras pass *_constraint arguments to layers.
INFO:root:   Saving results to dat/dmel_song_rt_dss.h5.
INFO:root:Done.
[0m

In [4]:
import h5py

with h5py.File('dat/dmel_song_rt_dss.h5', mode='r') as f:
    print(list(f.keys()))

['class_probabilities', 'events', 'segments']
