In [1]:
%load_ext tensorboard

In [2]:
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import tensorflow as tf

from tensorflow.random import set_seed
from tensorflow.keras.preprocessing import image_dataset_from_directory as im_ds_from_dir
from tensorflow.keras.metrics import MeanSquaredError as MSE
from tensorflow import boolean_mask
from sklearn.model_selection import train_test_split

import numpy as np
import cv2 as cv
import sklearn.neighbors as sk

from active_thresh_segseg import segseg
from microscope_image_preprocessing import load_image

#PATH TO DIR OR FILE
PATH = '2024-07-29'

num_bacteria = 3

In [3]:
image_dir='images/minis'
target_dir='images/minis'
target_size=(32,32) # set this to the target size you want
channels=3 # for color images
color_mode='grayscale'
shuffle=True,
seed=123
class_mode=None
batch_size=32 # set this to desired batch size
vsplit=.2 # set this to the validation split you want
image_gen=im_ds_from_dir(image_dir, color_mode=color_mode, batch_size=batch_size,image_size=target_size,seed=seed,subset='training',
                         validation_split=vsplit, labels=None,data_format='channels_last')
val_gen=im_ds_from_dir(image_dir, color_mode=color_mode, batch_size=batch_size,image_size=target_size,seed=seed,subset='validation',
                         validation_split=vsplit, labels=None,data_format='channels_last')
norm_layer = tf.keras.layers.Rescaling(1./255)
image_gen = image_gen.map(lambda x: (norm_layer(x) - 1e-6,norm_layer(x) - 1e-6))
val_gen = val_gen.map(lambda x: (norm_layer(x) - 1e-6,norm_layer(x) - 1e-6))

for x in image_gen.unbatch().as_numpy_iterator():
    print(x[1])
    break

Found 182 files.
Using 146 files for training.
Found 182 files.
Using 36 files for validation.
[[[0.8470579 ]
  [0.85097945]
  [0.8431363 ]
  ...
  [0.8313716 ]
  [0.8431363 ]
  [0.83921474]]

 [[0.8588226 ]
  [0.8588226 ]
  [0.85097945]
  ...
  [0.82352847]
  [0.8352932 ]
  [0.8313716 ]]

 [[0.8666657 ]
  [0.8666657 ]
  [0.8588226 ]
  ...
  [0.8196069 ]
  [0.8313716 ]
  [0.8313716 ]]

 ...

 [[0.8039206 ]
  [0.79999906]
  [0.78823435]
  ...
  [0.7843128 ]
  [0.78823435]
  [0.7843128 ]]

 [[0.8196069 ]
  [0.81568533]
  [0.8078422 ]
  ...
  [0.7921559 ]
  [0.7960775 ]
  [0.7921559 ]]

 [[0.8313716 ]
  [0.82745004]
  [0.82352847]
  ...
  [0.8039206 ]
  [0.8078422 ]
  [0.8078422 ]]]


In [4]:
def conv(layer, channel_num):
    layer = layers.Conv2D(channel_num, 3, padding='same', activation='tanh')(layer)
    layer = layers.Conv2D(channel_num, 3, padding='same', activation='tanh')(layer)
    return layer

def downconv(layer, channel_num, dropout):
    c = conv(layer, channel_num)
    mp = layers.MaxPooling2D(2)(c)
    layer = layers.Dropout(dropout)(mp)
    return c, layer

def upconv(layer, skip, channel_num, dropout):
    layer = layers.Conv2DTranspose(channel_num, 3, 2, padding='same', activation='tanh')(layer)
    layer = layers.concatenate([layer, skip])
    layer = layers.Dropout(dropout)(layer)
    layer = conv(layer, channel_num)
    return layer

In [5]:
input_layer = tf.keras.Input(shape=(32,32,1))

skip1, encode2 = downconv(input_layer,32,0.1)
skip2, encode2 = downconv(encode2, 64, 0.2)
skip3, encode2 = downconv(encode2, 128, 0.25)
bottleneck = conv(encode2, 256)
encode2 = upconv(bottleneck, skip3, 128, 0.25)
encode2 = upconv(encode2, skip2, 64, 0.2)
encode2 = upconv(encode2, skip1, 32, 0.1)
autoencode = keras.layers.Conv2D(8, 3, padding='same', activation='relu')(encode2)
autoencode = keras.layers.Conv2D(1, 3, padding='same', activation='sigmoid')(autoencode)

