# Full inference run for PlanetScope Segmentation

### Imports 

In [None]:
from pathlib import Path
import torch
import pandas as pd
import os
import numpy as np
import tqdm
from joblib import delayed, Parallel
import shutil
from tqdm.notebook import tqdm

### Settings 

In [None]:
# Local code dir
CODE_DIR = Path('/isipd/projects/p_aicore_pf/initze/code/aicore_inference')
# Location of raw data
RAW_DATA_DIR = Path('/isipd/projects/p_aicore_pf/initze/data/planet/planet_data_inference_grid/tiles')
# Location data processing
PROCESSING_DIR = Path('/isipd/projects/p_aicore_pf/initze/processing')
# Target directory for
INFERENCE_DIR = Path('/isipd/projects/p_aicore_pf/initze/processed/inference')

# Target to models - RTS
#MODEL_DIR = Path('/isipd/projects/p_aicore_pf/initze/models/thaw_slumps')
#MODEL='RTS_v4'

# Target to models - Water
#MODEL_DIR = Path('/isipd/projects/p_aicore_pf/initze/models/water')
#MODEL='Water_v5_1024'

# Target to models - Pingos
MODEL_DIR = Path('/isipd/projects/p_aicore_pf/initze/models/pingos')
MODEL='pingo_UnetPP_v1_2021-12-12_09-56-50'

#USE_GPU = [0,1,2,3,4,5,6,7]
USE_GPU = [0,1,2,3,4,5]
RUNS_PER_GPU = 5
MAX_IMAGES = None

In [None]:
def run_inference(df, gpu=0, run=False, patch_size=1024, margin_size=256):
    if len(df) == 0:
        print('Empty dataframe')
    else:
        tiles = ' '.join(df.name.values)
        run_string = f"CUDA_VISIBLE_DEVICES='{gpu}' python inference.py -n {MODEL} --data_dir {PROCESSING_DIR} --inference_dir {INFERENCE_DIR}  --patch_size {patch_size} --margin_size {margin_size} {MODEL_DIR/MODEL} {tiles}"
        print(run_string)
        if run:
            os.system(run_string)

def listdirs(rootdir):
    dirs = []
    for path in Path(rootdir).iterdir():
        if path.is_dir():
            #print(path)
            dirs.append(path)
    return dirs

def listdirs2(rootdir, depth=0):
    dirs = []
    for path in Path(rootdir).iterdir():
        if path.is_dir():
            if depth == 1:
                for path2 in Path(path).iterdir():
                    if path2.is_dir():
                        dirs.append(path2)
            else:
                dirs.append(path)
    return dirs

def get_PS_products_type(name):
    if len(name.split('_')) == 3:
        return 'PSScene'
    elif len(name.split('_')) == 4:
        return 'PSOrthoTile'
    else:
        None
        
def get_date_from_PSfilename(name):
    date = name.split('_')[2]
    return date
    

def get_datasets(path, depth=0, preprocessed=False):
    dirs = listdirs2(path, depth=depth)
    df = pd.DataFrame(data=dirs, columns=['path'])

    df['name'] = df.apply(lambda x: x['path'].name, axis=1)
    df['preprocessed'] = preprocessed
    df['PS_product_type'] = df.apply(lambda x: get_PS_products_type(x['name']), axis=1)
    df['image_date'] = df.apply(lambda x: get_date_from_PSfilename(x['name']), axis=1)
    df['tile_id'] = df.apply(lambda x: x['name'].split('_')[1], axis=1)
    return df

def copy_unprocessed_files(row, processing_dir, quiet=True):
    inpath = row['path']
    outpath = processing_dir / 'input' / inpath.name

    if not outpath.exists():
        if not quiet:
            print (f'Start copying {inpath.name} to {outpath}')
        shutil.copytree(inpath, outpath)
    else:
        if not quiet:
            print(f'Skipped copying {inpath.name}')

def update_DEM(vrt_target_dir):
    """
    Function to update elevation vrts
    """
    os.system('./create_ArcticDEM.sh')
    shutil.copy('elevation.vrt', vrt_target_dir)
    shutil.copy('slope.vrt', vrt_target_dir)

