In [2]:
%load_ext autoreload
%autoreload 2
#%matplotlib notebook

In [3]:
import sys
sys.path.append("..")

In [11]:
import glob
from pathlib import Path
import csv
from matplotlib import pyplot as plt
#plt.ioff()
import os
import odeon
from odeon.data.data_module import Input
from odeon.models.change.arch.changeformer.ChangeFormer import ChangeFormerV6
import torch
from PIL import Image
import torchvision.transforms.functional as TF
import numpy as np


In [49]:

#from models.networks import *
#from misc.metric_tool import ConfuseMatrixMeter
#from misc.logger_tool import Logger
#from utils import de_norm
#import utils


# Decide which device we want to run on
# torch.cuda.current_device()

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def de_norm(tensor_data):
    return tensor_data * 0.5 + 0.5


def get_confuse_matrix(num_classes, label_gts, label_preds):
    """计算一组预测的混淆矩阵"""
    def __fast_hist(label_gt, label_pred):
        """
        Collect values for Confusion Matrix
        For reference, please see: https://en.wikipedia.org/wiki/Confusion_matrix
        :param label_gt: <np.array> ground-truth
        :param label_pred: <np.array> prediction
        :return: <np.ndarray> values for confusion matrix
        """
        mask = (label_gt >= 0) & (label_gt < num_classes)
        hist = np.bincount(num_classes * label_gt[mask].astype(int) + label_pred[mask],
                           minlength=num_classes**2).reshape(num_classes, num_classes)
        return hist
    confusion_matrix = np.zeros((num_classes, num_classes))
    for lt, lp in zip(label_gts, label_preds):
        confusion_matrix += __fast_hist(lt.flatten(), lp.flatten())
    return confusion_matrix

def cm2F1(confusion_matrix):
    hist = confusion_matrix
    n_class = hist.shape[0]
    tp = np.diag(hist)
    sum_a1 = hist.sum(axis=1)
    sum_a0 = hist.sum(axis=0)
    # ---------------------------------------------------------------------- #
    # 1. Accuracy & Class Accuracy
    # ---------------------------------------------------------------------- #
    acc = tp.sum() / (hist.sum() + np.finfo(np.float32).eps)

    # recall
    recall = tp / (sum_a1 + np.finfo(np.float32).eps)
    # acc_cls = np.nanmean(recall)

    # precision
    precision = tp / (sum_a0 + np.finfo(np.float32).eps)

    # F1 score
    F1 = 2 * recall * precision / (recall + precision + np.finfo(np.float32).eps)
    mean_F1 = np.nanmean(F1)
    return mean_F1


###################       metrics      ###################
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.initialized = False
        self.val = None
        self.avg = None
        self.sum = None
        self.count = None

    def initialize(self, val, weight):
        self.val = val
        self.avg = val
        self.sum = val * weight
        self.count = weight
        self.initialized = True

    def update(self, val, weight=1):
        if not self.initialized:
            self.initialize(val, weight)
        else:
            self.add(val, weight)

    def add(self, val, weight):
        self.val = val
        self.sum += val * weight
        self.count += weight
        self.avg = self.sum / self.count

    def value(self):
        return self.val

    def average(self):
        return self.avg

    def get_scores(self):
        scores_dict = cm2score(self.sum)
        return scores_dict

    def clear(self):
        self.initialized = False

###################      cm metrics      ###################
class ConfuseMatrixMeter(AverageMeter):
    """Computes and stores the average and current value"""
    def __init__(self, n_class):
        super(ConfuseMatrixMeter, self).__init__()
        self.n_class = n_class

    def update_cm(self, pr, gt, weight=1):
        """获得当前混淆矩阵，并计算当前F1得分，并更新混淆矩阵"""
        val = get_confuse_matrix(num_classes=self.n_class, label_gts=gt, label_preds=pr)
        self.update(val, weight)
        current_score = cm2F1(val)
        return current_score

    def get_scores(self):
        scores_dict = cm2score(self.sum)
        return scores_dict




In [50]:

