<a href="https://colab.research.google.com/github/fregean/Tensorflow-ver-mixed-segdec-net/blob/main/Tensorflow_version_20230602.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Tensorflowへ変換

KSDD

画像データ：1408*512

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import os

os.chdir('/content/drive/MyDrive/研修用/tensorflow_ver_mixed-segdec-net')


## models.py

In [3]:
# @title models.py

import tensorflow as tf
from tensorflow.keras import layers, initializers, Model
import math

BATCHNORM_TRACK_RUNNING_STATS = False
BATCHNORM_MOVING_AVERAGE_DECAY = 0.9997

class Conv2D_init(layers.Conv2D):
    def __init__(self, filters, kernel_size, **kwargs):
        super(Conv2D_init, self).__init__(filters, 
                                          kernel_size, 
                                          kernel_initializer=initializers.GlorotNormal(), 
                                          bias_initializer=initializers.Zeros(),
                                          **kwargs)

class FeatureNorm(layers.Layer):
    # PyTorch版のfeature_index=1をfeature_index=-1に変更（channel first→channel last）
    def __init__(self, num_features, feature_index=-1, rank=4, reduce_dims=(2, 3), eps=0.001, include_bias=True):
        super(FeatureNorm, self).__init__()
        self.shape = [1] * rank
        self.shape[feature_index] = num_features
        self.reduce_dims = reduce_dims
        self.scale = self.add_weight(name='scale', shape=self.shape, initializer='ones', trainable=True)
        self.eps = eps

        if include_bias:
            self.bias = self.add_weight(name='bias', shape=self.shape, initializer='zeros', trainable=True)
        else:
            self.bias = 0

    def call(self, features):
        f_std = tf.math.reduce_std(features, axis=self.reduce_dims, keepdims=True)
        f_mean = tf.math.reduce_mean(features, axis=self.reduce_dims, keepdims=True)
        return self.scale * ((features - f_mean) / tf.sqrt(f_std + self.eps)) + self.bias

def _conv_block(out_channels, kernel_size, padding):
    return tf.keras.Sequential([
        Conv2D_init(filters=out_channels, kernel_size=kernel_size, padding=padding, use_bias=False),
        FeatureNorm(num_features=out_channels, eps=0.001),
        layers.ReLU()
    ])