In [None]:
def get_processing_status(raw_data_dir, procesing_dir, inference_dir, model):
    # get raw tiles
    df_raw = get_datasets(raw_data_dir, depth=1)
    # get processed
    df_processed = get_datasets(procesing_dir / 'tiles', depth=0, preprocessed=True)
    # calculate prperties
    diff = df_raw[~df_raw['name'].isin(df_processed['name'])]
    df_merged = pd.concat([df_processed, diff]).reset_index()
    
    products_list = [prod.name for prod in list((inference_dir / model).glob('*'))]
    df_merged['inference_finished'] = df_merged.apply(lambda x: x['name'] in (products_list), axis=1)
    
    return df_merged

### List all files with properties

In [None]:
df_processing_status = get_processing_status(RAW_DATA_DIR, PROCESSING_DIR, INFERENCE_DIR, MODEL)

### Select Data 

#### Single file(s)

In [None]:
image_ids = ['4767269_0370913_2021-08-05_227e']

df_final = df_processing_status[df_processing_status['name'].isin([image_ids])]

#### Tile ID

In [None]:
tile_ids = ['0571410']

df_final = df_processing_status[df_processing_status['tile_id'].isin(tile_ids)]
print(f'Number of images: {len(df_final)}')
print(f'Number of preprocessed images: {df_final.preprocessed.sum()}')

#### Tile_ID with regex

In [None]:
tile_id_start = '03'

df_final = df_processing_status[df_processing_status.tile_id.str.startswith(tile_id_start)]
print(f'Number of images: {len(df_final)}')
print(f'Number of preprocessed images: {df_final.preprocessed.sum()}')
print(f'Number of finished images: {df_final.inference_finished.sum()}')

### Full Set 

In [None]:
df_final = df_processing_status
print(f'Number of images: {len(df_final)}')
print(f'Number of preprocessed images: {df_final.preprocessed.sum()}')
print(f'Number of finished images: {df_final.inference_finished.sum()}')

## Preprocessing

#### Update Arctic DEM data 

In [None]:
vrt_target_dir = Path('../../processing/auxiliary/ArcticDEM')
update_DEM(vrt_target_dir)

#### Copy data for Preprocessing 

In [None]:
df_preprocess = df_final[~df_final.preprocessed]
print(f'Number of images to preprocess: {len(df_preprocess)}')

df_preprocess.apply(lambda x: copy_unprocessed_files(x, PROCESSING_DIR), axis=1)

#### Run Preprocessing 

In [None]:
import warnings
warnings.filterwarnings('ignore')

print(f'Preprocessing {len(df_preprocess)} images')
if len(df_preprocess) > 0:
    pp_string = f'python setup_raw_data.py --data_dir {PROCESSING_DIR} --nolabel'
    os.system(pp_string)

## Processing/Inference

#### Parallel runs 

In [None]:
#df_process = df_final[df_final.preprocessed]
df_process = df_final[~df_final.inference_finished].iloc[:MAX_IMAGES]
#df_process = df_final

n_splits = len(USE_GPU) * RUNS_PER_GPU
df_split = np.array_split(df_process, n_splits)
gpu_split = USE_GPU * RUNS_PER_GPU

In [None]:
for split in df_split:
    print(f'Number of images: {len(split)}')

### Parallel Inference execution

In [None]:
Parallel(n_jobs=n_splits)(delayed(run_inference)(df_split[split], gpu=gpu_split[split], run=True) for split in range(n_splits))

### Single image

In [None]:
image_id = '4767269_0370913_2021-08-05_227e'
dslist = list(RAW_DATA_DIR.glob(f'**/*{image_id}*'))
processing_ds = [ds for ds in dslist if ds.is_dir()][0]

In [None]:
dslist = list(PROCESSING_DIR.glob(f'**/*{image_id}*'))
processed_ds = [ds for ds in dslist if ds.is_dir()][0]

#### single run 

In [None]:
tiles = ''.join(df_final.name.values)
run_string = f"CUDA_VISIBLE_DEVICES='4' python inference.py -n {MODEL} --data_dir {PROCESSING_DIR} --inference_dir {INFERENCE_DIR} {MODEL_DIR/MODEL} {tiles}"
os.system(run_string)