<a href="https://colab.research.google.com/github/fregean/Tensorflow-ver-mixed-segdec-net/blob/main/Tensorflow_version.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Tensorflowへ変換

KSDD

画像データ：1408*512

In [1]:
import os

os.chdir('/content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net')


## models.py

In [2]:
# @title models.py

import tensorflow as tf
from tensorflow.keras import layers, Model
import math

BATCHNORM_TRACK_RUNNING_STATS = False
BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997

class Conv2D_init(layers.Conv2D):
    def __init__(self, filters, kernel_size, **kwargs):
        super(Conv2D_init, self).__init__(filters, kernel_size, **kwargs)
        self.kernel_initializer = tf.keras.initializers.RandomNormal(mean=0., stddev=0.01)
        self.bias_initializer = 'zeros'

class FeatureNorm(layers.Layer):
    # PyTorch版のfeature_index=1をfeature_index=-1に変更（channel first→channel last）
    def __init__(self, num_features, feature_index=-1, rank=4, reduce_dims=(2, 3), eps=0.001, include_bias=True):
        super(FeatureNorm, self).__init__()
        self.shape = [1] * rank
        self.shape[feature_index] = num_features
        self.reduce_dims = reduce_dims
        self.scale = self.add_weight(name='scale', shape=self.shape, initializer='ones', trainable=True)
        self.eps = eps

        if include_bias:
            self.bias = self.add_weight(name='bias', shape=self.shape, initializer='zeros', trainable=True)
        else:
            self.bias = 0

    def call(self, features):
        f_std = tf.math.reduce_std(features, axis=self.reduce_dims, keepdims=True)
        f_mean = tf.math.reduce_mean(features, axis=self.reduce_dims, keepdims=True)
        return self.scale * ((features - f_mean) / tf.sqrt(f_std + self.eps)) + self.bias

def _conv_block(out_channels, kernel_size, padding):
    return tf.keras.Sequential([
        Conv2D_init(filters=out_channels, kernel_size=kernel_size, padding=padding, use_bias=False),
        FeatureNorm(num_features=out_channels, eps=0.001),
        layers.ReLU()
    ])

