In [1]:
import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp
import numpy as np
import albumentations as A
from albumentations.pytorch import ToTensorV2
from geonet import tiler, mask, raster, dataset
from geonet.visualizations import plotImagePair

In [2]:
# add augmentations
aug = A.Compose([
    #A.Normalize(mean=(0.0095, 0.0087, 0.0078), std=(0.0075, 0.0070, 0.0060)),
    A.RandomRotate90(p=0.6),
    A.HorizontalFlip(p=0.6),
    #ToTensorV2()
])

In [3]:
def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def preprocess_input(x, mean=[0.187, 0.182, 0.139, 0.215], std=[0.015, 0.035, 0.038, 0.067], input_space="RGB", input_range=[0,1], **kwargs):

    if input_space == "BGR":
        x = x[..., ::-1].copy()

    if input_range is not None:
        if x.max() > 1 and input_range[1] == 1:
            x = x / 5000.0

    if mean is not None:
        mean = np.array(mean)
        x = x - mean

    if std is not None:
        std = np.array(std)
        x = x / std

    return x


def get_inference_preprocessing(preprocessing_fn):
    """Construct preprocessing transform
    
    Args:
        preprocessing_fn (callbale): data normalization function 
            (can be specific for each pretrained neural network)
    Return:
        transform: albumentations.Compose
    
    """
    
    _transform = [
        A.Lambda(image=preprocess_input),
        A.Lambda(image=to_tensor)
    ]
    return A.Compose(_transform)

In [10]:
# initialize model
ENCODER = 'se_resnext50_32x4d'
ENCODER_WEIGHTS = 'imagenet'
CLASSES = ['agro']
ACTIVATION = 'sigmoid' # could be None for logits or 'softmax2d' for multicalss segmentation
DEVICE = 'cuda'

# create segmentation model with pretrained encoder
model = smp.FPN(
    encoder_name=ENCODER, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
    in_channels=4,
)

preprocessing_fn = smp.encoders.get_preprocessing_fn(ENCODER, ENCODER_WEIGHTS)

In [23]:
agro_model = torch.load("./weights/best/agronet_v1.2.2 - 0.7559994951487025.pth")

In [6]:
forest_model = torch.load("./weights/best/forestnet_v1.0 - 0.6568768178222868.pth")

In [36]:
img = raster.get_array_from_tiff("../../data/_mvp/vrn_train_3857.tif")
img = np.dstack(img)
mask = raster.get_array_from_tiff("../../data/_mvp/vrn_3857_mask.tif")[0]

In [37]:
#forestData = dataset.RasterDataset(img, mask, classes=['agro'], tile_size=1024, step=768, preprocessing=get_inference_preprocessing(preprocessing_fn))
agroData = dataset.RasterDataset(img, mask, classes=['agro'], step=384, preprocessing=get_inference_preprocessing(preprocessing_fn))

  "Using lambda is incompatible with multiprocessing. "


In [38]:
agro_loader = DataLoader(agroData, batch_size=1, shuffle=False, num_workers=2)
#forest_loader = DataLoader(forestData, batch_size=1, shuffle=False, num_workers=2)

In [28]:
def cnn_predict(model, img, test_loader):
    ext_x = np.zeros(shape=(img.shape[0], img.shape[1]), dtype=np.float32)
    step = 384
    tile_size = 512
    xc = round(img.shape[0] / step) + 1
    yc = round(img.shape[1] / step) + 1

    i = 0
    for batch in test_loader:
        m = i % xc
        j = i // xc
        #x_tensor = torch.from_numpy(batch[0]).unsqueeze(0)
        pr_mask = model.predict(batch[0].cuda())
        pr_mask = (pr_mask.cpu().numpy().round(decimals=2))

            
        if (step*m+tile_size) > img.shape[0]:
            if (step*j+tile_size) > img.shape[1]:
                ext_x[(img.shape[0]-tile_size):img.shape[0], (img.shape[1]-tile_size):img.shape[1]] = np.maximum(ext_x[(img.shape[0]-tile_size):img.shape[0], (img.shape[1]-tile_size):img.shape[1]], pr_mask)
            else:
                ext_x[(img.shape[0]-tile_size):img.shape[0], step*j:(step*j+tile_size)] = np.maximum(ext_x[(img.shape[0]-tile_size):img.shape[0], step*j:(step*j+tile_size)], pr_mask)
        elif (step*j+tile_size) > img.shape[1]:
            ext_x[step*m:(step*m+tile_size), (img.shape[1]-tile_size):img.shape[1]] = np.maximum(ext_x[step*m:(step*m+tile_size), (img.shape[1]-tile_size):img.shape[1]], pr_mask)
        else:
            ext_x[step*m:(step*m+tile_size), step*j:(step*j+tile_size)] = np.maximum(ext_x[step*m:(step*m+tile_size), step*j:(step*j+tile_size)], pr_mask)
    
        i += 1
    
    return ext_x

In [14]:
def complex_predict(img):
    forest_model = torch.load('../geonet/weights/forestnet_v1.0 - 0.6568768178222868.pth')
    image = raster.get_array_from_tiff(img)
    cnn_img = np.dstack(image)
    testData = dataset.RasterDataset(cnn_img.astype(float), cnn_img[0].astype(float), classes=['agro'], tile_size=1024, step=768, preprocessing=get_inference_preprocessing(preprocessing_fn))
    test_loader = DataLoader(testData, batch_size=1, shuffle=False, num_workers=1)
    gbm = lgb.Booster(model_file="../../data_science/geonet/weights/forest_gbm_v0.1.2.txt")
    cnn_pred = cnn_predict(forest_model, cnn_img, test_loader)
    gbm_pred = gbm.predict(flatten_file(img))
    gbm_pred = np.reshape(gbm_pred, image[0].shape)
    preds = np.maximum(cnn_pred, gbm_pred)
    return preds

In [39]:
data = np.dstack(raster.get_array_from_tiff("../../data/_mvp/vrn_train_3857.tif"))
ext_x = cnn_predict(agro_model, data, agro_loader)

In [40]:
raster.get_raster_from_array(ext_x, "../mvp/vrn_agro_predicted_fin.tif", "../../data/_mvp/vrn_train_3857.tif")