In [3]:

from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import tensorflow as tf
from tensorflow.python.client import device_lib
 

from mymodel_tf import save_as_tflite, GenerationTF
from myutils_tf import *

In [4]:

def get_available_gpus():
    local_device_protos = device_lib.list_local_devices()
    return [x.name for x in local_device_protos if x.device_type == 'GPU']

NGPU = len(get_available_gpus())
# if NGPU == 0:
#     NGPU = 1

print(NGPU)

1


2024-01-24 18:45:17.508535: I metal_plugin/src/device/metal_device.cc:1154] Metal device set to: Apple M1
2024-01-24 18:45:17.508558: I metal_plugin/src/device/metal_device.cc:296] systemMemory: 16.00 GB
2024-01-24 18:45:17.508563: I metal_plugin/src/device/metal_device.cc:313] maxCacheSize: 5.33 GB
2024-01-24 18:45:17.508793: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-01-24 18:45:17.508989: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


# Hyper params

In [None]:
# for Adam
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-8


# INPUT PARAMS

In [6]:
# update params. using input arguments
input_type = 'rgb'
patch_size = 128
batch_size = 512
constraint_max = 6
constraint = {'min_value': 0, 'max_value': constraint_max}
input_max = 1

model_name = 'unetv2'  # resnet_flat, resnet_ed, bwunet, unet, unetv2 
model_sig = '_noise'
myepoch = 600

test = False
# test = True

## data path

In [17]:
cwd = os.getcwd()
print(cwd)

data_path = '/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords'

if '/content/drive/MyDrive' in cwd:
    data_path = '/content/drive/MyDrive/Datasets/MIPI_tetra_hybridenvs/train/tfrecords'

print(data_path)

/Users/bw/Code/lightning/src
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords


# UTILS

In [18]:
loss_type = ['rgb', 'yuv', 'ssim']  # 'rgb', 'yuv', 'ploss
# loss_type = ['rgb']  # 'rgb', 'yuv', 'ploss
# loss_type = ['yuv']

# get util class
if test:
    cache_enable=False
else:
    cache_enable=True

utils = bwutils(input_type,
                cfa_pattern='tetra',
                patch_size=patch_size,
                crop_size=patch_size,
                input_max=input_max,
                loss_type=loss_type, # 'rgb', 'yuv', 'ploss'
                loss_mode='2norm',
                loss_scale=1e4,
                cache_enable=cache_enable)


[bwutils] input_type rgb
[bwutils] output_type data_only
[bwutils] cfa_pattern 2
[bwutils] patch_size 128
[bwutils] crop_size 128
[bwutils] upscaling_factor None
[bwutils] input_max 1
[bwutils] loss_type ['rgb', 'yuv', 'ssim']
[bwutils] loss_mode 2norm <function square at 0x168d38c10>
[bwutils] loss_scale 10000.0
[bwutils] cache_enable True


# MODEL PATH

In [19]:
base_path = 'model_dir'
os.makedirs(base_path, exist_ok=True)
model_dir = os.path.join(base_path, 'checkpoint', model_name + model_sig)

# TFRECORD READ

In [20]:
def get_tfrecords(path, keyword):
    files = tf.io.gfile.glob(os.path.join(path, f'*{keyword}*tfrecords'))
    files.sort()
    return files
train_files = get_tfrecords(data_path, 'train')
eval_files = get_tfrecords(data_path, 'valid')
viz_files = get_tfrecords(data_path, 'viz')

In [21]:
print('data_path, ', data_path)
print('\n'.join(train_files))
print('\n'.join(eval_files))
print('\n'.join(viz_files))

data_path,  /Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_000.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_001.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_002.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_003.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_004.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_005.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_006.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_007.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_008.tfrecords
/Users/bw/Dataset/MIPI_demosaic_hybridevs/train/tfrecords/div2k_tetra_train_009.tfrecords
/Users/bw/Dataset/MIPI_demosai

# Training params setup


In [22]:
print('=========================================================')
print('=========================================================')
print('=========================================================')
print('=========================================================')
print('=========================================================')
print('========================================================= NGPU', NGPU)



In [23]:
batch_size      = batch_size * NGPU  # 128
batch_size_eval = batch_size * NGPU
batch_size_viz  = 4
print(batch_size, batch_size_eval, batch_size_viz)

512 512 4


In [24]:

train_params = {'filenames': train_files,
                'mode': tf.estimator.ModeKeys.TRAIN,
                'threads': 2,
                'shuffle_buff': 128,
                'batch': batch_size,
                'input_type':input_type,
                }