class CDEvaluator():

    def __init__(self, dataloader, checkpoint_dir, n_class = 2, embed_dim = 256, gpu_ids = []):

        self.dataloader = dataloader

        self.n_class = n_class
        # define G
        #self.net_G = define_G(args=args, gpu_ids=args.gpu_ids)
        self.net_G = ChangeFormerV6(embed_dim=embed_dim)
        self.device = torch.device("cuda:%s" % gpu_ids[0] if torch.cuda.is_available() and len(gpu_ids)>0
                                   else "cpu")
        print(self.device)

        # define some other vars to record the training states
        self.running_metric = ConfuseMatrixMeter(n_class=self.n_class)

        # define logger file
        #logger_path = os.path.join(args.checkpoint_dir, 'log_test.txt')
        #self.logger = Logger(logger_path)
        #self.logger.write_dict_str(args.__dict__)


        #  training log
        self.epoch_acc = 0
        self.best_val_acc = 0.0
        self.best_epoch_id = 0

        self.steps_per_epoch = len(dataloader)

        self.G_pred = None
        self.pred_vis = None
        self.batch = None
        self.is_training = False
        self.batch_id = 0
        self.epoch_id = 0
        self.checkpoint_dir = checkpoint_dir
#         self.vis_dir = args.vis_dir

        # check and create model dir
        if os.path.exists(self.checkpoint_dir) is False:
            os.mkdir(self.checkpoint_dir)
#         if os.path.exists(self.vis_dir) is False:
#             os.mkdir(self.vis_dir)


    def _load_checkpoint(self, checkpoint_name='best_ckpt.pt'):

        if os.path.exists(os.path.join(self.checkpoint_dir, checkpoint_name)):
            print('loading last checkpoint...\n')
            # load the entire checkpoint
            checkpoint = torch.load(os.path.join(self.checkpoint_dir, checkpoint_name), map_location=self.device)

            self.net_G.load_state_dict(checkpoint['model_G_state_dict'])

            self.net_G.to(self.device)

            # update some other states
            self.best_val_acc = checkpoint['best_val_acc']
            self.best_epoch_id = checkpoint['best_epoch_id']

            print('Eval Historical_best_acc = %.4f (at epoch %d)\n' %
                  (self.best_val_acc, self.best_epoch_id))
            print('\n')

        else:
            raise FileNotFoundError('no such checkpoint %s' % checkpoint_name)


    def _visualize_pred(self):
        pred = torch.argmax(self.G_pred, dim=1, keepdim=True)
        pred_vis = pred * 255
        return pred_vis


    def _update_metric(self):
        """
        update metric
        """
        target = self.batch['mask'].to(self.device).detach()
        G_pred = self.G_pred.detach()
        G_pred = torch.argmax(G_pred, dim=1)

        current_score = self.running_metric.update_cm(pr=G_pred.cpu().numpy(), gt=target.cpu().numpy())
        return current_score

    def _collect_running_batch_states(self):

        running_acc = self._update_metric()

#         m = len(self.dataloader)

#         if np.mod(self.batch_id, 100) == 1:
#             message = 'Is_training: %s. [%d,%d],  running_mf1: %.5f\n' %\
#                       (self.is_training, self.batch_id, m, running_acc)
#             self.logger.write(message)

#         if np.mod(self.batch_id, 100) == 1:
#             vis_input = utils.make_numpy_grid(de_norm(self.batch['A']))
#             vis_input2 = utils.make_numpy_grid(de_norm(self.batch['B']))

#             vis_pred = utils.make_numpy_grid(self._visualize_pred())

