# Load model pytorch only


In [None]:
# bizarre error, the order of imports matters here. If DefaultPredictor isn't imported after numpy, etc. then kenrel restarts
import os
import skimage.io as skio
from skimage import img_as_ubyte, exposure
from PIL import Image as pilimg
import warnings
import numpy as np
from detectron2.config import get_cfg
from detectron2.engine.defaults import DefaultPredictor
from detectron2.modeling import build_model
from detectron2.checkpoint import DetectionCheckpointer
import detectron2.data.transforms as T
import torch

def load_and_edit_cfg_for_inference(cfg_path= "../app/pytorch_api/config.yaml",
                                   model_weights_path = "../app/pytorch_api/best-mrcnn-nebraska-model-rgb-jpeg-split-geo-nebraska-freeze0-withseed.pth"):
    cfg = get_cfg()    # obtain detectron2's default config
    cfg.CONFIG_NAME = '' # add new configs for your own custom components
    cfg.DATASET_PATH = ''
    cfg.merge_from_file(cfg_path)   # load values from a file
    cfg.MODEL.WEIGHTS = model_weights_path
    cfg.INPUT.MIN_SIZE_TEST = 800
    cfg.INPUT.MAX_SIZE_TEST = 1333
    model_cfg = cfg.clone()
    return model_cfg

def load_model_for_inference(cfg):
    print('pytorch_classifier.py: Loading model...')
    model = build_model(cfg)  # returns a torch.nn.Module
    model.eval()

    checkpointer = DetectionCheckpointer(model)
    checkpointer.load(cfg.MODEL.WEIGHTS); # ; suppresses long output
    return model

def run_model_single_image(model):
    img = skio.imread("../images/aoi_restricted_LT05_CU_015009_20050722_20190102_C01_V0_-169665_1934999.jpg")
    rgb_img = img[:,:,::-1] # assumes image is in BGR order, puts it in RGB order since model expects RGB
    aug = T.ResizeShortestEdge(
            [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST
        )
    with torch.no_grad():
        height, width = img.shape[:2]
        print(rgb_img.dtype)
        img = aug.get_transform(rgb_img).apply_image(rgb_img)
        img = torch.as_tensor(img.astype("float32").transpose(2, 0, 1))
        inputs = {"image": img, "height": height, "width": width}
        predictions = model([inputs])
    cpu_output = predictions[0]["instances"].to("cpu")
    return cpu_output, rgb_img

In [None]:
# Run your model function
cfg = load_and_edit_cfg_for_inference()
model = load_model_for_inference(cfg)
predictions, rgb_img = run_model_single_image(model)

In [None]:
from detectron2.utils.visualizer import ColorMode
from detectron2.data import MetadataCatalog
from detectron2.utils.visualizer import Visualizer
import matplotlib.pyplot as plt
metadata = MetadataCatalog.get("test")
vis_p = Visualizer(rgb_img, metadata, instance_mode=ColorMode.SEGMENTATION)

# move to cpu
# instances = result['instances']
vis_pred_im = vis_p.draw_instance_predictions(predictions).get_image()

def show_im(image, ax, taskID):
    # Show area outside image boundaries.
    ax.axis('off')
    ax.imshow(image)
    plt.savefig(taskID)
    return ax

plt.rcParams['font.size'] = 20
plt.rcParams['axes.linewidth'] = 2
plt.style.use("seaborn")
fig,ax = plt.subplots(figsize=(10,10))
show_im(vis_pred_im,ax, "taskid")

In [None]:
import torch
from torch.utils.cpp_extension import CUDA_HOME
print(torch.cuda.is_available(), CUDA_HOME)

In [None]:
def read_and_rescale_tif(img_tile, clamp_low=262, clamp_high=1775): # Landsat 5 ARD .25 and 97.75 percentile range for Nebraska
    img_array = skio.imread(img_tile)
    img_array = exposure.rescale_intensity(img_array, in_range=(clamp_low, clamp_high))  # Landsat 5 ARD .25 and 97.75 percentile range.
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        img_array = img_as_ubyte(img_array)
    return img_array

def save_rescaled_jpeg(old_tif_path, jpeg_dir, rescaled_arr):
    img_pil = pilimg.fromarray(rescaled_arr)
    fid = os.path.basename(old_tif_path).split(".tif")[0]
    jpeg_path = os.path.join(jpeg_dir, fid + ".jpg")
    # Export chip images
    with open(Path(jpeg_path), 'w') as dst:
        img_pil.save(dst, format='JPEG', subsampling=0, quality=100)

In [None]:
import detectron2
detectron2.__version__

In [None]:
import torch
torch.__version__

In [None]:
import torchvision
torchvision.__version__