eval_params = {'filenames': eval_files,
               'mode': tf.estimator.ModeKeys.EVAL,
               'threads': 2,
               'shuffle_buff': 128,
               'batch': batch_size_eval,
               'input_type': input_type,
               }

viz_params = {'filenames': viz_files,
               'mode': tf.estimator.ModeKeys.EVAL,
               'threads': 2,
               'shuffle_buff': 128,
               'batch': batch_size_viz,
               'input_type': input_type,
               }


dataset_train = utils.dataset_input_fn(train_params)
dataset_eval = utils.dataset_input_fn(eval_params)
dataset_viz = utils.dataset_input_fn(viz_params)

2024-01-24 18:48:48.112587: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:303] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2024-01-24 18:48:48.112606: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:269] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


>>>>>>> raw.shape (128, 128)
<<<<<<< raw.shape (128, 128, 3)
<<<<<<< gt.shape (128, 128, 3)
>>>>>>> raw.shape (128, 128)
<<<<<<< raw.shape (128, 128, 3)
<<<<<<< gt.shape (128, 128, 3)
>>>>>>> raw.shape (128, 128)
<<<<<<< raw.shape (128, 128, 3)
<<<<<<< gt.shape (128, 128, 3)


## traini/val/viz counts

In [25]:
cnt_train, cnt_valid = 7080*19, 7096 # with noise
cnt_viz = 4

## Get Model gogo

In [26]:

bw = GenerationTF(model_name =  model_name, kernel_regularizer=True, kernel_constraint=True)

model = bw.model
if False:
    model.input.set_shape(1 + model.input.shape[1:]) # to freeze model
model.save(os.path.join(base_path, 'checkpoint' , f'{model_name}_model_structure.h5'), include_optimizer=False)
model.summary()

-----------> unetv2, init done <-------------
Model: "unet"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 unet_input (InputLayer)     [(None, 128, 128, 3)]        0         []                            
                                                                                                  
 sequential (Sequential)     (None, 64, 64, 64)           3136      ['unet_input[0][0]']          
                                                                                                  
 sequential_1 (Sequential)   (None, 32, 32, 128)          131712    ['sequential[0][0]']          
                                                                                                  
 sequential_2 (Sequential)   (None, 16, 16, 256)          525568    ['sequential_1[0][0]']        
                                                 

  saving_api.save_model(


## Model Compile

In [27]:
optimizer = tf.keras.optimizers.Adam(learning_rate=LEARNING_RATE, name='Adam')
model.compile(optimizer=optimizer,  # 'adam',
            loss=utils.loss_fn,  # 'mse',
            metrics=[utils.loss_fn])



## load pre-trained model

In [28]:
trained_model_file_name = '00003_resnet_flat_2.89940e-09.h5'
model, prev_epoch, prev_loss = load_checkpoint_if_exists(model, model_dir, model_name, trained_model_file_name)


--=-=-=->  []


## callbacks for training loop

In [29]:

callbacks = get_training_callbacks(['ckeckpoint', 'tensorboard', 'image'],
                                    base_path=base_path, model_name=model_name + model_sig,
                                    dataloader=dataset_viz, cnt_viz=cnt_viz, initial_value_threshold=prev_loss)
## lr callback
callback_lr = get_scheduler(type='cosine', lr_init=LEARNING_RATE, steps=myepoch)
callbacks.append(callback_lr)

-------------------> scheduler type,  cosine
-------------------> GOOD scheduler type2,  cosine


# tensorboard

In [30]:
%load_ext tensorboard

In [31]:
%tensorboard --logdir=model_dir

## TRAIN GOGO

In [None]:
more_ckpt_ratio = 1
model.fit(dataset_train,
            epochs=myepoch*more_ckpt_ratio,
            steps_per_epoch=(cnt_train // (batch_size*more_ckpt_ratio)) + 1,
            initial_epoch=prev_epoch,
            validation_data=dataset_eval,
            validation_steps=cnt_valid // batch_size_eval,
            validation_freq=1,
            callbacks=callbacks
            )

x.shape (4, 128, 128, 3)
y.shape (4, 128, 128, 3)
y.shape (4, 128, 128, 3)
y.shape (4, 128, 128, 3)
x: 0.00, 0.99, y: 0.00, 1.00, pred: 0.47, 0.53, diff: 0.00, 0.53

Epoch 1: LearningRateScheduler setting learning rate to 9.999999747378752e-05.
Epoch 1/600


2024-01-24 18:50:38.175037: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
