In [1]:
from argparse import ArgumentParser
import utils
import torch
from models.basic_model import CDEvaluator
import os
from dask.distributed import Client
from dask_cuda import LocalCUDACluster
cluster = LocalCUDACluster(threads_per_worker=4)
client = Client(cluster)
print(f"/proxy/{client.scheduler_info()['services']['dashboard']}/status")

/proxy/8787/status


2022-05-13 03:21:20,628 - distributed.diskutils - INFO - Found stale lock file and directory '/home/dylan/code/planet-regrid-poc/BIT_CD/dask-worker-space/worker-y9wu0nu5', purging
2022-05-13 03:21:20,628 - distributed.preloading - INFO - Import preload module: dask_cuda.initialize


In [21]:
from typing import TypedDict

In [2]:
def get_args():
    # ------------
    # args
    # ------------
    parser = ArgumentParser()
    parser.add_argument('--project_name', default='BIT_LEVIR', type=str)
    parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
    parser.add_argument('--checkpoint_root', default='checkpoints', type=str)
    parser.add_argument('--output_folder', default='samples/predict', type=str)

    # data
    parser.add_argument('--num_workers', default=0, type=int)
    parser.add_argument('--dataset', default='CDDataset', type=str)
    parser.add_argument('--data_name', default='quick_start', type=str)

    parser.add_argument('--batch_size', default=1, type=int)
    parser.add_argument('--split', default="demo", type=str)
    parser.add_argument('--img_size', default=256, type=int)

    # model
    parser.add_argument('--n_class', default=2, type=int)
    parser.add_argument('--net_G', default='base_transformer_pos_s4_dd8_dedim8', type=str,
                        help='base_resnet18 | base_transformer_pos_s4_dd8 | base_transformer_pos_s4_dd8_dedim8|')
    parser.add_argument('--checkpoint_name', default='best_ckpt.pt', type=str)

    args = parser.parse_args(args=[])
    return args

In [3]:
args = get_args()
utils.get_device(args)

In [4]:
def load_model():
    device = torch.device("cuda:%s" % args.gpu_ids[0]
                          if torch.cuda.is_available() and len(args.gpu_ids)>0
                        else "cpu")
    args.checkpoint_dir = os.path.join(args.checkpoint_root, args.project_name)
    os.makedirs(args.output_folder, exist_ok=True)

    log_path = os.path.join(args.output_folder, 'log_vis.txt')

    data_loader = utils.get_loader(args.data_name, img_size=args.img_size,
                                   batch_size=args.batch_size,
                                   split=args.split, is_train=False)

    model = CDEvaluator(args)
    model.load_checkpoint(args.checkpoint_name)
    # model.to(device) 
    model.eval()

In [5]:
remote_model = client.submit(load_model)

#### a single item plugged into the model has to be of this format

In [None]:
# def __getitem__(self, index):
#         name = self.img_name_list[index]
#         A_path = get_img_path(self.root_dir, self.img_name_list[index % self.A_size])
#         B_path = get_img_post_path(self.root_dir, self.img_name_list[index % self.A_size])

#         img = np.asarray(Image.open(A_path).convert('RGB'))
#         img_B = np.asarray(Image.open(B_path).convert('RGB'))

#         [img, img_B], _ = self.augm.transform([img, img_B],[], to_tensor=self.to_tensor)

#         return {'A': img, 'B': img_B, 'name': name}

# Load Planet Data

In [6]:
import xarray
import rioxarray
import boto3
import rasterio as rio
from rasterio.session import AWSSession
import os

In [7]:
session = boto3.Session(aws_access_key_id=os.getenv('makepath_pb_id'),
                       aws_secret_access_key=os.getenv('makepath_pb_key'))

In [14]:
rio_env = rio.Env(session=AWSSession(session))
rio_env.__enter__()

<rasterio.env.Env at 0x7fca2de28b50>

In [9]:
def get_file_names(bucket_name,prefix):
    """
    Return the latest file name in an S3 bucket folder.

    :param bucket: Name of the S3 bucket.
    :param prefix: Only fetch keys that start with this prefix (folder  name).
    """
    s3_client = boto3.client('s3',aws_access_key_id=os.getenv('makepath_pb_id'),
                       aws_secret_access_key=os.getenv('makepath_pb_key'))
    objs = s3_client.list_objects_v2(Bucket=bucket_name)['Contents']
    shortlisted_files = []    #dict({})      
    for obj in objs:
        key = obj['Key']
        timestamp = obj['LastModified']
        # if key starts with folder name retrieve that key
        if key.startswith(prefix) and key.endswith('.tif'):              
            # Adding a new key value pair
            shortlisted_files.append("s3://"+bucket_name+"/"+key)   #{key : timestamp}
    return shortlisted_files

In [10]:
latest_filenames = get_file_names(bucket_name='makepath-planet-data',prefix = 'Full/')

In [11]:
latest_filenames[0]

's3://makepath-planet-data/Full/22_05/694c5805-f9fd-4cba-b4dc-046954166dde/PSScene/20220507_170051_95_241c_3B_Visual.tif'

In [17]:
ds1 = rioxarray.open_rasterio(latest_filenames[0], chunks=(4,8192,8192), lock=False)

In [18]:
ds1

Unnamed: 0,Array,Chunk
Bytes,463.65 MiB,256.00 MiB
Shape,"(4, 9281, 13096)","(4, 8192, 8192)"
Count,5 Tasks,4 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 463.65 MiB 256.00 MiB Shape (4, 9281, 13096) (4, 8192, 8192) Count 5 Tasks 4 Chunks Type uint8 numpy.ndarray",13096  9281  4,

Unnamed: 0,Array,Chunk
Bytes,463.65 MiB,256.00 MiB
Shape,"(4, 9281, 13096)","(4, 8192, 8192)"
Count,5 Tasks,4 Chunks
Type,uint8,numpy.ndarray


In [19]:
ds2 = rioxarray.open_rasterio(latest_filenames[1], chunks=(4,8192,8192), lock=False)

In [20]:
ds2

Unnamed: 0,Array,Chunk
Bytes,463.75 MiB,256.00 MiB
Shape,"(4, 9283, 13096)","(4, 8192, 8192)"
Count,5 Tasks,4 Chunks
Type,uint8,numpy.ndarray
"Array Chunk Bytes 463.75 MiB 256.00 MiB Shape (4, 9283, 13096) (4, 8192, 8192) Count 5 Tasks 4 Chunks Type uint8 numpy.ndarray",13096  9283  4,

Unnamed: 0,Array,Chunk
Bytes,463.75 MiB,256.00 MiB
Shape,"(4, 9283, 13096)","(4, 8192, 8192)"
Count,5 Tasks,4 Chunks
Type,uint8,numpy.ndarray


In [23]:
import dask.array
def predict_chips(data:TypedDict('data',{'A': torch.Tensor, 'B': torch.Tensor, 'name':str}),model)->torch.Tensor:
    return model._forward_pass(data).to("cpu")

In [None]:
def copy_and_predict_chunked(