# encoder.summary()
encenc = keras.Model(inputs=input_layer,outputs=autoencode)
encenc.summary()

In [6]:
import datetime
SEED = 1
set_seed(SEED)
log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

callbacks = [tf.keras.callbacks.EarlyStopping(monitor='loss', restore_best_weights=True, patience=20, start_from_epoch=15),
             tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=20, update_freq=10)]  
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
epochs = 10


encenc.compile(optimizer=optimizer, loss="mse")
encenc.fit(image_gen, epochs=epochs,callbacks=callbacks,batch_size=batch_size, validation_data=val_gen)

Epoch 1/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 201ms/step - loss: 0.0834 - val_loss: 0.0105
Epoch 2/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 88ms/step - loss: 0.0117 - val_loss: 0.0109
Epoch 3/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 86ms/step - loss: 0.0094 - val_loss: 0.0084
Epoch 4/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 86ms/step - loss: 0.0075 - val_loss: 0.0060
Epoch 5/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 82ms/step - loss: 0.0061 - val_loss: 0.0052
Epoch 6/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 86ms/step - loss: 0.0053 - val_loss: 0.0050
Epoch 7/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 85ms/step - loss: 0.0047 - val_loss: 0.0044
Epoch 8/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 85ms/step - loss: 0.0045 - val_loss: 0.0041
Epoch 9/10
[1m5/5[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m 

<keras.src.callbacks.history.History at 0x1baa947a850>

In [7]:
encenc.predict(next(iter(image_gen))[0])

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 258ms/step


array([[[[0.5630633 ],
         [0.7216988 ],
         [0.7576585 ],
         ...,
         [0.80446875],
         [0.8356439 ],
         [0.81146437]],

        [[0.7389389 ],
         [0.79758173],
         [0.8168979 ],
         ...,
         [0.8240038 ],
         [0.8228979 ],
         [0.8210819 ]],

        [[0.79236686],
         [0.78667945],
         [0.79610157],
         ...,
         [0.79868335],
         [0.8300126 ],
         [0.77938855]],

        ...,

        [[0.80102974],
         [0.8287561 ],
         [0.79688525],
         ...,
         [0.8009677 ],
         [0.79269516],
         [0.81127423]],

        [[0.80124867],
         [0.77207255],
         [0.7737455 ],
         ...,
         [0.77100015],
         [0.7960275 ],
         [0.74124783]],

        [[0.7421239 ],
         [0.81760335],
         [0.7896765 ],
         ...,
         [0.81809986],
         [0.7441114 ],
         [0.7160452 ]]],


       [[[0.560402  ],
         [0.7222151 ],
         [0.75

In [8]:
print(encenc.trainable)
print(next(iter(image_gen.unbatch()))[0])
encenc.summary()

True
tf.Tensor(
[[[0.8431363 ]
  [0.8470579 ]
  [0.854901  ]
  ...
  [0.82352847]
  [0.8196069 ]
  [0.81176376]]

 [[0.8352932 ]
  [0.8352932 ]
  [0.83921474]
  ...
  [0.81568533]
  [0.81568533]
  [0.81176376]]

 [[0.81176376]
  [0.81176376]
  [0.81176376]
  ...
  [0.82745004]
  [0.8352932 ]
  [0.8352932 ]]

 ...

 [[0.8941167 ]
  [0.89803827]
  [0.90195984]
  ...
  [0.8666657 ]
  [0.882352  ]
  [0.882352  ]]

 [[0.8941167 ]
  [0.89803827]
  [0.90195984]
  ...
  [0.8705873 ]
  [0.88627356]
  [0.88627356]]

 [[0.8941167 ]
  [0.89803827]
  [0.89803827]
  ...
  [0.87450886]
  [0.88627356]
  [0.88627356]]], shape=(32, 32, 1), dtype=float32)


# tensorboard commands 
!rd /s /q "./logs/" 

tensorboard --logdir logs/fit --host localhost --port=8080