In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
import segmentation_models_pytorch as smp

from neural_compressor.utils.pytorch import load
from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion, AccuracyCriterion
from neural_compressor.quantization import fit

from helper_datasets import *
from helper_model import *


### Config

In [None]:
# CONFIG
CHECKPOINT_PATH = './models/best/l8_sr_v21.ckpt'

DATA_DIR = './data/landsat8_v9_sr/'
# MEAN_STD_PATH = './data/mean_stds/mean_std_ls7_v9.npy'
MEAN_STD_PATH = './data/mean_stds/mean_std_ls8_v9.npy'

OUT_NPY_OG = './data/preds/ls8_2017_preds_val_og.npy'
OUT_NPY_QUANT = './data/preds/ls8_2017_preds_val_quant.npy'
OUT_NPY_MASKS = './data/preds/val_masks.npy'

QUANT_MODEL_DIR = './models/best/quantized_model_l8/'

FIT_QUANTIZE = False
EVAL_QUANTIZE = True

### Run Prep

In [None]:
mean_std = np.load(MEAN_STD_PATH)

In [None]:

x_valid_dir = os.path.join(DATA_DIR, 'img_dir/val')
y_valid_dir = os.path.join(DATA_DIR, 'ann_dir/val')
# x_valid_dir = os.path.join(DATA_DIR, 'img_dir/test')
# y_valid_dir = os.path.join(DATA_DIR, 'ann_dir/test')
# x_valid_dir = os.path.join(DATA_DIR, 'img_dir/train')
# y_valid_dir = os.path.join(DATA_DIR, 'ann_dir/train')


In [None]:
# helper function for data visualization
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 5))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image)
        print(image.max())
    plt.show()

In [None]:
# Lets look at data we have
dataset = Dataset(x_valid_dir, y_valid_dir, classes=['Water'],
                  mean_std=mean_std
)

batch = dataset[5]
visualize(
    image=batch['image'][:,:, [-3]], 
    water_mask=batch['mask'].squeeze(),
)

# Eval and save data: OG

In [None]:
model =  ResModel.load_from_checkpoint(CHECKPOINT_PATH, in_channels=6, out_classes=1, arch='', encoder_name='resnet34')

In [None]:
CLASSES = ['Water']

valid_dataset = Dataset(
    x_valid_dir, 
    y_valid_dir, 
    preprocessing=get_preprocessing(),
    classes=CLASSES,
    mean_std=mean_std,
)

valid_dataset_imageonly = DatasetImageOnly(
    x_valid_dir, 
    y_valid_dir, 
    preprocessing=get_preprocessing(),
    classes=CLASSES,
    mean_std=mean_std,
)

valid_loader_withmasks = DataLoader(valid_dataset, batch_size=4, shuffle=False, num_workers=2)
valid_loader = DataLoader(valid_dataset_imageonly, batch_size=4, shuffle=False, num_workers=2)
trainer = pl.Trainer()

In [None]:
preds =  np.vstack(trainer.predict(model, valid_loader_withmasks))[:,0 ,: ,: ]
np.save(OUT_NPY_OG, preds)

In [None]:
masks = np.vstack([valid_dataset[i]['mask'] for i in range(len(valid_dataset))])
np.save(OUT_NPY_MASKS, masks)

In [None]:
cutoff = 0.035

In [None]:
def compute_stats(tp, fp, fn, tn):
    iou = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    f1 = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
    prec = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro")
    return np.array([iou, f1, prec,recall])

In [None]:
preds_binary = torch.Tensor(preds>cutoff).long()
tp, fp, fn, tn = smp.metrics.get_stats(preds_binary,
                                        torch.Tensor(masks).long(),
                                        mode="binary")
compute_stats(tp, fp, fn, tn)

In [None]:
torch.round(torch.Tensor(preds*254))

In [None]:
preds_byte = np.round(preds * 254).astype(np.uint8)
preds_binary = torch.Tensor(preds_byte>(cutoff*254)).long()
tp, fp, fn, tn = smp.metrics.get_stats(preds_binary,
                                        torch.Tensor(masks[:, 10:490, 10:490]).long(),
                                        mode="binary")
compute_stats(tp, fp, fn, tn)

### Fit Quantized model

In [None]:
def eval_func_for_nc(model_n, trainer_n):
    setattr(model, "model", model_n)
    preds = np.vstack(trainer_n.predict(model, valid_loader_withmasks))[:,0 ,: ,: ]
    tp, fp, fn, tn = smp.metrics.get_stats(torch.Tensor(preds>0.5).long(),
                                       torch.Tensor(
                                           np.vstack([valid_dataset[i]['mask'] for i in range(len(valid_dataset))]
                                                     )).long(),
                                       mode="binary")
    return float(smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro"))


def eval_func(model):
    return eval_func_for_nc(model, trainer)

accuracy_criterion = AccuracyCriterion(tolerable_loss=0.01)
tuning_criterion = TuningCriterion(max_trials=5)
conf = PostTrainingQuantConfig(
    approach="auto", backend="default", tuning_criterion=tuning_criterion, accuracy_criterion=accuracy_criterion
) 
if FIT_QUANTIZE:
    q_model = fit(model=model.model, conf=conf, calib_dataloader=valid_loader, eval_func=eval_func)
    q_model.save(QUANT_MODEL_DIR)

### Eval and save: Quantized

In [None]:
if EVAL_QUANTIZE:
    model.model = load(os.path.join(QUANT_MODEL_DIR, 'best_model.pt'), model.model)
    preds =  np.vstack(trainer.predict(model, valid_loader_withmasks))[:,0 ,: ,: ]
    np.save(OUT_NPY_QUANT, preds)