# Visualization

Test ClusterWay to automatically detect waypoints from a occupancy grid map of a row based crop and cluster them.

In [None]:
%reload_ext autoreload
%autoreload 2

In [1]:
import os
import numpy as np
import tensorflow as tf

from utils.tools import load_config, deepWayLoss, clusterLoss
from utils.models import build_deepway, build_clusterway
from utils.dataset import load_dataset_test
from utils.train import Trainer
from utils.visualization import RotationalPredictor

In [2]:
# select a GPU and set memory growth 
gpus = tf.config.list_physical_devices('GPU')
tf.config.set_visible_devices(gpus[0], 'GPU')
tf.config.experimental.set_memory_growth(gpus[0], True)

In [3]:
# important paths and names
PATH_DIR = os.path.abspath('.')
PATH_WEIGHTS = os.path.join(PATH_DIR, 'bin')

TEST_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/straight/test')
TEST_CURVED_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/curved/test') # curved
SATELLITE_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/satellite/')
SATELLITE_CURVED_DATA_PATH = os.path.join(PATH_DIR, '../Datasets/satellite_curved') # curved

config_file = 'config.json'
config = load_config(config_file)

In [None]:
#select model

#name_model = 'deep_way_pretrained'
name_model = 'cluster_way_pretrained'
CURVED = False
I = 0

name_model += '_curved' if CURVED else ''
name_model += f'_{I}'
name_model

# 1 Import the Test Dataset

In [None]:
#choose target dataset
img_folder = SATELLITE_DATA_PATH

X_test, y_test, y_cluster_test, df_waypoints_test = load_dataset_test(img_folder, config)
print(X_test.shape, y_test.shape, y_cluster_test.shape)

# 2 Create the model

In [None]:
# create model
tf.keras.backend.clear_session()
if 'deep_way' in name_model:
    deepway_net = build_deepway(name_model, config['FILTERS'],
                            config['KERNEL_SIZE'],
                            config['N'], config['MASK_DIM'])
elif 'cluster_way' in name_model:
    name_classic = f'deep_way{"_curved" if CURVED else ""}_{I}'

    model_classic = build_deepway(name_classic, config['FILTERS'],
                            config['KERNEL_SIZE'],
                            config['N'], config['MASK_DIM'], True)

    deepway_net = build_clusterway(name_model, model_classic, config['FILTERS'],
                            config['KERNEL_SIZE'], config['MASK_DIM'], 
                            out_feats=config['OUT_FEATS'])
else:
    raise ValueError(f'Wrong model {name_model}.')

In [None]:
deepway_net.summary()

In [None]:
# load weights
loss={'mask': deepWayLoss('none')}
if 'cluster_way' in name_model:
    loss['features'] = clusterLoss('none')
trainer = Trainer(deepway_net, config, loss=loss, optimizer=tf.keras.optimizers.Adam(0.), checkpoint_dir=PATH_WEIGHTS)

# 3 Visualize predictions

In [None]:
predictor = RotationalPredictor(deepway_net, X_test, config)
predictor.start()

In [8]:
from glob import glob
import shutil

In [14]:
get_n = lambda x: int(x.split('/')[-1][3:-4])

In [16]:
for img in sorted(glob(SATELLITE_CURVED_DATA_PATH+'/*.png'), key=get_n):
    shutil.copy(img, SATELLITE_CURVED_DATA_PATH+f'/correct/img{get_n(img)-1}.png')

In [15]:
get_n(img), get_name(img)

(150, 'img150.png')

In [17]:
import pandas as pd

In [37]:
df = pd.read_csv(SATELLITE_CURVED_DATA_PATH+'/waypoints.csv')
df

Unnamed: 0,N_img,x_wp,y_wp,class
0,img101,315,73,0
1,img101,708,94,1
2,img101,291,118,0
3,img101,704,115,1
4,img101,279,141,0
...,...,...,...,...
1899,img150,671,662,1
1900,img150,227,455,0
1901,img150,670,689,1
1902,img150,215,480,0


In [38]:
get_n = lambda x: int(x.split('/')[-1][3:])

In [39]:
N = [f'img{get_n(img)-1}' for img in df['N_img'].to_numpy()]

In [40]:
df['N_img'] = N
df

Unnamed: 0,N_img,x_wp,y_wp,class
0,img100,315,73,0
1,img100,708,94,1
2,img100,291,118,0
3,img100,704,115,1
4,img100,279,141,0
...,...,...,...,...
1899,img149,671,662,1
1900,img149,227,455,0
1901,img149,670,689,1
1902,img149,215,480,0


In [42]:
df.to_csv(SATELLITE_CURVED_DATA_PATH+'/correct/waypoints.csv', index=False)

In [43]:
df = pd.read_csv(SATELLITE_CURVED_DATA_PATH+'/correct/waypoints.csv')
df

Unnamed: 0,N_img,x_wp,y_wp,class
0,img100,315,73,0
1,img100,708,94,1
2,img100,291,118,0
3,img100,704,115,1
4,img100,279,141,0
...,...,...,...,...
1899,img149,671,662,1
1900,img149,227,455,0
1901,img149,670,689,1
1902,img149,215,480,0
