# 02 Predict
In this step, we will demonstrate how to input the data obtained from `01 Data Process` into the model for prediction, display the final prediction results, and generate the prediction file collection.

## Import the necessary packages
`config_util.py` has been introduced earlier and will not be elaborated here. `model_util.py` encompasses the model architecture and runtime computation methods. `dataset_util.py` involves data processing operations during model execution, including methods for batch data retrieval and displaying intermediate results.
<br>The aforementioned source files have all been provided, and readers can refer to them as needed.

In [3]:
from model_util import *
from dataset_util import *
from config_util import Config

## predict() method
The `predict()` method takes a `config` class as input for prediction. Within the function, `model_index = 'ESM_06131214'` indicates the model index to be used, with the default being ESM_06131214, which corresponds to GloEC-3. `dataset_type` indicates the type of data to be predicted. Invoking the `predict()` method will automatically read the pre-trained model parameters and perform predictions using the `run()` method.

In [4]:
def predict(config):
    model_index = 'ESM_06131214'
    label_map = get_label_map(config)

    dataset_type = 'input_sample'
    nc_data_loader = get_type_dataloader(config, label_map, type=dataset_type)
    dataset_name = dataset_type + '_' + model_index


    model = get_model(config, label_map, class_num=len(label_map))

    model.load_state_dict(torch.load('../Save_model/' + model_index + '.pth', map_location=torch.device('cpu')))

    model.eval()
    with torch.no_grad():
        perform_dict, name_list, predict_result = run(nc_data_loader, model, label_map, config, dataset_name=dataset_name)

    print("-------------Show predict result---------------")
    print("entry_name           predict EC number")
    for i in range(len(name_list)):
        print(name_list[i] + "           " + predict_result[i])
    print("----------------------------------------------")

## run() method
To utilize the model for prediction, you must provide `data_loader`, `model`, `label_map`, `config`, and `dataset_name`. They respectively represent the data loader, trained model, enzyme label set, relevant parameters, and the name of the dataset. The method returns prediction metrics and results.

In [5]:
def run(data_loader, model, label_map, config, dataset_name, split_area=None, label_num_dict=None):
    # split_area label_num_dict 
    predict_probs = []
    true_label = []
    name_list = []

    for i, (batch_esm, batch_str_label, name) in enumerate(data_loader):
        logits = model(batch_esm)

        predict_results = torch.sigmoid(logits)
        predict_probs.extend(predict_results.tolist())
        true_label += batch_str_label
        name_list += name

    # get predict file
    predict_result = get_predict_file(label_map, predict_probs, true_label, name_list, config, dataset_name=dataset_name, GET_DEYAIL=False)

    # compute
    if split_area == None:
        perform_dict = get_other_dataset_perform(label_map,
                                                 predict_file='../Data/predict_result/' + dataset_name + '.csv')
    else:
        perform_dict = get_kfold_dataset_perform(label_map,
                                                 predict_file='../Data/predict_result/' + dataset_name + '.csv',
                                                 split_area=split_area, label_num_dict=label_num_dict)
        
    return perform_dict, name_list, predict_result,

## Calling the method.
Call the above method within the `main()` function to prevent multithreading issues.
<br>Define a `config` instance and pass parameters through `predict`.
<br>The running result will be displayed as follows, and the prediction files will be saved to the target directory, which can be modified in the `config`.

In [8]:
if __name__ == '__main__':
    config = Config()
    print("------start to predict--------")
    predict(config)

using CPU training
------start to predict--------
Predict result is saved to --> ../Data/predict_result/input_sample_ESM_06131214.csv
-------------Show predict result---------------
entry_name           predict EC number
A0A067YMX8           2.4.1.207
A0A0K3AV08           2.7.11.1
A0A1D6K6U5           5.5.1.13
A1XSY8           2.3.2.27
A1ZA55           2.7.7.-
A2A5Z6           2.3.2.26
A2CEI6           3.1.26.-
A2TK72           3.4.24.-
A3KPQ7           3.2.1.35
A4FUD9           3.6.4.12
----------------------------------------------