class SegDecNet(Model):
    def __init__(self, device, input_width, input_height, input_channels):
        super(SegDecNet, self).__init__()
        if input_width % 8 != 0 or input_height % 8 != 0:
            raise Exception(f"Input size must be divisible by 8! width={input_width}, height={input_height}")
        self.input_width = input_width
        self.input_height = input_height
        self.input_channels = input_channels
        self.volume = tf.keras.Sequential([
            _conv_block(32, 5, 'same'), 
            layers.MaxPool2D(2),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            layers.MaxPool2D(2),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            _conv_block(64, 5, 'same'),
            layers.MaxPool2D(2),
            _conv_block(1024, 15, 'same')
        ])

        self.seg_mask = tf.keras.Sequential([
            Conv2D_init(filters=1, kernel_size=1, use_bias=False),
            FeatureNorm(num_features=1, eps=0.001)
        ])

        self.extractor = tf.keras.Sequential([
            layers.MaxPool2D(pool_size=2),
            _conv_block(8, 5, padding='same'),
            layers.MaxPool2D(pool_size=2),
            _conv_block(16, 5, padding='same'),
            layers.MaxPool2D(pool_size=2),
            _conv_block(32, 5, padding='same')
        ])

        self.global_max_pool_feat = layers.GlobalMaxPooling2D(keepdims=True)
        self.global_avg_pool_feat = layers.GlobalAveragePooling2D(keepdims=True)
        self.global_max_pool_seg = layers.MaxPooling2D(pool_size=(self.input_height // 8, self.input_width // 8))
        self.global_avg_pool_seg = layers.AveragePooling2D(pool_size=(self.input_height // 8, self.input_width // 8))

        self.fc = layers.Dense(units=1)

        self.volume_lr_multiplier_layer = GradientMultiplyLayer()
        self.glob_max_lr_multiplier_layer = GradientMultiplyLayer()
        self.glob_avg_lr_multiplier_layer = GradientMultiplyLayer()

    def set_gradient_multipliers(self, multiplier):
        self.volume_lr_multiplier_mask = tf.ones((1,)) * multiplier
        self.glob_max_lr_multiplier_mask = tf.ones((1,)) * multiplier
        self.glob_avg_lr_multiplier_mask = tf.ones((1,)) * multiplier

    def call(self, input, training=False):
        volume = self.volume(input, training=training)
        seg_mask = self.seg_mask(volume, training=training)


        cat = tf.concat([volume, seg_mask], axis=-1)
        #print(f'cat_concat: {cat, cat.shape}')

        cat = self.volume_lr_multiplier_layer(cat, self.volume_lr_multiplier_mask)
        #print(f'cat_volume_lr_multiplier_layer: {cat, cat.shape}')

        features = self.extractor(cat, training=training)
        #print(f'features: {features, features.shape}')
        global_max_feat = self.global_max_pool_feat(features)
        global_avg_feat = self.global_avg_pool_feat(features)
        global_max_seg = self.global_max_pool_seg(seg_mask)
        global_avg_seg = self.global_avg_pool_seg(seg_mask)

        global_max_seg = self.glob_max_lr_multiplier_layer(global_max_seg, self.glob_max_lr_multiplier_mask)
        global_avg_seg = self.glob_avg_lr_multiplier_layer(global_avg_seg, self.glob_avg_lr_multiplier_mask)
        

        fc_in = tf.concat([global_max_feat, global_avg_feat, global_max_seg, global_avg_seg], axis=-1)
        fc_in = tf.reshape(fc_in, [fc_in.shape[0], -1])
        prediction = self.fc(fc_in)

        return prediction, seg_mask

class GradientMultiplyLayer(layers.Layer):
    def call(self, input, mask_bw):
        return input


## End2End

In [4]:
# @title end2end.py 

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
#from models import SegDecNet
import numpy as np
import os
import tensorflow as tf
import utils
import time
import datetime
import pandas as pd
from data.dataset_catalog import get_dataset
import random
import cv2
from config import Config
from tensorflow.python.ops import summary_ops_v2
from tensorflow.python.keras import backend as K

LVL_ERROR = 10
LVL_INFO = 5
LVL_DEBUG = 1

LOG = 1  # Will log all mesages with lvl greater than this
SAVE_LOG = True

WRITE_TENSORBOARD = False

class End2End:
    def __init__(self, cfg: Config):
        self.cfg: Config = cfg
        self.storage_path: str = os.path.join(self.cfg.RESULTS_PATH, self.cfg.DATASET)

    def _log(self, message, lvl=LVL_INFO):
        n_msg = f"{self.run_name} {message}"
        if lvl >= LOG:
            print(n_msg)

    def train(self):
        self._set_results_path()
        self._create_results_dirs()
        self.print_run_params()
        if self.cfg.REPRODUCIBLE_RUN:
            self._log("Reproducible run, fixing all seeds to:1337", LVL_DEBUG)
            np.random.seed(1337)
            tf.random.set_seed(1337)
            random.seed(1337)

        device = self._get_device()
        model = self._get_model()
        optimizer = self._get_optimizer(model)
        loss_seg, loss_dec = self._get_loss(True), self._get_loss(False)

        train_loader = get_dataset("TRAIN", self.cfg)
        validation_loader = get_dataset("VAL", self.cfg)

        # get_dataset関数でKSDDDatasetのインスタンスを取得
        # train_dataset = get_dataset("TRAIN", self.cfg)
        # validation_dataset = get_dataset("VAL", self.cfg)

        # tf.data.Datasetのインスタンスを作成
        # train_loader = train_dataset.prefetch(tf.data.experimental.AUTOTUNE)
        # validation_loader = validation_dataset.prefetch(tf.data.experimental.AUTOTUNE)

        tensorboard_writer = summary_ops_v2.create_file_writer(self.tensorboard_path) if WRITE_TENSORBOARD else None

        train_results = self._train_model(device, model, train_loader, loss_seg, loss_dec, optimizer, validation_loader, tensorboard_writer)
        self._save_train_results(train_results)
        self._save_model(model, os.path.join(self.model_path, "final_state.tf"))

        self.eval(model, device, self.cfg.SAVE_IMAGES, False, False)

        self._save_params()

    def eval(self, model, save_images, plot_seg, reload_final):
        if reload_final:
            model.load_weights(os.path.join(self.model_path, "final_state_dict.tf"))
        test_loader = get_dataset("TEST", self.cfg)
        self.eval_model(model, test_loader, save_folder=self.outputs_path, save_images=save_images, is_validation=False, plot_seg=plot_seg)
    

    

    def training_iteration(self, data, model, criterion_seg, criterion_dec, optimizer, weight_loss_seg, weight_loss_dec, iter_index):
        images, seg_masks, seg_loss_masks, is_segmented, _, _, _ = data

        batch_size = self.cfg.BATCH_SIZE
        memory_fit = self.cfg.MEMORY_FIT  # Not supported yet for >1

        num_subiters = int(batch_size / memory_fit)
        
        total_loss = 0
        total_correct = 0

        total_loss_seg = 0
        total_loss_dec = 0
        is_segmented_tensor = tf.constant(is_segmented, dtype=tf.bool)

        for sub_iter in range(num_subiters):
            with tf.GradientTape() as tape:    
                images_ = images[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                seg_masks_ = seg_masks[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                seg_loss_masks_ = seg_loss_masks[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                is_pos_ = tf.reshape(tf.reduce_max(seg_masks_), (memory_fit, 1))
                print(f'seg_masks_:{seg_masks_},is_pos_:{is_pos_}')

                decision, output_seg_mask = model(images_, training=True)

                if is_segmented_tensor[sub_iter]:
                    if self.cfg.WEIGHTED_SEG_LOSS:

                        loss_seg = tf.reduce_mean(criterion_seg(seg_masks_, output_seg_mask)[:,:,:,np.newaxis] * seg_loss_masks_)
                         
                    else:
                        loss_seg = criterion_seg(seg_masks_, output_seg_mask)[:,:,:,np.newaxis]

                    

                    loss_dec = criterion_dec(is_pos_, decision)
                    # 表示
                    print(f'output_seg_mask:{output_seg_mask}, seg_masks_:{seg_masks_}, loss_seg:{loss_seg}, decision:{decision}, is_pos_:{is_pos_}, loss_dec:{loss_dec}')
                    

                    total_loss_seg += loss_seg
                    total_loss_dec += loss_dec
                    
                    total_correct += tf.reduce_sum(tf.cast(tf.equal(decision.numpy() > 0.0, tf.cast(tf.cast(is_pos_.numpy(), dtype=np.int32), tf.bool)), dtype=np.float32))
                    loss = weight_loss_seg * loss_seg + weight_loss_dec * loss_dec
                    
                else:
                    loss_dec = criterion_dec(is_pos_, decision)
                    total_loss_dec += loss_dec

                    total_correct += tf.reduce_sum(tf.cast(tf.equal(decision.numpy() > 0.0, tf.cast(tf.cast(is_pos_.numpy(), dtype=np.int32), tf.bool)), dtype=np.float32))
                    loss = weight_loss_dec * loss_dec

                total_loss += loss

            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))    
        

        return total_loss_seg, total_loss_dec, total_loss, total_correct


    def _train_model(self, device, model, train_loader, criterion_seg, criterion_dec, optimizer, validation_set, tensorboard_writer):
        losses = []
        validation_data = []
        max_validation = -1
        validation_step = self.cfg.VALIDATION_N_EPOCHS

        num_epochs = self.cfg.EPOCHS
        
        samples_per_epoch = tf.data.experimental.cardinality(train_loader).numpy() * self.cfg.BATCH_SIZE   

        self.set_dec_gradient_multiplier(model, 0.0)

        for epoch in range(num_epochs):

            self.print_epoch(epoch)

            weight_loss_seg, weight_loss_dec = self.get_loss_weights(epoch)

            dec_gradient_multiplier = self.get_dec_gradient_multiplier()
            self.set_dec_gradient_multiplier(model, dec_gradient_multiplier)
            
            epoch_loss_seg, epoch_loss_dec, epoch_loss = 0, 0, 0
            epoch_correct = 0

            for iter_index, (data) in enumerate(train_loader):

                curr_loss_seg, curr_loss_dec, curr_loss, correct = self.training_iteration(data, model,
                                                                                        criterion_seg,
                                                                                        criterion_dec,
                                                                                        optimizer, weight_loss_seg,
                                                                                        weight_loss_dec, iter_index)

                epoch_loss_seg += curr_loss_seg
                epoch_loss_dec += curr_loss_dec
                epoch_loss += curr_loss

                # loss表示
                self.print_loss(epoch_loss_seg, epoch_loss_dec)

                epoch_correct += correct

            if epoch % 5 == 0:

                self._save_model(model, os.path.join(self.model_path, f"ep_{epoch:02}.tf"))
    
            epoch_loss_seg = epoch_loss_seg / samples_per_epoch
            epoch_loss_dec = epoch_loss_dec / samples_per_epoch
            epoch_loss = epoch_loss / samples_per_epoch
            losses.append((epoch_loss_seg, epoch_loss_dec, epoch_loss, epoch))

            if self.cfg.VALIDATE and (epoch % validation_step == 0 or epoch == num_epochs - 1):
                validation_ap, validation_accuracy = self.eval_model(model, validation_set, None, False, True, False)
                validation_data.append((validation_ap, epoch))

                if validation_ap > max_validation:
                    max_validation = validation_ap
                    self._save_model(model, os.path.join(self.model_path, "best_state_dict.tf")) # .h5

        return losses, validation_data


    def eval_model(self, model, eval_loader, save_folder, save_images, is_validation, plot_seg):

        res = []
        predictions, ground_truths = [], []

        for data_point in eval_loader:

            image, seg_mask, seg_loss_mask, _, sample_name, _, _ = data_point
            is_pos = tf.reduce_max(seg_mask) > 0
            prediction, pred_seg = model(image, training=False)

            pred_seg = tf.nn.sigmoid(pred_seg)
            prediction = tf.nn.sigmoid(prediction)

            prediction = prediction.numpy().item()
            image = image.numpy()
            pred_seg = pred_seg.numpy()
            seg_mask = seg_mask.numpy()

            predictions.append(prediction)
            ground_truths.append(is_pos)
            res.append((prediction, None, None, is_pos, sample_name[0]))
            if not is_validation and save_images:
                # image saving code here
                plt.imsave(f"{save_folder}/{sample_name[0]}_prediction.png", prediction[0, :, :, 0], cmap='gray')
                plt.imsave(f"{save_folder}/{sample_name[0]}_image.png", image[0, :, :, 0], cmap='gray')
                plt.imsave(f"{save_folder}/{sample_name[0]}_pred_seg.png", pred_seg[0, :, :, 0], cmap='gray')
                plt.imsave(f"{save_folder}/{sample_name[0]}_seg_mask.png", seg_mask[0, :, :], cmap='gray')

        if is_validation:
            self.print_eval(ground_truths, predictions)
            metrics = utils.get_metrics(np.array(ground_truths), np.array(predictions))
            return metrics["AP"], metrics["accuracy"]
            
        else:
            utils.evaluate_metrics(res, self.run_path, self.run_name)


    def get_dec_gradient_multiplier(self):
        if self.cfg.GRADIENT_ADJUSTMENT:
            grad_m = 0
        else:
            grad_m = 1

        return grad_m

    def set_dec_gradient_multiplier(self, model, multiplier):
        # This function is not applicable in TensorFlow as it is specific to PyTorch.
        model.set_gradient_multipliers(multiplier)

    def get_loss_weights(self, epoch):
        total_epochs = float(self.cfg.EPOCHS)

        if self.cfg.DYN_BALANCED_LOSS:
            seg_loss_weight = 1 - (epoch / total_epochs)
            dec_loss_weight = self.cfg.DELTA_CLS_LOSS * (epoch / total_epochs)
            # seg_loss_weight = self.cfg.DELTA_CLS_LOSS *  (epoch / total_epochs)
            # dec_loss_weight = self.cfg.DELTA_CLS_LOSS * (epoch / total_epochs)
        else:
            seg_loss_weight = 1
            dec_loss_weight = self.cfg.DELTA_CLS_LOSS

        return tf.constant(seg_loss_weight), tf.constant(dec_loss_weight)

    def reload_model(self, model, load_final=False):
        if self.cfg.USE_BEST_MODEL:
            path = os.path.join(self.model_path, "best_state_dict.tf")
            model.load_weights(path)
        elif load_final:
            path = os.path.join(self.model_path, "final_state_dict.tf")
            model.load_weights(path)

    def _save_params(self):
        params = self.cfg.get_as_dict()
        params_lines = sorted(map(lambda e: e[0] + ":" + str(e[1]) + "\n", params.items()))
        fname = os.path.join(self.run_path, "run_params.txt")
        with open(fname, "w+") as f:
            f.writelines(params_lines)

    def _save_train_results(self, results):
        losses, validation_data = results
        ls, ld, l, le = map(list, zip(*losses))
        plt.plot(le, l, label="Loss", color="red")
        plt.plot(le, ls, label="Loss seg")
        plt.plot(le, ld, label="Loss dec")
        plt.ylim(bottom=0)
        plt.grid()
        plt.xlabel("Epochs")
        if self.cfg.VALIDATE:
            v, ve = map(list, zip(*validation_data))
            plt.twinx()
            plt.plot(ve, v, label="Validation AP", color="Green")
            plt.ylim((0, 1))
        plt.legend()
        plt.savefig(os.path.join(self.run_path, "loss_val"), dpi=200)

        df_loss = pd.DataFrame(data={"loss_seg": ls, "loss_dec": ld, "loss": l, "epoch": le})
        df_loss.to_csv(os.path.join(self.run_path, "losses.csv"), index=False)

        if self.cfg.VALIDATE:
            df_loss = pd.DataFrame(data={"validation_data": v, "epoch": ve})
            df_loss.to_csv(os.path.join(self.run_path, "validation.csv"), index=False)

    def _save_model(self, model, output_name):
        print(output_name)
        if os.path.isfile(output_name):
            os.remove(output_name)

        model.save_weights(output_name)

    def _get_optimizer(self, model):
        return tf.keras.optimizers.SGD(learning_rate=self.cfg.LEARNING_RATE)
        #return tf.keras.optimizers.Adam(learning_rate=self.cfg.LEARNING_RATE)

    def _get_loss(self, is_seg):
        reduction = tf.keras.losses.Reduction.NONE if self.cfg.WEIGHTED_SEG_LOSS and is_seg else tf.keras.losses.Reduction.AUTO
        return tf.keras.losses.BinaryCrossentropy(reduction=reduction)


    def _get_device(self):
        return tf.device(f"/GPU:{self.cfg.GPU}")

    def _set_results_path(self):
        self.run_name = f"{self.cfg.RUN_NAME}_FOLD_{self.cfg.FOLD}" if self.cfg.DATASET in ["KSDD", "DAGM"] else self.cfg.RUN_NAME

        results_path = os.path.join(self.cfg.RESULTS_PATH, self.cfg.DATASET)
        self.tensorboard_path = os.path.join(results_path, "tensorboard", self.run_name)

        run_path = os.path.join(results_path, self.cfg.RUN_NAME)
        if self.cfg.DATASET in ["KSDD", "DAGM"]:
            run_path = os.path.join(run_path, f"FOLD_{self.cfg.FOLD}")

        self._log(f"Executing run with path {run_path}")

        self.run_path = run_path
        self.model_path = os.path.join(run_path, "models")
        self.outputs_path = os.path.join(run_path, "test_outputs")

    def _create_results_dirs(self):
        list(map(utils.create_folder, [self.run_path, self.model_path, self.outputs_path, ]))

    def _get_model(self):
        seg_net = SegDecNet(self._get_device(), self.cfg.INPUT_WIDTH, self.cfg.INPUT_HEIGHT, self.cfg.INPUT_CHANNELS)
        return seg_net

    def print_run_params(self):
        for l in sorted(map(lambda e: e[0] + ":" + str(e[1]) + "\n", self.cfg.get_as_dict().items())):
            k, v = l.split(":")
            print(f"{k:25s} : {str(v.strip())}")

    def print_epoch(self, epoch):
        t_now = datetime.datetime.now().time()
        print(f'epoch:{epoch}>{t_now}')

    def print_eval(self, ground_truths, predictions):
        print(f'predictions:{predictions}')

    def print_loss(self, epoch_loss_seg, epoch_loss_dec):
        print(f'epoch_loss_seg:{epoch_loss_seg}, epoch_loss_dec:{epoch_loss_dec}')

    def print_criterion(self, output_seg_mask, seg_masks_, loss_seg, decision, is_pos_, loss_dec):
        print(f'output_seg_mask:{output_seg_mask}, seg_masks_:{seg_masks_}, loss_seg: {loss_seg}')
        print(f'decision:{decision}, is_pos_:{is_pos_}, loss_dec: {loss_dec}')

以下はpyファイルの方を用いる

## train_net.py

In [5]:
# Colab用args
import argparse
from config import Config

class args():
    def __init__(self):
        self.GPU=0
        self.RUN_NAME='RUN_NAME'
        self.DATASET='KSDD'
        self.DATASET_PATH='/content/drive/MyDrive/研修用/tensorflow_ver_mixed-segdec-net/datasets/KSDD'
        self.EPOCHS=5 # 50
        self.LEARNING_RATE=0.001 # 1.0から変更
        self.DELTA_CLS_LOSS = 0.01
        self.BATCH_SIZE=1
        self.WEIGHTED_SEG_LOSS=True
        self.WEIGHTED_SEG_LOSS_P=2
        self.WEIGHTED_SEG_LOSS_MAX=1
        self.DYN_BALANCED_LOSS=True
        self.GRADIENT_ADJUSTMENT=True # PyTorchに合わせる
        self.FREQUENCY_SAMPLING=True
        self.FOLD=0
        self.TRAIN_NUM=33
        self.NUM_SEGMENTED=33
        self.RESULTS_PATH = "/content/drive/MyDrive/研修用/tensorflow_ver_mixed-segdec-net/datasets/RESULTS"
        
        self.VALIDATE = True
        self.VALIDATE_ON_TEST = True
        self.VALIDATION_N_EPOCHS = 5
        self.USE_BEST_MODEL = False

        self.ON_DEMAND_READ = False
        self.REPRODUCIBLE_RUN = False
        self.MEMORY_FIT = 1
        self.SAVE_IMAGES = True
        self.DILATE = 7

args=args()
cfg = Config()
cfg.merge_from_args(args)
cfg.init_extra()

end2end = End2End(cfg=cfg)

# ここを実行する
#end2end.train()

In [6]:
from data.input_ksdd import KSDDDataset
from data.dataset import Dataset
import pickle
from config import Config

kind = "TRAIN"

ds = KSDDDataset(cfg.DATASET_PATH, cfg, kind)

shuffle = kind == "TRAIN"
batch_size = cfg.BATCH_SIZE if kind == "TRAIN" else 1

dataset = ds._data.batch(batch_size)
if shuffle:
    dataset = dataset.shuffle(buffer_size=ds.length)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)

device = end2end._get_device()
model = end2end._get_model()
optimizer = end2end._get_optimizer(model)
criterion_seg, criterion_dec = end2end._get_loss(True), end2end._get_loss(False)

train_loader = get_dataset("TRAIN", cfg)
validation_loader = get_dataset("VAL", cfg)



In [7]:
losses = []
validation_data = []
max_validation = -1
validation_step = cfg.VALIDATION_N_EPOCHS

num_epochs = cfg.EPOCHS

samples_per_epoch = tf.data.experimental.cardinality(train_loader).numpy() * cfg.BATCH_SIZE   

end2end.set_dec_gradient_multiplier(model, 0.0)

In [9]:
for epoch in range(num_epochs):
    print('-----------------------------------------------------------------------------------')
    print(f'epoch:{epoch}')

    weight_loss_seg, weight_loss_dec = end2end.get_loss_weights(epoch)

    dec_gradient_multiplier = end2end.get_dec_gradient_multiplier()
    end2end.set_dec_gradient_multiplier(model, dec_gradient_multiplier)
    
    epoch_loss_seg, epoch_loss_dec, epoch_loss = 0, 0, 0
    epoch_correct = 0

    for iter_index, (data) in enumerate(train_loader):

        images, seg_masks, seg_loss_masks, is_segmented, _, _, _ = data

        batch_size = cfg.BATCH_SIZE
        memory_fit = cfg.MEMORY_FIT

        num_subiters = int(batch_size / memory_fit)
        print(f'num_subiters:{num_subiters}')
        
        total_loss = 0
        total_correct = 0

        total_loss_seg = 0
        total_loss_dec = 0
        is_segmented_tensor = tf.constant(is_segmented, dtype=tf.bool)

        for sub_iter in range(num_subiters):
            with tf.GradientTape() as tape:    
                images_ = images[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                seg_masks_ = seg_masks[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                seg_loss_masks_ = seg_loss_masks[sub_iter * memory_fit:(sub_iter + 1) * memory_fit, :, :, :]
                is_pos_ = tf.reshape(tf.reduce_max(seg_masks_), (memory_fit, 1))

                decision, output_seg_mask = model(images_, training=True)

                print(f'decision:{decision}')

                if is_segmented_tensor[sub_iter]:
                    if cfg.WEIGHTED_SEG_LOSS:

                        loss_seg = tf.reduce_mean(criterion_seg(seg_masks_, output_seg_mask)[:,:,:,np.newaxis] * seg_loss_masks_)
                         
                    else:
                        loss_seg = criterion_seg(seg_masks_, output_seg_mask)[:,:,:,np.newaxis]

                    loss_dec = criterion_dec(is_pos_, decision)

                    total_loss_seg += loss_seg
                    total_loss_dec += loss_dec
                    
                    total_correct += tf.reduce_sum(tf.cast(tf.equal(decision.numpy() > 0.0, tf.cast(tf.cast(is_pos_.numpy(), dtype=np.int32), tf.bool)), dtype=np.float32))
                    loss = weight_loss_seg * loss_seg + weight_loss_dec * loss_dec
                    
                    print(f'is_pos_:{is_pos_}, decision:{decision}')
                    print(f'loss:{loss}, loss_dec:{loss_dec}, weight_loss_dec:{weight_loss_dec}, loss_seg:{loss_seg}, weight_loss_seg:{weight_loss_seg}')
                    
                else:
                    loss_dec = criterion_dec(is_pos_, decision)
                    total_loss_dec += loss_dec

                    total_correct += tf.reduce_sum(tf.cast(tf.equal(decision.numpy() > 0.0, tf.cast(tf.cast(is_pos_.numpy(), dtype=np.int32), tf.bool)), dtype=np.float32))
                    loss = weight_loss_dec * loss_dec
                    
                    print(f'is_pos_:{is_pos_}, decision:{decision}')
                    print(f'loss:{loss}, loss_dec:{loss_dec}, weight_loss_dec:{weight_loss_dec}, loss_seg:{loss_seg}, weight_loss_seg:{weight_loss_seg}')

                total_loss += loss
                

            print(f'loss:{loss}')
            gradients = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(gradients, model.trainable_variables))  


        epoch_loss_seg += total_loss_seg
        epoch_loss_dec += total_loss_dec
        epoch_loss += total_loss

        # loss表示
        print('===============================================')
        print(f'epoch_loss_seg:{epoch_loss_seg}, epoch_loss_dec:{epoch_loss_dec}')

        epoch_correct += total_correct

    # if epoch % 5 == 0:

    #     end2end._save_model(model, os.path.join(end2end.model_path, f"ep_{epoch:02}.tf"))

    epoch_loss_seg = epoch_loss_seg / samples_per_epoch
    epoch_loss_dec = epoch_loss_dec / samples_per_epoch
    epoch_loss = epoch_loss / samples_per_epoch

    print(f'epoch_loss:{epoch_loss}')
    losses.append((epoch_loss_seg, epoch_loss_dec, epoch_loss, epoch))

    if cfg.VALIDATE and (epoch % validation_step == 0 or epoch == num_epochs - 1):
        validation_ap, validation_accuracy = end2end.eval_model(model, validation_loader, None, False, True, False)
        print(f'validation_ap:{validation_ap},validation_accuracy:{validation_accuracy}')
        validation_data.append((validation_ap, epoch))

        # if validation_ap > max_validation:
        #     max_validation = validation_ap
        #     end2end._save_model(model, os.path.join(end2end.model_path, "best_state_dict.tf")) # .h5


[1;30;43mストリーミング出力は最後の 5000 行に切り捨てられました。[0m
num_subiters:1
decision:[[-0.39074126]]
is_pos_:[[0.]], decision:[[-0.39074126]]
loss:1.0062440633773804, loss_dec:-0.0, weight_loss_dec:0.004000000189989805, loss_seg:1.677073359489441, weight_loss_seg:0.6000000238418579
loss:1.0062440633773804
epoch_loss_seg:134.77569580078125, epoch_loss_dec:200.52430725097656
num_subiters:1
decision:[[-0.39374447]]
is_pos_:[[0.]], decision:[[-0.39374447]]
loss:1.0062438249588013, loss_dec:-0.0, weight_loss_dec:0.004000000189989805, loss_seg:1.6770730018615723, weight_loss_seg:0.6000000238418579
loss:1.0062438249588013
epoch_loss_seg:136.45277404785156, epoch_loss_dec:200.52430725097656
num_subiters:1
decision:[[-0.38995868]]
is_pos_:[[0.]], decision:[[-0.38995868]]
loss:1.0062438249588013, loss_dec:-0.0, weight_loss_dec:0.004000000189989805, loss_seg:1.6770730018615723, weight_loss_seg:0.6000000238418579
loss:1.0062438249588013
epoch_loss_seg:138.12985229492188, epoch_loss_dec:200.52430725097656
num_sub