# This is a generic notebook for viewing model predictions

#### First, set up the preliminaries for viewing predictions.

In [None]:
%matplotlib inline

import tensorflow as tf 
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # only print errors
import numpy as np
import matplotlib.pyplot as plt

In [None]:
# set up logging
import logging
logging.basicConfig(level = logging.INFO)

In [None]:
%env RAMP_HOME=/content/drive/MyDrive/code/projects/ramp-staging-test/ramp-staging/colab

In [None]:
import sys
sys.path.append("..")
from pathlib import Path
RAMP_HOME = os.environ["RAMP_HOME"]

In [None]:
from ramp.data_mgmt.data_generator import test_batches_from_gtiff_dirs_with_names
from ramp.data_mgmt.display_data import display_img_mask_pred_batch, get_mask_from_prediction

#### Step 1: Set your training data location. 

In [None]:
!ls {RAMP_HOME}/data/TRAIN/myanmar

In [None]:
# sample data path
DATA_PATH = Path(RAMP_HOME) / "data/TRAIN/myanmar"

# path to all models built using this dataset
ALL_MODELS_PATH = DATA_PATH / "model-checkpts"

In [None]:
timestamp = "20220826-125633"
!tree {ALL_MODELS_PATH}/{timestamp}

#### Step 2: Set your saved model location.

In [None]:
# sample model path
# COMMENT ON MODEL TYPE HERE: multimask, extra augmentation, loss function weighting by class
MODEL_PATH = ALL_MODELS_PATH/f"{timestamp}/model_20220826-125633_002_0.962.tf"

#### Step 3: Set batchsize, input and output image sizes.

Batchsize is whatever number of files you'd like to display in a single image. I recommend keeping batch_size small for good conversion to pdfs (using 'nbconvert' or any other jupyter notebook rendering tool). 

Numbatches is the total number of image batches you'd like to display. This can be as large as you like. 

Image sizes should be set to their values during training. 



In [None]:
# sample configuration
NUMBATCHES = 10 # number of batches to display
BATCH_SIZE = 4 # size per batch -- I recommend keeping this small for better pdf output
INPUT_IMAGE_SIZE = (256, 256)
OUTPUT_IMAGE_SIZE = (256, 256)

test_img_dir = str(DATA_PATH / "valchips")

# mask type. Comment out the type of mask (binary vs. multichannel) that you are not using. 
test_mask_dir = str(DATA_PATH / "val-multimasks")
# test_mask_dir = str(DATA_PATH / "val-binmasks")

### END CONFIGURATION, begin code for viewing model predictions

In [None]:
# set up batches for display
test_batches = test_batches_from_gtiff_dirs_with_names(
                                            test_img_dir, 
                                            test_mask_dir, 
                                            BATCH_SIZE, 
                                            INPUT_IMAGE_SIZE, 
                                            OUTPUT_IMAGE_SIZE)

# load model
model = tf.keras.models.load_model(MODEL_PATH)

### Coding notes

#### This notebook uses a data generator that returns the filenames of the chips and masks as well as the chips and masks. 

You can iterate through batches multiple times to display more results. In this code, I iterate through and display two batches. 

##### Iterating over batches

You have to be a bit careful how you iterate through batches, or you'll get confusing bugs like I did. 

Iterate through batches using this code:

```
test_batches = test_batches_from_gtiff_dirs_with_names(
                                            test_img_dir, 
                                            test_mask_dir, 
                                            BATCH_SIZE, 
                                            INPUT_IMAGE_SIZE, 
                                            OUTPUT_IMAGE_SIZE)
                                            
# these test batches are streaming. Create an iterator for them.
iterator = iter(test_batches)
batch = iterator.get_next()
```

When you want to get a new batch, call iterator.get_next() again.

##### Structure of data in each batch

In each batch, batch[0] contains all the data associated with the image chips. It's a 2-tuple: the first element, batch[0][0], is the image batch tensor, and batch[0][1] is a list of image names in the batch. 

Batch[1] is data associated with masks: batch[1][0] is the mask batch tensor and batch[1][1] is a list of mask names.

In [None]:
# Take one batch from the data generator. 
iterator = iter(test_batches)
batch = iterator.get_next()
chips = batch[0][0]
masks = batch[1][0]
chipnames = batch[0][1] 
masknames = batch[1][1]
prediction = model.predict(chips)

# Predictions are one-hot encoded by default. 
# Binary predictions will have two channels. 
# Multichannel predictions will have four channels 
print(f"Prediction shape: {prediction.shape}")

# we flatten the channels (by taking their max) 
# to make a mask with only 1 channel from the prediction.
predmask = get_mask_from_prediction(prediction)
print(f"Sparse prediction shape: {predmask.shape}")


In [None]:
# display the output of the trained model.
print("LIST OF CHIP FILES")
print('\n'.join([name_tensor.numpy().decode('utf-8') for name_tensor in chipnames]))
display_img_mask_pred_batch(chips, masks, predmask)

In [None]:
def display_batch(model, iterator):
    batch = iterator.get_next()
    chips = batch[0][0]
    masks = batch[1][0]
    chipnames = batch[0][1] 
    prediction = model.predict(chips)
    predmask = get_mask_from_prediction(prediction)
    print("LIST OF CHIP FILES")
    print('\n'.join([name_tensor.numpy().decode('utf-8') for name_tensor in chipnames]))
    display_img_mask_pred_batch(chips, masks, predmask)
    return

In [None]:
for ii in range(NUMBATCHES):
    display_batch(model, iterator)