# Prepare a training dataset
- We use [the test set of VGGFace2](https://github.com/ox-vgg/vgg_face2).
- You may download the VGGFace2 set or prepare another dataset

In [1]:
DATA_PATH = '/workspace/data/vgg2_test'

# Preprocess the datatset
- Before training, we have to crop faces and/or align it, and extract face embeddings.
- You might look into "crop_vgg2_data" module for using another dataset.

In [1]:
PROCESSED_DATA_PATH = '/workspace/data/vgg2_test_processed'

In [5]:
from cleanir.tools.crop_vgg2_data import *
import os

if not os.path.exists(DATA_PATH):
    print("training data path does not exist")

else:
    os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)
    crop_and_align(DATA_PATH, PROCESSED_DATA_PATH, (64, 64), align=False)
    save_face_encoding(PROCESSED_DATA_PATH)

100%|██████████| 169396/169396 [2:13:50<00:00, 21.09it/s]


Instructions for updating:
Colocations handled automatically by placer.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.


100%|██████████| 152507/152507 [1:22:17<00:00, 30.89it/s]


# Build a CLEANIR network

In [2]:
from cleanir.cleanir import Cleanir
cleanir = Cleanir(latent_dim=1024)
cleanir.build_network(n_blocks=4, recon_weight=0.3)

Using TensorFlow backend.


Instructions for updating:
Colocations handled automatically by placer.




## Check summaries of the networks if you want

In [4]:
cleanir.print_network_summary()

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 64, 64, 3)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 64)   9472        input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 32, 32, 64)   256         conv2d_1[0][0]                   
__________________________________________________________________________________________________
leaky_re_lu_1 (LeakyReLU)       (None, 32, 32, 64)   0           batch_normalization_1[0][0]      
__________________________________________________________________________________________________
res2conv_b

# Train the network

In [None]:
from cleanir.data_loader import VGG2DataGenerator
res_path = '/workspace/data/Results/pubtestver'
batch_size = 32

vgg2_gen = VGG2DataGenerator(PROCESSED_DATA_PATH,
                             batch_size=batch_size, pair=False,
                             has_label_files=True)
cleanir.train(res_path, vgg2_gen)

HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))

Instructions for updating:
Use tf.cast instead.

epoch 1 loss [90455.44, 69915.34, 20540.096]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 2 loss [85346.7, 71278.3, 14068.407]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 3 loss [82154.82, 71155.945, 10998.873]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 4 loss [79609.945, 70262.305, 9347.641]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 5 loss [77953.125, 70060.11, 7893.0137]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 6 loss [77212.85, 69535.16, 7677.6943]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 7 loss [74423.984, 67489.89, 6934.096]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 8 loss [75130.35, 69486.625, 5643.727]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 9 loss [74064.555, 68155.36, 5909.194]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 10 loss [75437.62, 70111.86, 5325.7554]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 11 loss [73677.47, 67994.36, 5683.1123]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 12 loss [74152.86, 69586.22, 4566.638]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 13 loss [73684.57, 69268.48, 4416.095]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 14 loss [74099.805, 69003.61, 5096.1943]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 15 loss [72237.555, 67773.86, 4463.697]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 16 loss [74305.36, 70074.836, 4230.524]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 17 loss [73038.05, 69431.06, 3606.9863]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 18 loss [73152.72, 68698.08, 4454.6387]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 19 loss [72786.11, 68652.84, 4133.2676]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 20 loss [70348.984, 66683.84, 3665.1445]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 21 loss [73169.3, 69600.39, 3568.9053]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 22 loss [73382.4, 70317.06, 3065.3325]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 23 loss [73278.625, 70058.65, 3219.9788]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 24 loss [72815.87, 69369.3, 3446.5679]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 25 loss [72391.695, 68728.64, 3663.0537]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 26 loss [69726.945, 66400.945, 3326.0015]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 27 loss [73281.96, 69984.914, 3297.0447]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 28 loss [72097.13, 69097.75, 2999.3813]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 29 loss [71920.15, 68768.22, 3151.929]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 30 loss [70425.125, 67511.12, 2914.0083]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 31 loss [71563.47, 68497.77, 3065.6978]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 32 loss [71339.07, 68791.086, 2547.9807]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 33 loss [71866.414, 69062.21, 2804.2053]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 34 loss [73060.92, 69897.625, 3163.2942]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 35 loss [70918.56, 68216.22, 2702.3445]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 36 loss [72670.81, 69992.2, 2678.6064]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 37 loss [68897.49, 66177.5, 2719.9956]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 38 loss [71083.61, 68421.516, 2662.0898]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 39 loss [71478.74, 68977.484, 2501.2595]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 40 loss [73071.9, 70412.625, 2659.2732]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 41 loss [74007.52, 71230.26, 2777.2688]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))


epoch 42 loss [71858.25, 68965.01, 2893.241]


HBox(children=(IntProgress(value=0, max=4765), HTML(value='')))