#             vis_gt = utils.make_numpy_grid(self.batch['L'])
#             vis = np.concatenate([vis_input, vis_input2, vis_pred, vis_gt], axis=0)
#             vis = np.clip(vis, a_min=0.0, a_max=1.0)
#             file_name = os.path.join(
#                 self.vis_dir, 'eval_' + str(self.batch_id)+'.jpg')
#             plt.imsave(file_name, vis)


    def _collect_epoch_states(self):

        scores_dict = self.running_metric.get_scores()

        np.save(os.path.join(self.checkpoint_dir, 'scores_dict.npy'), scores_dict)

        self.epoch_acc = scores_dict['mf1']

        with open(os.path.join(self.checkpoint_dir, '%s.txt' % (self.epoch_acc)),
                  mode='a') as file:
            pass

        message = ''
        for k, v in scores_dict.items():
            message += '%s: %.5f ' % (k, v)
        self.logger.write('%s\n' % message)  # save the message

        self.logger.write('\n')

    def _clear_cache(self):
        self.running_metric.clear()

    def _forward_pass(self, batch):
        self.batch = batch
        img_in1 = batch['T0'].to(self.device)
        img_in2 = batch['T1'].to(self.device)
        x = torch.stack(tensors=(img_in1, img_in2), dim=1)
        self.G_pred = self.net_G(x)

    def eval_models(self,checkpoint_name='best_ckpt.pt'):

        self._load_checkpoint(checkpoint_name)

        ################## Eval ##################
        ##########################################
        print('Begin evaluation...\n')
        self._clear_cache()
        self.is_training = False
        self.net_G.eval()

        # Iterate over data.
        for self.batch_id, batch in enumerate(self.dataloader, 0):
            with torch.no_grad():
                self._forward_pass(batch)
            self._collect_running_batch_states()
        self._collect_epoch_states()

In [28]:
batch_size = 10
num_workers = 8

cur_levir_path = Path("/mnt/store_dai/datasrc/dchan/levir-cd-repatched-256x256")

train_params = {
    'input_fields': {
        "T0": {"name": "A", "type": "raster", "dtype": "uint8", "band_indices": [1, 2, 3]},
        "T1": {"name": "B", "type": "raster", "dtype": "uint8", "band_indices": [1, 2, 3]},
        "mask": {"name": "label", "type": "mask", "encoding": "integer"}
    },
    'dataloader_options' : {"batch_size": batch_size, "num_workers": num_workers},
    'input_file': cur_levir_path / "train.csv",
    'input_files_has_header': 'infer',
    'root_dir': cur_levir_path,
    'nb_samples': 0,
    'sample_seed': 0
}

val_params = {
    'input_fields': {
        "T0": {"name": "A", "type": "raster", "dtype": "uint8", "band_indices": [1, 2, 3]},
        "T1": {"name": "B", "type": "raster", "dtype": "uint8", "band_indices": [1, 2, 3]},
        "mask": {"name": "label", "type": "mask", "encoding": "integer"}
    },
    'dataloader_options' : {"batch_size": batch_size, "num_workers": num_workers},
    'input_file': cur_levir_path / "val.csv",
    'input_files_has_header': 'infer',
    'root_dir': cur_levir_path,
    'nb_samples': 0,
    'sample_seed': 0
}

test_params = {
    'input_fields': {
        "T0": {"name": "A", "type": "raster", "dtype": "uint8", "band_indices": [1, 2, 3]},
        "T1": {"name": "B", "type": "raster", "dtype": "uint8", "band_indices": [1, 2, 3]},
        "mask": {"name": "label", "type": "mask", "encoding": "integer"}
    },
    'dataloader_options' : {"batch_size": batch_size, "num_workers": num_workers},
    'input_file': cur_levir_path / "test.csv",
    'root_dir': cur_levir_path,
    'input_files_has_header': 'infer',
    'nb_samples': 0,
    'sample_seed': 0
}


input = Input(
    fit_params=train_params,
    validate_params=val_params,
    test_params=test_params
)


In [30]:
input.prepare_data()
input.setup(stage="fit")

In [51]:
model = CDEvaluator(
    dataloader=input.validate.dataloader,
    checkpoint_dir="/mnt/store_dai/equipiers/pvoitot/dchan/checkpoints/ChangeFormerV6_LEVIR_ckpt"
)


cpu


In [52]:
model.eval_models()

loading last checkpoint...

Eval Historical_best_acc = 0.9495 (at epoch 176)



Begin evaluation...



  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)
  dataset = DatasetReader(path, driver=driver, sharing=sharing, **kwargs)


NameError: name 'cm2score' is not defined