# Visualization

Test ClusterWay to automatically detect waypoints from a occupancy grid map of a row based crop and cluster them.

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
import numpy as np
import tensorflow as tf

from utils.tools import load_config, deepWayLoss, clusterLoss
from utils.models import build_deepway, build_clusterway
from utils.dataset import load_dataset_test
from utils.train import Trainer
from utils.visualization import RotationalPredictor

In [3]:
# select a GPU and set memory growth 
gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[1], 'GPU')
tf.config.experimental.set_memory_growth(gpus[1], True)

In [4]:
# important paths and names
PATH_DIR = os.path.abspath('.')
PATH_WEIGHTS = os.path.join(PATH_DIR, 'bin')

TEST_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/straight/test')
TEST_CURVED_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/curved/test') # curved
SATELLITE_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/satellite/')
SATELLITE_CURVED_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/satellite_curved') # curved

config_file = 'config.json'
config = load_config(config_file)

In [5]:
#select model

#name_model = 'deep_way_pretrained'
name_model = 'cluster_way_pretrained'
CURVED = False
I = 0

name_model += '_curved' if CURVED else ''
name_model += f'_{I}'
name_model

'cluster_way_pretrained_0'

# 1 Import the Test Dataset

In [6]:
#choose target dataset
img_folder = SATELLITE_DATA_PATH

X_test, y_test, y_cluster_test, df_waypoints_test = load_dataset_test(img_folder, config)
print(X_test.shape, y_test.shape, y_cluster_test.shape)

Loading test set: 100%|█████████████████████████████████████████████| 100/100 [00:00<00:00, 163.17it/s]


(100, 800, 800) (100, 100, 100, 3) (100, 100, 100)


# 2 Create the model

In [7]:
# create model
tf.keras.backend.clear_session()
if 'deep_way' in name_model:
    deepway_net = build_deepway(name_model, config['FILTERS'],
                            config['KERNEL_SIZE'],
                            config['R'], config['MASK_DIM'])
elif 'cluster_way' in name_model:
    j = name_model.find('cluster_way')
    name_classic = name_model[:j] + 'deep_way' + name_model[j+11:]
    model_classic = build_deepway(name_classic, config['FILTERS'],
                            config['KERNEL_SIZE'],
                            config['R'], config['MASK_DIM'], True)

    deepway_net = build_clusterway(name_model, model_classic, config['FILTERS'],
                            config['KERNEL_SIZE'], out_feats=config['OUT_FEATS'])
else:
    raise ValueError(f'Wrong model {name_model}.')

In [8]:
deepway_net.summary()

Model: "cluster_way_pretrained_0"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 800, 800)]   0           []                               
                                                                                                  
 lambda (Lambda)                (None, 800, 800)     0           ['input_1[0][0]']                
                                                                                                  
 lambda_1 (Lambda)              (None, 800, 800, 1)  0           ['lambda[0][0]']                 
                                                                                                  
 conv2d (Conv2D)                (None, 400, 400, 16  800         ['lambda_1[0][0]']               
                                )                                          

In [9]:
# load weights
loss={'mask': deepWayLoss('none')}
if 'cluster_way' in name_model:
    loss['features'] = clusterLoss('none')
trainer = Trainer(deepway_net, config, loss=loss, optimizer=tf.keras.optimizers.Adam(0.), checkpoint_dir=PATH_WEIGHTS)

Model cluster_way_pretrained_0 restored from checkpoint at step 33286.


# 3 Visualize predictions

In [12]:
predictor = RotationalPredictor(deepway_net, X_test, config, dim=(10,10))
predictor.start()

HBox(children=(BoundedIntText(value=0, description='Image:', max=99), Button(description='Reset', style=Button…

Output()