class SegDecNet(Model):
    def __init__(self, device, input_width, input_height, input_channels):
        super(SegDecNet, self).__init__()
        if input_width % 8 != 0 or input_height % 8 != 0:
            raise Exception(f"Input size must be divisible by 8! width={input_width}, height={input_height}")
        self.input_width = input_width
        self.input_height = input_height
        self.input_channels = input_channels
        self.volume = tf.keras.Sequential([
            _conv_block(32, 5, 'same'), 
            layers.MaxPool2D(2),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            layers.MaxPool2D(2),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            layers.MaxPool2D(2),
            _conv_block(1024, 15, 'same')
        ])

        self.seg_mask = tf.keras.Sequential([
            Conv2D_init(filters=1, kernel_size=1, use_bias=False),
            FeatureNorm(num_features=1, eps=0.001)
        ])

        self.extractor = tf.keras.Sequential([
            layers.MaxPool2D(pool_size=2),
            _conv_block(8, 5, padding='same'),
            layers.MaxPool2D(pool_size=2),
            _conv_block(16, 5, padding='same'),
            layers.MaxPool2D(pool_size=2),
            _conv_block(32, 5, padding='same')
        ])

        self.global_max_pool_feat = layers.GlobalMaxPooling2D(keepdims=True)
        self.global_avg_pool_feat = layers.GlobalAveragePooling2D(keepdims=True)
        self.global_max_pool_seg = layers.MaxPooling2D(pool_size=(self.input_height // 8, self.input_width // 8))
        self.global_avg_pool_seg = layers.AveragePooling2D(pool_size=(self.input_height // 8, self.input_width // 8))

        self.fc = layers.Dense(units=1)

        self.volume_lr_multiplier_layer = GradientMultiplyLayer()
        self.glob_max_lr_multiplier_layer = GradientMultiplyLayer()
        self.glob_avg_lr_multiplier_layer = GradientMultiplyLayer()

    def set_gradient_multipliers(self, multiplier):
        self.volume_lr_multiplier_mask = tf.ones((1,)) * multiplier
        self.glob_max_lr_multiplier_mask = tf.ones((1,)) * multiplier
        self.glob_avg_lr_multiplier_mask = tf.ones((1,)) * multiplier

    def call(self, input, training=False):
        volume = self.volume(input, training=training)
        seg_mask = self.seg_mask(volume, training=training)

        # 表示用
        self.volume_ = volume
        self.seg_mask_ = seg_mask

        cat = tf.concat([volume, seg_mask], axis=-1)

        cat = self.volume_lr_multiplier_layer(cat, self.volume_lr_multiplier_mask)

        features = self.extractor(cat, training=training)
        global_max_feat = self.global_max_pool_feat(features)
        global_avg_feat = self.global_avg_pool_feat(features)
        global_max_seg = self.global_max_pool_seg(seg_mask)
        global_avg_seg = self.global_avg_pool_seg(seg_mask)

        global_max_seg = self.glob_max_lr_multiplier_layer(global_max_seg, self.glob_max_lr_multiplier_mask)
        global_avg_seg = self.glob_avg_lr_multiplier_layer(global_avg_seg, self.glob_avg_lr_multiplier_mask)
        

        fc_in = tf.concat([global_max_feat, global_avg_feat, global_max_seg, global_avg_seg], axis=-1)
        fc_in = tf.reshape(fc_in, [fc_in.shape[0], -1])
        prediction = self.fc(fc_in)

        return prediction, seg_mask

class GradientMultiplyLayer(layers.Layer):
    def call(self, input, mask_bw):
        return input * mask_bw


## End2End

In [3]:
# @title end2end.py 

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
#from models import SegDecNet
import numpy as np
import os
import tensorflow as tf
import utils
import time
import datetime
import pandas as pd
from data.dataset_catalog import get_dataset
import random
import cv2
from config import Config
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras import backend as K

LVL_ERROR = 10
LVL_INFO = 5
LVL_DEBUG = 1

LOG = 1  # Will log all mesages with lvl greater than this
SAVE_LOG = True

WRITE_TENSORBOARD = False

class End2End:
    def __init__(self, cfg: Config):
        self.cfg: Config = cfg
        self.storage_path: str = os.path.join(self.cfg.RESULTS_PATH, self.cfg.DATASET)

    def _log(self, message, lvl=LVL_INFO):
        n_msg = f"{self.run_name} {message}"
        if lvl >= LOG:
            print(n_msg)

    def train(self):
        self._set_results_path()
        self._create_results_dirs()
        self.print_run_params()
        if self.cfg.REPRODUCIBLE_RUN:
            self._log("Reproducible run, fixing all seeds to:1337", LVL_DEBUG)
            np.random.seed(1337)
            tf.random.set_seed(1337)
            random.seed(1337)

        device = self._get_device()
        model = self._get_model()
        self.optimizer = self._get_optimizer(model)
        loss_seg, loss_dec = self._get_loss(True), self._get_loss(False)

        train_loader = get_dataset("TRAIN", self.cfg)
        validation_loader = get_dataset("VAL", self.cfg)

        # get_dataset関数でKSDDDatasetのインスタンスを取得
        # train_dataset = get_dataset("TRAIN", self.cfg)
        # validation_dataset = get_dataset("VAL", self.cfg)

        # tf.data.Datasetのインスタンスを作成
        # train_loader = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
        # validation_loader = validation_dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tensorboard_writer = summary_ops_v2.create_file_writer(self.tensorboard_path) if WRITE_TENSORBOARD else None

        train_results = self._train_model(device, model, train_loader, loss_seg, loss_dec, self.optimizer, validation_loader, tensorboard_writer)
        self._save_train_results(train_results)
        self._save_model(model, os.path.join(self.model_path, "final_state.tf"))

        self.eval(model, device, self.cfg.SAVE_IMAGES, False, False)

        self._save_params()

    def eval(self, model, save_images, plot_seg, reload_final):
        if reload_final:
            model.load_weights(os.path.join(self.model_path, "final_state_dict.tf"))
        test_loader = get_dataset("TEST", self.cfg)
        self.eval_model(model, test_loader, save_folder=self.outputs_path, save_images=save_images, is_validation=False, plot_seg=plot_seg)
    
    def training_iteration(self, data, model, criterion_seg, criterion_dec, optimizer, weight_loss_seg, weight_loss_dec, iter_index):
        images, seg_masks, seg_loss_masks, is_segmented, _, _, _ = data

        batch_size = self.cfg.BATCH_SIZE
        memory_fit = self.cfg.MEMORY_FIT  # Not supported yet for >1

        num_subiters = int(batch_size / memory_fit)
        
        total_loss = 0
        total_correct = 0

        total_loss_seg = 0
        total_loss_dec = 0

        
        for sub_iter in range(num_subiters):
            with tf.GradientTape() as tape:    
                images_ = images[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                seg_masks_ = seg_masks[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                seg_loss_masks_ = seg_loss_masks[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                is_pos_ = tf.reshape(tf.reduce_max(seg_masks_), (memory_fit, 1))

                decision, output_seg_mask = model(images_, training=True)

                if is_segmented[sub_iter]:
                    if self.cfg.WEIGHTED_SEG_LOSS:

                        loss_seg = tf.reduce_mean(criterion_seg(output_seg_mask, seg_masks_)[:,:,:,np.newaxis] * seg_loss_masks_)
                         
                    else:
                        loss_seg = criterion_seg(output_seg_mask, seg_masks_)[:,:,:,np.newaxis]
                    loss_dec = criterion_dec(decision, is_pos_)

                    total_loss_seg += loss_seg#.numpy().item()
                    total_loss_dec += loss_dec#.numpy().item()
                    
                    total_correct += tf.reduce_sum(tf.cast(tf.equal(decision.numpy() > 0.0, tf.cast(tf.cast(is_pos_.numpy(), dtype=np.int32), tf.bool)), dtype=np.float32))
                    loss = weight_loss_seg * loss_seg + weight_loss_dec * loss_dec
                else:
                    loss_dec = criterion_dec(decision, is_pos_)
                    total_loss_dec += loss_dec#.numpy().item()

                    total_correct += tf.reduce_sum(tf.cast(tf.equal(decision.numpy() > 0.0, tf.cast(tf.cast(is_pos_.numpy(), dtype=np.int32), tf.bool)), dtype=np.float32))
                    loss = weight_loss_dec * loss_dec

                total_loss += loss#.numpy().item()

            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))    
        

        return total_loss_seg, total_loss_dec, total_loss, total_correct


    def _train_model(self, device, model, train_loader, criterion_seg, criterion_dec, optimizer, validation_set, tensorboard_writer):
        losses = []
        validation_data = []
        max_validation = -1
        validation_step = self.cfg.VALIDATION_N_EPOCHS

        num_epochs = self.cfg.EPOCHS
        
        samples_per_epoch = tf.data.experimental.cardinality(train_loader).numpy() * self.cfg.BATCH_SIZE   

        self.set_dec_gradient_multiplier(model, 0.0)

        for epoch in range(num_epochs):

            self.print_epoch(epoch)

            weight_loss_seg, weight_loss_dec = self.get_loss_weights(epoch)

            dec_gradient_multiplier = self.get_dec_gradient_multiplier()
            self.set_dec_gradient_multiplier(model, dec_gradient_multiplier)
            
            epoch_loss_seg, epoch_loss_dec, epoch_loss = 0, 0, 0
            epoch_correct = 0

            for iter_index, (data) in enumerate(train_loader):

                curr_loss_seg, curr_loss_dec, curr_loss, correct = self.training_iteration(data, model,
                                                                                        criterion_seg,
                                                                                        criterion_dec,
                                                                                        optimizer, weight_loss_seg,
                                                                                        weight_loss_dec, iter_index)

                epoch_loss_seg += curr_loss_seg
                epoch_loss_dec += curr_loss_dec
                epoch_loss += curr_loss

                # loss表示
                self.print_loss(epoch_loss_seg, epoch_loss_dec)

                epoch_correct += correct

            if epoch % 5 == 0:

                self._save_model(model, os.path.join(self.model_path, f"ep_{epoch:02}.tf"))
    
            epoch_loss_seg = epoch_loss_seg / samples_per_epoch
            epoch_loss_dec = epoch_loss_dec / samples_per_epoch
            epoch_loss = epoch_loss / samples_per_epoch
            losses.append((epoch_loss_seg, epoch_loss_dec, epoch_loss, epoch))

            if self.cfg.VALIDATE and (epoch % validation_step == 0 or epoch == num_epochs - 1):
                validation_ap, validation_accuracy = self.eval_model(model, validation_set, None, False, True, False)
                validation_data.append((validation_ap, epoch))

                if validation_ap > max_validation:
                    max_validation = validation_ap
                    self._save_model(model, os.path.join(self.model_path, "best_state_dict.tf")) # .h5

        return losses, validation_data


    def eval_model(self, model, eval_loader, save_folder, save_images, is_validation, plot_seg):

        res = []
        predictions, ground_truths = [], []

        for data_point in eval_loader:

            image, seg_mask, seg_loss_mask, _, sample_name, _, _ = data_point
            is_pos = tf.reduce_max(seg_mask) > 0
            prediction, pred_seg = model(image, training=False)

            pred_seg = tf.nn.sigmoid(pred_seg)
            prediction = tf.nn.sigmoid(prediction)

            prediction = prediction.numpy().item()
            image = image.numpy()
            pred_seg = pred_seg.numpy()
            seg_mask = seg_mask.numpy()

            predictions.append(prediction)
            ground_truths.append(is_pos)
            res.append((prediction, None, None, is_pos, sample_name[0]))
            if not is_validation and save_images:
                # image saving code here
                plt.imsave(f"{save_folder}/{sample_name[0]}_prediction.png", prediction[0, :, :, 0], cmap='gray')
                plt.imsave(f"{save_folder}/{sample_name[0]}_image.png", image[0, :, :, 0], cmap='gray')
                plt.imsave(f"{save_folder}/{sample_name[0]}_pred_seg.png", pred_seg[0, :, :, 0], cmap='gray')
                plt.imsave(f"{save_folder}/{sample_name[0]}_seg_mask.png", seg_mask[0, :, :], cmap='gray')

        if is_validation:
            self.print_eval(ground_truths, predictions)
            metrics = utils.get_metrics(np.array(ground_truths), np.array(predictions))
            return metrics["AP"], metrics["accuracy"]
            
        else:
            utils.evaluate_metrics(res, self.run_path, self.run_name)


    def get_dec_gradient_multiplier(self):
        if self.cfg.GRADIENT_ADJUSTMENT:
            grad_m = 0
        else:
            grad_m = 1

        return grad_m

    def set_dec_gradient_multiplier(self, model, multiplier):
        # This function is not applicable in TensorFlow as it is specific to PyTorch.
        model.set_gradient_multipliers(multiplier)

    def get_loss_weights(self, epoch):
        total_epochs = float(self.cfg.EPOCHS)

        if self.cfg.DYN_BALANCED_LOSS:
            seg_loss_weight = 1 - (epoch / total_epochs)
            dec_loss_weight = self.cfg.DELTA_CLS_LOSS * (epoch / total_epochs)
        else:
            seg_loss_weight = 1
            dec_loss_weight = self.cfg.DELTA_CLS_LOSS

        return tf.constant(seg_loss_weight), tf.constant(dec_loss_weight)

    def reload_model(self, model, load_final=False):
        if self.cfg.USE_BEST_MODEL:
            path = os.path.join(self.model_path, "best_state_dict.tf")# .h5
            model.load_weights(path)
        elif load_final:
            path = os.path.join(self.model_path, "final_state_dict.tf")# .h5
            model.load_weights(path)

    def _save_params(self):
        params = self.cfg.get_as_dict()
        params_lines = sorted(map(lambda e: e[0] + ":" + str(e[1]) + "\n", params.items()))
        fname = os.path.join(self.run_path, "run_params.txt")
        with open(fname, "w+") as f:
            f.writelines(params_lines)

    def _save_train_results(self, results):
        losses, validation_data = results
        ls, ld, l, le = map(list, zip(*losses))
        plt.plot(le, l, label="Loss", color="red")
        plt.plot(le, ls, label="Loss seg")
        plt.plot(le, ld, label="Loss dec")
        plt.ylim(bottom=0)
        plt.grid()
        plt.xlabel("Epochs")
        if self.cfg.VALIDATE:
            v, ve = map(list, zip(*validation_data))
            plt.twinx()
            plt.plot(ve, v, label="Validation AP", color="Green")
            plt.ylim((0, 1))
        plt.legend()
        plt.savefig(os.path.join(self.run_path, "loss_val"), dpi=200)

        df_loss = pd.DataFrame(data={"loss_seg": ls, "loss_dec": ld, "loss": l, "epoch": le})
        df_loss.to_csv(os.path.join(self.run_path, "losses.csv"), index=False)

        if self.cfg.VALIDATE:
            df_loss = pd.DataFrame(data={"validation_data": v, "epoch": ve})
            df_loss.to_csv(os.path.join(self.run_path, "validation.csv"), index=False)

    def _save_model(self, model, output_name):
        print(output_name)
        if os.path.isfile(output_name):
            os.remove(output_name)

        model.save_weights(output_name)

    def _get_optimizer(self, model):
        return tf.keras.optimizers.SGD(learning_rate=self.cfg.LEARNING_RATE)

    def _get_loss(self, is_seg):
        reduction = tf.keras.losses.Reduction.NONE if self.cfg.WEIGHTED_SEG_LOSS and is_seg else tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
        return tf.keras.losses.BinaryCrossentropy(reduction=reduction)


    def _get_device(self):
        return tf.device(f"/GPU:{self.cfg.GPU}")

    def _set_results_path(self):
        self.run_name = f"{self.cfg.RUN_NAME}_FOLD_{self.cfg.FOLD}" if self.cfg.DATASET in ["KSDD", "DAGM"] else self.cfg.RUN_NAME

        results_path = os.path.join(self.cfg.RESULTS_PATH, self.cfg.DATASET)
        self.tensorboard_path = os.path.join(results_path, "tensorboard", self.run_name)

        run_path = os.path.join(results_path, self.cfg.RUN_NAME)
        if self.cfg.DATASET in ["KSDD", "DAGM"]:
            run_path = os.path.join(run_path, f"FOLD_{self.cfg.FOLD}")

        self._log(f"Executing run with path {run_path}")

        self.run_path = run_path
        self.model_path = os.path.join(run_path, "models")
        self.outputs_path = os.path.join(run_path, "test_outputs")

    def _create_results_dirs(self):
        list(map(utils.create_folder, [self.run_path, self.model_path, self.outputs_path, ]))

    def _get_model(self):
        seg_net = SegDecNet(self._get_device(), self.cfg.INPUT_WIDTH, self.cfg.INPUT_HEIGHT, self.cfg.INPUT_CHANNELS)
        return seg_net

    def print_run_params(self):
        for l in sorted(map(lambda e: e[0] + ":" + str(e[1]) + "\n", self.cfg.get_as_dict().items())):
            k, v = l.split(":")
            print(f"{k:25s} : {str(v.strip())}")

    def print_epoch(self, epoch):
        t_now = datetime.datetime.now().time()
        print(f'epoch:{epoch}>{t_now}')

    def print_eval(self, ground_truths, predictions):
        print(f'predictions:{predictions}')

    def print_loss(self, epoch_loss_seg, epoch_loss_dec):
        print(f'epoch_loss_seg:{epoch_loss_seg}, epoch_loss_dec:{epoch_loss_dec}')



以下はpyファイルの方を用いる

## train_net.py

In [4]:
# Colab用args
import argparse
from config import Config

class args():
    def __init__(self):
        self.GPU=0
        self.RUN_NAME='RUN_NAME'
        self.DATASET='KSDD'
        self.DATASET_PATH='/content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net/datasets/KSDD'
        self.EPOCHS=50
        self.LEARNING_RATE=1.0
        self.DELTA_CLS_LOSS = 0.01
        self.BATCH_SIZE=1
        self.WEIGHTED_SEG_LOSS=True
        self.WEIGHTED_SEG_LOSS_P=2
        self.WEIGHTED_SEG_LOSS_MAX=1
        self.DYN_BALANCED_LOSS=True
        self.GRADIENT_ADJUSTMENT=False # Trueを変更
        self.FREQUENCY_SAMPLING=True
        self.FOLD=0
        self.TRAIN_NUM=33
        self.NUM_SEGMENTED=33
        self.RESULTS_PATH = "/content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net/datasets/RESULTS"
        
        self.VALIDATE = True
        self.VALIDATE_ON_TEST = True
        self.VALIDATION_N_EPOCHS = 5
        self.USE_BEST_MODEL = False

        self.ON_DEMAND_READ = False
        self.REPRODUCIBLE_RUN = False
        self.MEMORY_FIT = 1
        self.SAVE_IMAGES = True
        self.DILATE = 7

args=args()
configuration = Config()
configuration.merge_from_args(args)
configuration.init_extra()

end2end = End2End(cfg=configuration)

# ここを実行する
end2end.train()

RUN_NAME_FOLD_0 Executing run with path /content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net/datasets/RESULTS/KSDD/RUN_NAME/FOLD_0
BATCH_SIZE                : 1
DATASET                   : KSDD
DATASET_PATH              : /content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net/datasets/KSDD
DELTA_CLS_LOSS            : 0.01
DILATE                    : 7
DYN_BALANCED_LOSS         : True
EPOCHS                    : 50
FOLD                      : 0
FREQUENCY_SAMPLING        : True
GPU                       : 0
GRADIENT_ADJUSTMENT       : False
INPUT_CHANNELS            : 1
INPUT_HEIGHT              : 1408
INPUT_WIDTH               : 512
LEARNING_RATE             : 1.0
MEMORY_FIT                : 1
NUM_SEGMENTED             : 33
ON_DEMAND_READ            : False
REPRODUCIBLE_RUN          : False
RESULTS_PATH              : /content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net/datasets/RESULTS
SAVE_IMAGES               : True
TRAIN_NUM                



epoch_loss_seg:-1.6933138224928257e-10, epoch_loss_dec:-4.214300155639648
epoch_loss_seg:-237.92904663085938, epoch_loss_dec:37.19872283935547
epoch_loss_seg:-713.7871704101562, epoch_loss_dec:124.48521423339844
epoch_loss_seg:-1427.5743408203125, epoch_loss_dec:255.6753387451172
epoch_loss_seg:-2379.29052734375, epoch_loss_dec:430.4193115234375
epoch_loss_seg:-3568.935791015625, epoch_loss_dec:649.6659545898438
epoch_loss_seg:-4996.51025390625, epoch_loss_dec:913.7020263671875
epoch_loss_seg:-6571.693359375, epoch_loss_dec:621.6930541992188
epoch_loss_seg:-8462.240234375, epoch_loss_dec:829.6842651367188
epoch_loss_seg:-10590.716796875, epoch_loss_dec:1082.5855712890625
epoch_loss_seg:-12957.1220703125, epoch_loss_dec:1379.56787109375
epoch_loss_seg:-15561.45703125, epoch_loss_dec:1722.5341796875
epoch_loss_seg:-18253.78125, epoch_loss_dec:1352.767822265625
epoch_loss_seg:-21321.32421875, epoch_loss_dec:1721.66455078125
epoch_loss_seg:-24626.796875, epoch_loss_dec:2144.6728515625
epoc

AttributeError: ignored

In [None]:
import os
output_name='/content/drive/MyDrive/14_ブラザー工業様/tensorflow_ver_mixed-segdec-net/datasets/RESULTS/KSDD/RUN_NAME/FOLD_0/models/best_state_dict.h5'

os.remove(output_name)

In [None]:
def downsize(image: np.ndarray, downsize_factor: tuple = (8, 8)) -> np.ndarray:
        img_t = tf.convert_to_tensor(np.expand_dims(image, 0 if len(image.shape) == 3 else (0, 1)).astype(np.float32))
        pad_size = min(downsize_factor[0], img_t.shape[1] - 1, img_t.shape[2] - 1)
        print(img_t.shape)
        img_t = tf.pad(img_t, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], mode='REFLECT')
        print(img_t.shape)
        image_np = tf.nn.avg_pool(img_t, ksize=(2 * pad_size + 1, 2 * pad_size + 1), strides=downsize_factor, padding='VALID')
        print(image_np[0,0].shape)
        return image_np[0] if len(image.shape) == 3 else image_np[0, 0]

def to_tensor(x):
        if x.dtype != np.float32:
            x = (x / 255.0).astype(np.float32)   

        if len(x.shape) != 3:
            x = np.expand_dims(x, axis=-1)

        x = tf.convert_to_tensor(x)
        return x

im = np.arange((720896))
im = np.reshape(im, (1408, 512))
 
seg_im = downsize(to_tensor(im), downsize_factor=(8,8))[np.newaxis,:,:,:]
seg_im.shape

In [None]:
tf.reshape(tf.reduce_max(seg_im), dtype=tf.bool)

In [None]:
is_p = tf.cast(tf.reshape(tf.reduce_max(seg_im), (1, 1)), dtype=tf.bool)
is_p

In [None]:
is_p = tf.reshape(tf.reduce_max(seg_im), (1, 1))
is_p

In [None]:
decision = [[[[0.]]]]
tf.equal(decision > 0.0, is_p)

In [None]:
decision = 1
tf.reduce_sum(tf.cast(tf.equal(decision > 0.0, is_p), tf.float32)).numpy()

In [None]:
def downsize(image: np.ndarray, downsize_factor: int = 8) -> np.ndarray:
    img_t = tf.convert_to_tensor(np.expand_dims(image, 0 if len(image.shape) == 3 else (0, -1)).astype(np.float32))
    print(img_t.shape)
    pad_size = min(downsize_factor, img_t.shape[1] - 1, img_t.shape[2] - 1)
    img_t = tf.pad(img_t, [[0, 0], [pad_size, pad_size], [pad_size, pad_size], [0, 0]], mode='REFLECT')
    print(img_t.shape)
    image_np = tf.nn.avg_pool(img_t, ksize=2 * pad_size + 1, strides=downsize_factor, padding='VALID')
    print(image_np[0,0].shape)
    return image_np[0] if len(image.shape) == 3 else image_np[0, 0]

im = np.zeros((720896))
im = np.reshape(im, (1408, 512))
 
to_tensor(downsize(im, downsize_factor=8)).shape

In [None]:
im = np.zeros((720896))
im = np.reshape(im, (1408, 512))

img_t = tf.convert_to_tensor(np.expand_dims(im, -1 if len(im.shape) == 3 else (0, -1)))

img_t.shape

In [None]:
# ここからテスト--------------------------
from data.input_ksdd import KSDDDataset
from data.dataset import Dataset
import pickle

kind = "TRAIN"
ds = KSDDDataset(configuration.DATASET_PATH, configuration, kind)
dtst = Dataset(configuration.DATASET_PATH, configuration, kind)

print(type(ds))

shuffle = kind == "TRAIN"
batch_size = configuration.BATCH_SIZE if kind == "TRAIN" else 1

dataset = ds._data.shuffle(buffer_size=ds.length)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

train_num =33
num_segmented =33
fold = 0

fn = f"splits/KSDD/split_{train_num}_{num_segmented}.pyb"
with open(f"{fn}", "rb") as f:
    train_samples, test_samples = pickle.load(f)
    if kind == 'TRAIN':
        folders = train_samples[fold]
    elif kind == 'TEST':
        folders = test_samples[fold]
    else:
        raise Exception('Unknown')

pos_samples, neg_samples = [], []
for f, is_segmented in folders:
    for sample in sorted(os.listdir(os.path.join(configuration.DATASET_PATH, f))):
        if not sample.__contains__('label'):
            image_path = configuration.DATASET_PATH + '/' + f + '/' + sample
            seg_mask_path = f"{image_path[:-4]}_label.bmp"
            image = dtst.read_img_resize(image_path, configuration.INPUT_CHANNELS, (configuration.INPUT_WIDTH, configuration.INPUT_HEIGHT))
            seg_mask, positive = dtst.read_label_resize(seg_mask_path, (configuration.INPUT_WIDTH, configuration.INPUT_HEIGHT), dilate=configuration.DILATE)
            sample_name = f"{f}_{sample}"[:-4]
            if sample_name == 'kos21_Part7':
                continue
            if positive:
                image = dtst.to_tensor(image)
                seg_loss_mask = dtst.distance_transform(seg_mask, configuration.WEIGHTED_SEG_LOSS_MAX, configuration.WEIGHTED_SEG_LOSS_P)
                seg_loss_mask = dtst.to_tensor(dtst.downsize(seg_loss_mask))
                seg_mask = dtst.to_tensor(dtst.downsize(seg_mask))
                pos_samples.append((image, seg_mask, seg_loss_mask, is_segmented, image_path, seg_mask_path, sample_name))
            else:
                image = dtst.to_tensor(image)
                seg_loss_mask = dtst.to_tensor(dtst.downsize(np.ones_like(seg_mask)))
                seg_mask = dtst.to_tensor(dtst.downsize(seg_mask))
                neg_samples.append((image, seg_mask, seg_loss_mask, True, image_path, seg_mask_path, sample_name))

# 長さ
print(len(pos_samples)) # 34
print(len(neg_samples)) # 230
_data = pos_samples + neg_samples
print(len(_data)) # 264

# len(_data)の長さ取れてる！！！！！

# ここまでテスト--------------------------

In [None]:
ds._data[0]

In [None]:
# データの準備
data_1 = tf.constant([1.0, 2.0, 3.0])  # テンソルに変換可能なデータ
data_2 = tf.constant([0.1, 0.2, 0.3])  # テンソルに変換可能なデータ
data_3 = [True, False, True]  # ブール型のリスト
data_4 = ['a', 'b', 'c']  # 文字列のリスト

# データセットの作成
dataset_1 = tf.data.Dataset.from_tensor_slices(data_1)
dataset_2 = tf.data.Dataset.from_tensor_slices(data_2)
dataset_3 = tf.data.Dataset.from_tensor_slices(data_3)
dataset_4 = tf.data.Dataset.from_tensor_slices(data_4)

# データセットを結合
dataset = tf.data.Dataset.zip((dataset_1, dataset_2, dataset_3, dataset_4))

# 結果を確認
for data in dataset:
    print(data)


In [None]:
print(dataset_1)
print(dataset_1+dataset_2)

In [None]:
dataset = dataset.shuffle(buffer_size=100)
dataset = dataset.batch(33)
dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

In [None]:
for data in dataset:

    # 4つ
    print(data)

In [None]:
tf.data.Dataset.prefetch 

# 以下はpyファイル

## evaluation.py

In [None]:
import os
import numpy as np
import pandas as pd
import utils
from sklearn.metrics import precision_recall_curve, roc_curve, auc, average_precision_score
import shutil
import pickle


def get_performance_eval(P, Y, names, data_dir, output_dir, folds, prefix='', thresholds_tups=None, save=True):
    metrics = {}
    precision_, recall_, thresholds = precision_recall_curve(Y.astype(np.int32), P)
    metrics['precision'] = precision_
    metrics['recall'] = recall_
    metrics['thresholds'] = thresholds

    FPR, TPR, _ = roc_curve(Y.astype(np.int32), P)
    AUC = auc(FPR, TPR)
    AP = average_precision_score(Y.astype(np.int32), P)
    metrics['FPR'] = FPR
    metrics['TPR'] = TPR
    metrics['AUC'] = AUC

    f_measures = 2 * (precision_ * recall_) / (precision_ + recall_ + 0.0000000001)
    metrics['f_measures'] = f_measures

    thresholds_metrics = {}
    ix_best = np.argmax(f_measures)
    if ix_best > 0:
        best_threshold = (thresholds[ix_best] + thresholds[ix_best - 1]) / 2
    else:
        best_threshold = thresholds[ix_best]
    fn0_threshold = thresholds[np.where(recall_ >= 1)][0]

    for thresh, name, dir in zip([best_threshold, fn0_threshold, 0.5], ['best', 'fn0', '50_perc'], ['best', 'fn0', '50_perc']):
        FN, FP, TN, TP = get_and_copy_falses(P, Y, thresh, data_dir, folds, names, os.path.join(output_dir, dir), prefix, save=save)
        F_measure = (2 * TP.sum()) / float(2 * TP.sum() + FP.sum() + FN.sum())
        thresholds_metrics[name] = {}
        thresholds_metrics[name]['value'] = thresh
        thresholds_metrics[name]['TP'] = TP.sum()
        thresholds_metrics[name]['TN'] = TN.sum()
        thresholds_metrics[name]['FP'] = FP.sum()
        thresholds_metrics[name]['FN'] = FN.sum()
        thresholds_metrics[name]['F_measure'] = F_measure

    metrics['AP'] = AP
    metrics['thresholds'] = thresholds_metrics

    if save:
        for thr, d in metrics['thresholds'].items():
            print(f'THRESHOLD {prefix} {thr:>15} => VALUE={d["value"]:.4f}, FP={d["FP"]} FN={d["FN"]}, AP={AP}')

    return metrics


def get_and_copy_falses(P, Y, best_threshold, data_dir, folds, names, output_dir, prefix, save=True):
    FP, FN, TN, TP = utils.calc_confusion_mat(P >= best_threshold, Y)
    # find FN and FP examples and copy them to folders
    if save:
        if not os.path.exists(output_dir):
            utils.create_folder(output_dir)
        FN_names = [(folds[i], names[i] + ".jpg") for i in range(0, len(names)) if FN[i]]
        FP_names = [(folds[i], names[i] + ".jpg") for i in range(0, len(names)) if FP[i]]
        copy_falses(FN_names, data_dir, output_dir, prefix)
        copy_falses(FP_names, data_dir, output_dir, prefix, is_FN=False)
    return FN, FP, TN, TP


def copy_falses(names, data_dir, output_dir, prefix, is_FN=True):
    for fold, n in names:
        outputs_folder = os.path.join(data_dir, f'FOLD_{fold}', prefix + 'outputs')
        f_name = list(filter(lambda s: s.endswith(n), os.listdir(outputs_folder)))
        if len(f_name) > 0:
            f_name = f_name[0]
            acc = f_name[:5]
            src_file = os.path.join(outputs_folder, f_name)
            dst_file = os.path.join(output_dir, f'{"FN" if is_FN else "FP"}_{acc}_{f_name[6:]}')

            try:
                shutil.copy(src_file, dst_file)
            except:
                print(f"error: cannot copy file {n}")


def evaluate_decision(run_dir, folds, ground_truth, img_names, predictions, prefix='', output_dir=None, thresholds=None, save=True):
    if output_dir is None:
        output_dir = run_dir

    metrics = get_performance_eval(predictions, ground_truth, img_names, run_dir, output_dir, folds, prefix=prefix, thresholds_tups=thresholds, save=save)

    best_tr_metrics = metrics['thresholds']['best']
    tp_sum = best_tr_metrics['TP']
    fp_sum = best_tr_metrics['FP']
    fn_sum = best_tr_metrics['FN']
    tn_sum = best_tr_metrics['TN']
    AP = metrics['AP']
    fp_0fn_sum = metrics['thresholds']['fn0']['FP']

    if save:
        print(f"AP: {AP:.03f}, FP/FN: {fp_sum:d}/{fn_sum:d}, FP@FN=0: {fp_0fn_sum:d}")

        with open(os.path.join(output_dir, prefix + 'accuracy.txt'), 'w') as f:
            f.write(f"TP= {tp_sum}\tFP={fp_sum}\n")
            f.write(f"FN= {fn_sum}\tTN={tn_sum}")

        with open(os.path.join(output_dir, f'{prefix}metrics.pkl'), 'wb') as f:
            pickle.dump(metrics, f)
            f.close()

    return metrics


def evaluate_fold(results_folder, t_folds, t_gt, t_img_names, t_preds):
    m_test = evaluate_decision(results_folder, t_folds, t_gt, t_img_names, t_preds, prefix='', output_dir=os.path.join(results_folder), save=False)

    thresholds = m_test['thresholds']
    best = thresholds['best']
    t50 = thresholds['50_perc']
    fn0 = thresholds['fn0']

    cls_acc = (best["TP"] + best["TN"]) / (best["TP"] + best["TN"] + best["FP"] + best["FN"])
    cls_acc_50 = (t50["TP"] + t50["TN"]) / (t50["TP"] + t50["TN"] + t50["FP"] + t50["FN"])

    tpr = best["TP"] / (best["TP"] + best["FN"])
    tnr = best["TN"] / (best["TN"] + best["FP"])

    eval_res = {"ap": (m_test['AP']),
                "auc": (m_test['AUC']),
                "fps": (best['FP']),
                "fns": (best['FN']),
                "best_t": (best['value']),
                "t50_fps": (t50['FP']),
                "t50_fns": (t50['FN']),
                "fn0s": (fn0['FP']),
                "fn0_t": (fn0['value']),
                "f_measure": (best["F_measure"]),
                "cls_acc": cls_acc,
                "f_measure_50": t50["F_measure"],
                "cls_acc_50": cls_acc_50,
                "tpr": tpr,
                "tnr": tnr}
    return eval_res


def read_predictions(fold, prefix, run_dir):
    predictions, decisions, ground_truth, img_names, folds = [], [], [], [], []
    if fold is not None:
        fold_path = os.path.join(run_dir, 'FOLD_{}'.format(fold), prefix + 'results.csv')
        decisions, folds, ground_truth, img_names, predictions = read_directory(decisions, fold, fold_path, folds, ground_truth, img_names, predictions)
    else:
        results_path = os.path.join(run_dir, prefix + 'results.csv')
        decisions, folds, ground_truth, img_names, predictions = read_directory(decisions, 0, results_path, folds, ground_truth, img_names, predictions)
    img_names = list(map(str, img_names))
    predictions, decisions, ground_truth, img_names, folds = list(map(np.array, [predictions, decisions, ground_truth, img_names, folds]))

    valid_idx = (img_names != 'kos21_Part7')
    predictions = predictions[valid_idx]
    decisions = decisions[valid_idx]
    ground_truth = ground_truth[valid_idx]
    img_names = img_names[valid_idx]
    folds = folds[valid_idx]

    return decisions, folds, ground_truth, img_names, predictions


def read_directory(decisions, f, fold_path, folds, ground_truth, img_names, predictions):
    csv = pd.read_csv(fold_path)
    n_samples_in_fold = len(list(csv['prediction']))
    predictions = predictions + list(csv['prediction'])
    decisions = decisions + list(csv['decision'])
    ground_truth = ground_truth + list(csv['ground_truth'])
    img_names = img_names + list(csv['img_name'])
    folds = folds + ([f] * n_samples_in_fold)
    return decisions, folds, ground_truth, img_names, predictions

## config.py

In [None]:
class Config:
    GPU = None

    RUN_NAME = None

    DATASET = None  # KSDD, DAGM, STEEL, KSDD2
    DATASET_PATH = None

    EPOCHS = None

    LEARNING_RATE = None
    DELTA_CLS_LOSS = None

    BATCH_SIZE = None

    WEIGHTED_SEG_LOSS = None
    WEIGHTED_SEG_LOSS_P = None
    WEIGHTED_SEG_LOSS_MAX = None
    DYN_BALANCED_LOSS = None
    GRADIENT_ADJUSTMENT = None
    FREQUENCY_SAMPLING = True

    # Default values
    FOLD = None
    TRAIN_NUM = None
    NUM_SEGMENTED = None
    RESULTS_PATH = "./RESULTS" # TODO use when releasing
    # RESULTS_PATH = "/home/jakob/outputs/WEAKLY_LABELED/PC_DEBUG" if "CONTAINER_NODE" in os.environ else "/opt/workspace/host_storage_hdd/REWRITE/v2"
    SPLITS_PATH = None

    VALIDATE = True
    VALIDATE_ON_TEST = True
    VALIDATION_N_EPOCHS = 5
    USE_BEST_MODEL = False

    ON_DEMAND_READ = False
    REPRODUCIBLE_RUN = False
    MEMORY_FIT = 1
    SAVE_IMAGES = True
    DILATE = 1

    # Auto filled
    INPUT_WIDTH = None
    INPUT_HEIGHT = None
    INPUT_CHANNELS = None

    def init_extra(self):
        if self.WEIGHTED_SEG_LOSS and (self.WEIGHTED_SEG_LOSS_P is None or self.WEIGHTED_SEG_LOSS_MAX is None):
            raise Exception("You also need to specify p and scaling factor for weighted segmentation loss!")
        if self.NUM_SEGMENTED is None:
            raise Exception("Missing NUM_SEGMENTED!")
        if self.DATASET == 'KSDD':
            self.INPUT_WIDTH = 512
            self.INPUT_HEIGHT = 1408
            self.INPUT_CHANNELS = 1

            if self.TRAIN_NUM is None:
                raise Exception("Missing TRAIN_NUM for KSDD dataset!")
            if self.NUM_SEGMENTED is None:
                raise Exception("Missing NUM_SEGMENTED for KSDD dataset!")
            if self.FOLD is None:
                raise Exception("Missing FOLD for KSDD dataset!")

        elif self.DATASET == 'DAGM':
            self.INPUT_WIDTH = 512
            self.INPUT_HEIGHT = 512
            self.INPUT_CHANNELS = 1
            if self.NUM_SEGMENTED is None:
                raise Exception("Missing NUM_SEGMENTED for DAGM dataset!")
            if self.FOLD is None:
                raise Exception("Missing FOLD for DAGM dataset!")
        elif self.DATASET == 'STEEL':
            self.INPUT_WIDTH = 1600
            self.INPUT_HEIGHT = 256
            self.INPUT_CHANNELS = 1

            self.VALIDATE_ON_TEST = False
            self.USE_BEST_MODEL = True
            print("Will use best model according to validation loss, validation is not performed on test set!")
            if not self.ON_DEMAND_READ:
                print("Will use ON_DEMAND_READ even though it is set on False!")
                self.ON_DEMAND_READ = True
            if self.TRAIN_NUM is None:
                raise Exception("Missing TRAIN_NUM for STEEL dataset!")
            if self.NUM_SEGMENTED is None:
                raise Exception("Missing NUM_SEGMENTED for STEEL dataset!")
        elif self.DATASET == 'KSDD2':
            self.INPUT_WIDTH = 232
            self.INPUT_HEIGHT = 640
            self.INPUT_CHANNELS = 3
            if self.NUM_SEGMENTED is None:
                raise Exception("Missing NUM_SEGMENTED for KSDD2 dataset!")
        else:
            raise Exception('Unknown dataset {}'.format(self.DATASET))

    def merge_from_args(self, args):
        self.GPU = args.GPU
        self.RUN_NAME = args.RUN_NAME
        self.DATASET = args.DATASET
        self.DATASET_PATH = args.DATASET_PATH
        self.EPOCHS = args.EPOCHS
        self.LEARNING_RATE = args.LEARNING_RATE
        self.DELTA_CLS_LOSS = args.DELTA_CLS_LOSS
        self.BATCH_SIZE = args.BATCH_SIZE
        self.WEIGHTED_SEG_LOSS = args.WEIGHTED_SEG_LOSS
        self.WEIGHTED_SEG_LOSS_P = args.WEIGHTED_SEG_LOSS_P
        self.WEIGHTED_SEG_LOSS_MAX = args.WEIGHTED_SEG_LOSS_MAX
        self.DYN_BALANCED_LOSS = args.DYN_BALANCED_LOSS
        self.GRADIENT_ADJUSTMENT = args.GRADIENT_ADJUSTMENT
        self.FREQUENCY_SAMPLING = args.FREQUENCY_SAMPLING
        self.NUM_SEGMENTED = args.NUM_SEGMENTED

        if args.FOLD is not None: self.FOLD = args.FOLD
        if args.TRAIN_NUM is not None: self.TRAIN_NUM = args.TRAIN_NUM
        if args.RESULTS_PATH is not None: self.RESULTS_PATH = args.RESULTS_PATH
        if args.VALIDATE is not None: self.VALIDATE = args.VALIDATE
        if args.VALIDATE_ON_TEST is not None: self.VALIDATE_ON_TEST = args.VALIDATE_ON_TEST
        if args.VALIDATION_N_EPOCHS is not None: self.VALIDATION_N_EPOCHS = args.VALIDATION_N_EPOCHS
        if args.USE_BEST_MODEL is not None: self.USE_BEST_MODEL = args.USE_BEST_MODEL
        if args.ON_DEMAND_READ is not None: self.ON_DEMAND_READ = args.ON_DEMAND_READ
        if args.REPRODUCIBLE_RUN is not None: self.REPRODUCIBLE_RUN = args.REPRODUCIBLE_RUN
        if args.MEMORY_FIT is not None: self.MEMORY_FIT = args.MEMORY_FIT
        if args.SAVE_IMAGES is not None: self.SAVE_IMAGES = args.SAVE_IMAGES
        if args.DILATE is not None: self.DILATE = args.DILATE

    def get_as_dict(self):
        params = {
            "GPU": self.GPU,
            "DATASET": self.DATASET,
            "DATASET_PATH": self.DATASET_PATH,
            "EPOCHS": self.EPOCHS,
            "LEARNING_RATE": self.LEARNING_RATE,
            "DELTA_CLS_LOSS": self.DELTA_CLS_LOSS,
            "BATCH_SIZE": self.BATCH_SIZE,
            "WEIGHTED_SEG_LOSS": self.WEIGHTED_SEG_LOSS,
            "WEIGHTED_SEG_LOSS_P": self.WEIGHTED_SEG_LOSS_P,
            "WEIGHTED_SEG_LOSS_MAX": self.WEIGHTED_SEG_LOSS_MAX,
            "DYN_BALANCED_LOSS": self.DYN_BALANCED_LOSS,
            "GRADIENT_ADJUSTMENT": self.GRADIENT_ADJUSTMENT,
            "FREQUENCY_SAMPLING": self.FREQUENCY_SAMPLING,
            "FOLD": self.FOLD,
            "TRAIN_NUM": self.TRAIN_NUM,
            "NUM_SEGMENTED": self.NUM_SEGMENTED,
            "RESULTS_PATH": self.RESULTS_PATH,
            "VALIDATE": self.VALIDATE,
            "VALIDATE_ON_TEST": self.VALIDATE_ON_TEST,
            "VALIDATION_N_EPOCHS": self.VALIDATION_N_EPOCHS,
            "USE_BEST_MODEL": self.USE_BEST_MODEL,
            "ON_DEMAND_READ": self.ON_DEMAND_READ,
            "REPRODUCIBLE_RUN": self.REPRODUCIBLE_RUN,
            "MEMORY_FIT": self.MEMORY_FIT,
            "INPUT_WIDTH": self.INPUT_WIDTH,
            "INPUT_HEIGHT": self.INPUT_HEIGHT,
            "INPUT_CHANNELS": self.INPUT_CHANNELS,
            "SAVE_IMAGES": self.SAVE_IMAGES,
            "DILATE": self.DILATE,
        }
        return params


def load_from_dict(dictionary):
    cfg = Config()

    cfg.GPU = dictionary.get("GPU", None)
    cfg.DATASET = dictionary.get("DATASET", None)
    cfg.DATASET_PATH = dictionary.get("DATASET_PATH", None)
    cfg.EPOCHS = dictionary.get("EPOCHS", None)
    cfg.LEARNING_RATE = dictionary.get("LEARNING_RATE", None)
    cfg.DELTA_CLS_LOSS = dictionary.get("DELTA_CLS_LOSS", None)
    cfg.BATCH_SIZE = dictionary.get("BATCH_SIZE", None)
    cfg.WEIGHTED_SEG_LOSS = dictionary.get("WEIGHTED_SEG_LOSS", None)
    cfg.WEIGHTED_SEG_LOSS_P = dictionary.get("WEIGHTED_SEG_LOSS_P", None)
    cfg.WEIGHTED_SEG_LOSS_MAX = dictionary.get("WEIGHTED_SEG_LOSS_MAX", None)
    cfg.DYN_BALANCED_LOSS = dictionary.get("DYN_BALANCED_LOSS", None)
    cfg.GRADIENT_ADJUSTMENT = dictionary.get("GRADIENT_ADJUSTMENT", None)
    cfg.FREQUENCY_SAMPLING = dictionary.get("FREQUENCY_SAMPLING", None)
    cfg.FOLD = dictionary.get("FOLD", None)
    cfg.TRAIN_NUM = dictionary.get("TRAIN_NUM", None)
    cfg.NUM_SEGMENTED = dictionary.get("NUM_SEGMENTED", None)
    cfg.RESULTS_PATH = dictionary.get("RESULTS_PATH", None)
    cfg.VALIDATE = dictionary.get("VALIDATE", None)
    cfg.VALIDATE_ON_TEST = dictionary.get("VALIDATE_ON_TEST", None)
    cfg.VALIDATION_N_EPOCHS = dictionary.get("VALIDATION_N_EPOCHS", None)
    cfg.USE_BEST_MODEL = dictionary.get("USE_BEST_MODEL", None)
    cfg.ON_DEMAND_READ = dictionary.get("ON_DEMAND_READ", None)
    cfg.REPRODUCIBLE_RUN = dictionary.get("REPRODUCIBLE_RUN", None)
    cfg.MEMORY_FIT = dictionary.get("MEMORY_FIT", None)
    cfg.INPUT_WIDTH = dictionary.get("INPUT_WIDTH", None)
    cfg.INPUT_HEIGHT = dictionary.get("INPUT_HEIGHT", None)
    cfg.INPUT_CHANNELS = dictionary.get("INPUT_CHANNELS", None)
    cfg.SAVE_IMAGES = dictionary.get("SAVE_IMAGES", None)
    cfg.DILATE = dictionary.get("DILATE", None)

    return cfg

## utils.py

In [None]:
import matplotlib

matplotlib.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import precision_recall_curve, roc_curve, auc
import pandas as pd
import os
import errno
import pickle
import cv2


def create_folder(folder, exist_ok=True):
    try:
        os.makedirs(folder)
    except OSError as e:
        if e.errno != errno.EEXIST or not exist_ok:
            raise


def calc_confusion_mat(D, Y):
    FP = (D != Y) & (Y.astype(np.bool) == False)
    FN = (D != Y) & (Y.astype(np.bool) == True)
    TN = (D == Y) & (Y.astype(np.bool) == False)
    TP = (D == Y) & (Y.astype(np.bool) == True)

    return FP, FN, TN, TP


def plot_sample(image_name, image, segmentation, label, save_dir, decision=None, blur=True, plot_seg=False):
    plt.figure()
    plt.clf()
    plt.subplot(1, 4, 1)
    plt.xticks([])
    plt.yticks([])
    plt.title('Input image')
    if image.shape[0] < image.shape[1]:
        image = np.transpose(image, axes=[1, 0, 2])
        segmentation = np.transpose(segmentation)
        label = np.transpose(label)
    if image.shape[2] == 1:
        plt.imshow(image, cmap="gray")
    else:
        plt.imshow(image)

    plt.subplot(1, 4, 2)
    plt.xticks([])
    plt.yticks([])
    plt.title('Groundtruth')
    plt.imshow(label, cmap="gray")

    plt.subplot(1, 4, 3)
    plt.xticks([])
    plt.yticks([])
    if decision is None:
        plt.title('Output')
    else:
        plt.title(f"Output: {decision:.5f}")
    # display max
    vmax_value = max(1, np.max(segmentation))
    plt.imshow(segmentation, cmap="jet", vmax=vmax_value)

    plt.subplot(1, 4, 4)
    plt.xticks([])
    plt.yticks([])
    plt.title('Output scaled')
    if blur:
        normed = segmentation / segmentation.max()
        blured = cv2.blur(normed, (32, 32))
        plt.imshow((blured / blured.max() * 255).astype(np.uint8), cmap="jet")
    else:
        plt.imshow((segmentation / segmentation.max() * 255).astype(np.uint8), cmap="jet")

    out_prefix = '{:.3f}_'.format(decision) if decision is not None else ''

    plt.savefig(f"{save_dir}/{out_prefix}result_{image_name}.jpg", bbox_inches='tight', dpi=300)
    plt.close()

    if plot_seg:
        jet_seg = cv2.applyColorMap((segmentation * 255).astype(np.uint8), cv2.COLORMAP_JET)
        cv2.imwrite(f"{save_dir}/{out_prefix}_segmentation_{image_name}.png", jet_seg)


def evaluate_metrics(samples, results_path, run_name):
    samples = np.array(samples)

    img_names = samples[:, 4]
    predictions = samples[:, 0]
    labels = samples[:, 3].astype(np.float32)

    metrics = get_metrics(labels, predictions)

    df = pd.DataFrame(
        data={'prediction': predictions,
              'decision': metrics['decisions'],
              'ground_truth': labels,
              'img_name': img_names})
    df.to_csv(os.path.join(results_path, 'results.csv'), index=False)

    print(
        f'{run_name} EVAL AUC={metrics["AUC"]:f}, and AP={metrics["AP"]:f}, w/ best thr={metrics["best_thr"]:f} at f-m={metrics["best_f_measure"]:.3f} and FP={sum(metrics["FP"]):d}, FN={sum(metrics["FN"]):d}')

    with open(os.path.join(results_path, 'metrics.pkl'), 'wb') as f:
        pickle.dump(metrics, f)
        f.close()

    plt.figure(1)
    plt.clf()
    plt.plot(metrics['recall'], metrics['precision'])
    plt.title('Average Precision=%.4f' % metrics['AP'])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.savefig(f"{results_path}/precision-recall.pdf", bbox_inches='tight')

    plt.figure(1)
    plt.clf()
    plt.plot(metrics['FPR'], metrics['TPR'])
    plt.title('AUC=%.4f' % metrics['AUC'])
    plt.xlabel('False positive rate')
    plt.ylabel('True positive rate')
    plt.savefig(f"{results_path}/ROC.pdf", bbox_inches='tight')


def get_metrics(labels, predictions):
    metrics = {}
    precision, recall, thresholds = precision_recall_curve(labels, predictions)
    metrics['precision'] = precision
    metrics['recall'] = recall
    metrics['thresholds'] = thresholds
    f_measures = 2 * np.multiply(recall, precision) / (recall + precision + 1e-8)
    metrics['f_measures'] = f_measures
    ix_best = np.argmax(f_measures)
    metrics['ix_best'] = ix_best
    best_f_measure = f_measures[ix_best]
    metrics['best_f_measure'] = best_f_measure
    best_thr = thresholds[ix_best]
    metrics['best_thr'] = best_thr
    FPR, TPR, _ = roc_curve(labels, predictions)
    metrics['FPR'] = FPR
    metrics['TPR'] = TPR
    AUC = auc(FPR, TPR)
    metrics['AUC'] = AUC
    AP = auc(recall, precision)
    metrics['AP'] = AP
    decisions = predictions >= best_thr
    metrics['decisions'] = decisions
    FP, FN, TN, TP = calc_confusion_mat(decisions, labels)
    metrics['FP'] = FP
    metrics['FN'] = FN
    metrics['TN'] = TN
    metrics['TP'] = TP
    metrics['accuracy'] = (sum(TP) + sum(TN)) / (sum(TP) + sum(TN) + sum(FP) + sum(FN))
    return metrics

## read_results.py

In [None]:
import os
import evaluation
import pandas as pd
from operator import itemgetter
from config import load_from_dict

def get_params(cfg):
    params_lines = []
    for k, v in sorted(cfg.get_as_dict().items(), key=lambda x: x[0]):
        params_lines.append(f'{k}:{v}')
    return ','.join(params_lines)


def get_run_config(run_path, fold):
    params = {}
    if fold is not None:
        params_file = os.path.join(run_path, f'FOLD_{fold}', 'run_params.txt')
    else:
        params_file = os.path.join(run_path, 'run_params.txt')
    with open(params_file, 'r') as f:
        for l in f.readlines():
            k, v = l.split(':')
            params[k] = v.strip()
    return load_from_dict(params)


def read_results(results_path, dataset, folds=None, dagm_join=False, sortkey=itemgetter(0)):
    results = []
    results_columns = ['RUN_NAME',
                       'TN', "N_SEG",
                       'W_SEG_LOSS', 'W_P', 'W_MAX',
                       'FRQ_SMP', 'DYN_B_L', 'DELTA',
                       'EPS', 'LR',
                       'AUC', 'AP',
                       'FP', 'FN', 'FALSES', 'THRESH',
                       "F_MSR", "CLS_ACC", "TPR", "TNR",
                       '50_FP', '50_FN', '50_FALSES', '50_FMS', '50_CA',
                       'FP@FN=0', 'THRESH@FN=0',
                       'PATH', 'CONFIGURATION'
                       ]
    if dataset == "DAGM" and not dagm_join:
        for f in folds:
            process_dataset(results_path, dataset, [f], results, dagm_join)
    else:
        process_dataset(results_path, dataset, folds, results, dagm_join)

    results = sorted(results, key=sortkey)
    df = pd.DataFrame(results, columns=results_columns)
    return df


def process_dataset(results_path, dataset, folds, results, dagm_join):
    for run_name in os.listdir(os.path.join(results_path, dataset)):
        run_path = os.path.join(results_path, dataset, run_name)
        try:
            print(f"Processing run_path: {run_path}")
            cfg = get_run_config(run_path, None if folds is None else folds[0])
            ap, auc, fps, fns, t50_fps, t50_fns, fn0s, f_measure, cls_acc, f_measure_50, cls_acc_50, tpr, tnr = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
            best_t, fn0_t = -1, -1
            for f in folds:
                t_dec, t_folds, t_gt, t_img_names, t_preds = evaluation.read_predictions(f, '', run_path)
                fold_eval_res = evaluation.evaluate_fold(run_path, t_folds, t_gt, t_img_names, t_preds)
                if len(folds) == 1:
                    best_t = fold_eval_res["best_t"]
                    fn0_t = fold_eval_res["fn0_t"]
                ap += fold_eval_res["ap"]
                auc += fold_eval_res["auc"]
                fps += fold_eval_res["fps"]
                fns += fold_eval_res["fns"]
                t50_fps += fold_eval_res["t50_fps"]
                t50_fns += fold_eval_res["t50_fns"]
                fn0s += fold_eval_res["fn0s"]
                f_measure += fold_eval_res["f_measure"]
                cls_acc += fold_eval_res["cls_acc"]
                f_measure_50 += fold_eval_res["f_measure_50"]
                cls_acc_50 += fold_eval_res["cls_acc_50"]
                tpr += fold_eval_res["tpr"]
                tnr += fold_eval_res["tnr"]
            ap /= len(folds)
            auc /= len(folds)
            f_measure /= len(folds)
            cls_acc /= len(folds)
            f_measure_50 /= len(folds)
            cls_acc_50 /= len(folds)
            tpr /= len(folds)
            tnr /= len(folds)

            if dataset == "DAGM" and not dagm_join:
                run_name = f"{run_name}_FOLD_{folds[0]}"

            results.append(
                [run_name,
                 cfg.TRAIN_NUM, cfg.NUM_SEGMENTED,
                 cfg.WEIGHTED_SEG_LOSS, cfg.WEIGHTED_SEG_LOSS_P, cfg.WEIGHTED_SEG_LOSS_MAX,
                 cfg.FREQUENCY_SAMPLING, cfg.DYN_BALANCED_LOSS, cfg.DELTA_CLS_LOSS,
                 cfg.EPOCHS, cfg.LEARNING_RATE,
                 f"{auc:.5f}", f"{ap:.5f}",
                 fps, fns, fps + fns, f"{best_t:.5f}",
                 f"{f_measure:.5f}", f"{cls_acc:.5f}", f"{tpr:.5f}", f"{tnr:.5f}",
                 t50_fps, t50_fns, t50_fps + t50_fns, f"{f_measure_50:.5f}", f"{cls_acc_50:.5f}",
                 fn0s, f"{fn0_t:.5f}",
                 run_path, get_params(cfg)]
            )

        except Exception as f:
            print(f'Error reading RUN {run_path} with Exception {f} ')


def main():
    # dataset,results_folder = "STEEL", '/home/jakob/outputs/WEAKLY_LABELED/STEEL/GRADIENT'
    # dataset, results_folder = "KSDD2", '/home/jakob/outputs/WEAKLY_LABELED/KSDD2/GRADIENT'
    # dataset, results_folder = "DAGM", '/home/jakob/outputs/WEAKLY_LABELED/DAGM/GS'
    dataset, results_folder = "KSDD", '/home/jakob/outputs/WEAKLY_LABELED/RELEASE/'

    dagm_join = False # If True will join(average) results for all classes

    folds_dict = {
        'KSDD': [0, 1, 2],
        'KSDD2': [None],
        'STEEL': [None],
        'DAGM': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    }
    results = read_results(results_folder, dataset, folds_dict[dataset], dagm_join, sortkey=itemgetter(0))
    results.to_csv(os.path.join(results_folder, f'{dataset}_summary{f"_joined" if dataset == "DAGM" and dagm_join else ""}.csv'), index=False)


if __name__ == '__main__':
    main()

## demo_eval_single_image

In [None]:
from models import SegDecNet
import cv2
import numpy as np
import tensorflow as tf

INPUT_WIDTH = 512  # must be the same as it was during training
INPUT_HEIGHT = 1408  # must be the same as it was during training
INPUT_CHANNELS = 1  # must be the same as it was during training

model = SegDecNet(INPUT_WIDTH, INPUT_HEIGHT, INPUT_CHANNELS)

model_path = "path_to_your_model"
model.load_weights(model_path)

# %%
img_path = "path_to_the_test_image"
img = cv2.imread(img_path) if INPUT_CHANNELS == 3 else cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (INPUT_WIDTH, INPUT_HEIGHT))
img = np.transpose(img, (2, 0, 1)) if INPUT_CHANNELS == 3 else img[np.newaxis]
img_t = tf.convert_to_tensor(img)[np.newaxis, ...] / 255.0  # must be [BATCH_SIZE x CHANNELS x HEIGHT x WIDTH]

dec_out, seg_out = model(img_t)
img_score = tf.keras.activations.sigmoid(dec_out)
print(img_score)

## dataset.py


In [None]:
import cv2
import numpy as np
import tensorflow as tf
from scipy.ndimage.morphology import distance_transform_edt
from scipy.signal import convolve2d
from config import Config


class Dataset(tf.data.Dataset):
    def _generator(self):
        for index in range(self.len):
            if self.counter >= self.len:
                self.counter = 0
                if self.frequency_sampling:
                    sample_probability = 1 - (self.neg_retrieval_freq / np.max(self.neg_retrieval_freq))
                    sample_probability = sample_probability - np.median(sample_probability) + 1
                    sample_probability = sample_probability ** (np.log(len(sample_probability)) * 4)
                    sample_probability = sample_probability / np.sum(sample_probability)

                    # use replace=False for to get only unique values
                    self.neg_imgs_permutation = np.random.choice(range(self.num_neg),
                                                                 size=self.num_negatives_per_one_positive * self.num_pos,
                                                                 p=sample_probability,
                                                                 replace=False)
                else:
                    self.neg_imgs_permutation = np.random.permutation(self.num_neg)

            if self.kind == 'TRAIN':
                if index >= self.num_pos:
                    ix = index % self.num_pos
                    ix = self.neg_imgs_permutation[ix]
                    item = self.neg_samples[ix]
                    self.neg_retrieval_freq[ix] = self.neg_retrieval_freq[ix] + 1

                else:
                    ix = index
                    item = self.pos_samples[ix]
            else:
                if index < self.num_neg:
                    ix = index
                    item = self.neg_samples[ix]
                else:
                    ix = index - self.num_neg
                    item = self.pos_samples[ix]

            image, seg_mask, seg_loss_mask, is_segmented, image_path, seg_mask_path, sample_name = item

            if self.cfg.ON_DEMAND_READ:  # STEEL only so far
                if image_path == -1 or seg_mask_path == -1:
                    raise Exception('For ON_DEMAND_READ image and seg_mask paths must be set in read_contents')
                img = self.read_img_resize(image_path, self.grayscale, self.image_size)
                if seg_mask_path is None:  # good sample
                    seg_mask = np.zeros_like(img)
                elif isinstance(seg_mask_path, list):
                    seg_mask = self.rle_to_mask(seg_mask_path, self.image_size)
                else:
                    seg_mask, _ = self.self.read_label_resize(seg_mask_path, self.image_size)

                if np.max(seg_mask) == np.min(seg_mask):  # good sample
                    seg_loss_mask = np.ones_like(seg_mask)
                else:
                    seg_loss_mask = self.distance_transform(seg_mask, self.cfg.WEIGHTED_SEG_LOSS_MAX, self.cfg.WEIGHTED_SEG_LOSS_P)

                image = self.to_tensor(img)
                seg_mask = self.to_tensor(self.downsize(seg_mask))
                seg_loss_mask = self.to_tensor(self.downsize(seg_loss_mask))

            self.counter = self.counter + 1

            yield image, seg_mask, seg_loss_mask, is_segmented,sample_name

           　

    def __new__(cls, path: str, cfg: Config, kind: str):
        return tf.data.Dataset.from_generator(
            cls._generator,
            output_signature=(
                tf.TensorSpec(shape=(None, None, None), dtype=tf.float32),
                tf.TensorSpec(shape=(None, None), dtype=tf.float32),
                tf.TensorSpec(shape=(None, None), dtype=tf.float32),
                tf.TensorSpec(shape=(), dtype=tf.bool),
                tf.TensorSpec(shape=(), dtype=tf.string)
            )
        )

    def __init__(self, path: str, cfg: Config, kind: str):
        super(Dataset, self).__init__()
        self.path: str = path
        self.cfg: Config = cfg
        self.kind: str = kind
        self.image_size: (int, int) = (self.cfg.INPUT_WIDTH, self.cfg.INPUT_HEIGHT)
        self.grayscale: bool = self.cfg.INPUT_CHANNELS == 1

        self.num_negatives_per_one_positive: int = 1
        self.frequency_sampling: bool = self.cfg.FREQUENCY_SAMPLING and self.kind == 'TRAIN'

   
   1def __len__(self):
        return self.len

    def read_contents(self):
        pass

    def read_img_resize(self, path, grayscale, resize_dim) -> np.ndarray:
        img = cv2.imread(path, cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR)
        if resize_dim is not None:
            img = cv2.resize(img, dsize=resize_dim)
        return np.array(img, dtype=np.float32) / 255.0

    def read_label_resize(self, path, resize_dim, dilate=None) -> (np.ndarray, bool):
        lbl = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
        if dilate is not None and dilate > 1:
            lbl = cv2.dilate(lbl, np.ones((dilate, dilate)))
        if resize_dim is not None:
            lbl = cv2.resize(lbl, dsize=resize_dim)
        return np.array((lbl / 255.0), dtype=np.float32), np.max(lbl) > 0

    # Tensorflowへ書き換え
    def to_tensor(self, x):
        if x.dtype != np.float32:
            x = (x / 255.0).astype(np.float32)

        if len(x.shape) == 3:
            x = np.transpose(x, axes=(2, 0, 1))
        else:
            x = np.expand_dims(x, axis=0)

        x = tf.convert_to_tensor(x)
        return x


    def distance_transform(self, mask: np.ndarray, max_val: float, p: float) -> np.ndarray:
        h, w = mask.shape[:2]
        dst_trf = np.zeros((h, w))
        
        num_labels, labels = cv2.connectedComponents((mask * 255.0).astype(np.uint8), connectivity=8)
        for idx in range(1, num_labels):
            mask_roi= np.zeros((h, w))
            k = labels == idx
            mask_roi[k] = 255
            dst_trf_roi = distance_transform_edt(mask_roi)
            if dst_trf_roi.max() > 0:
                dst_trf_roi = (dst_trf_roi / dst_trf_roi.max())
                dst_trf_roi = (dst_trf_roi ** p) * max_val
            dst_trf += dst_trf_roi

        dst_trf[mask == 0] = 1
        return np.array(dst_trf, dtype=np.float32)

    # Tensorflowへ書き換え
    def downsize(self, image: np.ndarray, downsize_factor: int = 8) -> np.ndarray:
        img_t = tf.convert_to_tensor(np.expand_dims(image, 0 if len(image.shape) == 3 else (0, 1)).astype(np.float32))
        img_t = tf.pad(img_t, [[0, 0], [downsize_factor, downsize_factor], [downsize_factor, downsize_factor], [0, 0]], mode='REFLECT')
        image_np = tf.nn.avg_pool(img_t, ksize=2 * downsize_factor + 1, strides=downsize_factor, padding='VALID')
        return image_np[0] if len(image.shape) == 3 else image_np[0, 0]


    def rle_to_mask(self, rle, image_size):
        if len(rle) % 2 != 0:
            raise Exception('Suspicious')

        w, h = image_size
        mask_label = np.zeros(w * h, dtype=np.float32)

        positions = rle[0::2]
        length = rle[1::2]
        for pos, le in zip(positions, length):
            mask_label[pos - 1:pos + le - 1] = 1
        mask = np.reshape(mask_label, (h, w), order='F').astype(np.uint8)
        return mask


## train_net.py（x）

In [None]:
#Colabではうまく動かない
#from end2end import End2End
import argparse
from config import Config


def str2bool(v):
    return v.lower() in ("yes", "true", "t", "1")


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument('--GPU', type=int, required=False, help="ID of GPU used for training/evaluation.")
    parser.add_argument('--RUN_NAME', type=str, required=False, help="Name of the run, used as directory name for storing results.")
    parser.add_argument('--DATASET', type=str, required=False, help="Which dataset to use.")
    parser.add_argument('--DATASET_PATH', type=str, required=False, help="Path to the dataset.")

    parser.add_argument('--EPOCHS', type=int, required=False, help="Number of training epochs.")

    parser.add_argument('--LEARNING_RATE', type=float, required=False, help="Learning rate.")
    parser.add_argument('--DELTA_CLS_LOSS', type=float, required=False, help="Weight delta for classification loss.")

    parser.add_argument('--BATCH_SIZE', type=int, required=False, help="Batch size for training.")

    parser.add_argument('--WEIGHTED_SEG_LOSS', type=str2bool, required=False, help="Whether to use weighted segmentation loss.")
    parser.add_argument('--WEIGHTED_SEG_LOSS_P', type=float, required=False, default=None, help="Degree of polynomial for weighted segmentation loss.")
    parser.add_argument('--WEIGHTED_SEG_LOSS_MAX', type=float, required=False, default=None, help="Scaling factor for weighted segmentation loss.")
    parser.add_argument('--DYN_BALANCED_LOSS', type=str2bool, required=False, help="Whether to use dynamically balanced loss.")
    parser.add_argument('--GRADIENT_ADJUSTMENT', type=str2bool, required=False, help="Whether to use gradient adjustment.")
    parser.add_argument('--FREQUENCY_SAMPLING', type=str2bool, required=False, help="Whether to use frequency-of-use based sampling.")

    parser.add_argument('--DILATE', type=int, required=False, default=None, help="Size of dilation kernel for labels")

    parser.add_argument('--FOLD', type=int, default=None, help="Which fold (KSDD) or class (DAGM) to train.")
    parser.add_argument('--TRAIN_NUM', type=int, default=None, help="Number of positive training samples for KSDD or STEEL.")
    parser.add_argument('--NUM_SEGMENTED', type=int, required=False, default=None, help="Number of segmented positive  samples.")
    parser.add_argument('--RESULTS_PATH', type=str, default=None, help="Directory to which results are saved.")

    parser.add_argument('--VALIDATE', type=str2bool, default=None, help="Whether to validate during training.")
    parser.add_argument('--VALIDATE_ON_TEST', type=str2bool, default=None, help="Whether to validate on test set.")
    parser.add_argument('--VALIDATION_N_EPOCHS', type=int, default=None, help="Number of epochs between consecutive validation runs.")
    parser.add_argument('--USE_BEST_MODEL', type=str2bool, default=None, help="Whether to use the best model according to validation metrics for evaluation.")

    parser.add_argument('--ON_DEMAND_READ', type=str2bool, default=None, help="Whether to use on-demand read of data from disk instead of storing it in memory.")
    parser.add_argument('--REPRODUCIBLE_RUN', type=str2bool, default=None, help="Whether to fix seeds and disable CUDA benchmark mode.")

    parser.add_argument('--MEMORY_FIT', type=int, default=None, help="How many images can be fitted in GPU memory.")
    parser.add_argument('--SAVE_IMAGES', type=str2bool, default=None, help="Save test images or not.")
    #parser.add_argument("-f", required=False)

    #args = parser.parse_args()
    args = parser.parse_args(args=[])

    return args

if __name__ == '__main__':
    args = parse_args()

    configuration = Config()
    configuration.merge_from_args(args)
    configuration.init_extra()

    end2end = End2End(cfg=configuration)
    end2end.train()