Skip to content

Commit

Permalink
Merge in some of the TF2 stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
geometrikal committed Jan 13, 2024
1 parent c711ba0 commit 8554123
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 80 deletions.
6 changes: 3 additions & 3 deletions examples/deep_weeds.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
Train an image classifier on the deep weeds dataset
"""

from miso.training.parameters import MisoConfig
from miso.training.parameters import MisoParameters
from miso.training.trainer import train_image_classification_model

tp = MisoConfig()
tp = MisoParameters()

# -----------------------------------------------------------------------------
# Dataset
Expand Down Expand Up @@ -38,7 +38,7 @@
# - resnet[18,34,50]
# - vgg[16,19]
# - efficientnetB[0-7]
tp.cnn.type = "resnet50_tl"
tp.cnn.id = "resnet50_tl"
# Input image shape, set to None to use default size ([128, 128, 1] for custom, [224, 224, 3] for others)
tp.cnn.img_shape = [224, 224, 3]
# Input image colour space [greyscale/rgb]
Expand Down
8 changes: 4 additions & 4 deletions examples/endless_forams.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from miso.training.parameters import MisoConfig
from miso.training.parameters import MisoParameters
from miso.training.trainer import train_image_classification_model

tp = MisoConfig()
tp = MisoParameters()

# -----------------------------------------------------------------------------
# Name
Expand Down Expand Up @@ -44,8 +44,8 @@
# - resnet[18,34,50]
# - vgg[16,19]
# - efficientnetB[0-7]
tp.cnn.type = r"base_cyclic"
tp.cnn.type = r"resnet50_tl"
tp.cnn.id = r"base_cyclic"
tp.cnn.id = r"resnet50_tl"
# Input image shape, set to None to use default size ([128, 128, 1] for custom, [224, 224, 3] for others)
tp.cnn.img_shape = [128, 128, 1]
tp.cnn.img_shape = [224, 224, 3]
Expand Down
9 changes: 3 additions & 6 deletions examples/olzo_mini.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

from miso.training.parameters import MisoConfig
from miso.training.parameters import MisoParameters
from miso.training.trainer import train_image_classification_model

tp = MisoConfig()

import tensorflow as tf
print(tf.config.list_physical_devices('GPU'))
tp = MisoParameters()

# -----------------------------------------------------------------------------
# Dataset
Expand Down Expand Up @@ -47,7 +44,7 @@
# - resnet[18,34,50]
# - vgg[16,19]
# - efficientnetB[0-7]
tp.cnn.type = "base_cyclic"
tp.cnn.id = "convnexttiny"
# Input image shape, set to None to use default size ([128, 128, 1] for custom, [224, 224, 3] for others)
tp.cnn.img_shape = [128, 128, 1]
# Input image colour space [greyscale/rgb]
Expand Down
2 changes: 1 addition & 1 deletion miso/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

tp = MisoParameters()
tp.source = args.input
tp.output_dir = args.output
tp.save_dir = args.output
tp.cnn_type = args.type
tp.filters = args.filters
tp.min_count = args.min_count
Expand Down
10 changes: 5 additions & 5 deletions miso/deploy/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
from collections import OrderedDict

from miso.training.parameters import MisoConfig
from miso.training.parameters import MisoParameters


class ModelInfo:
Expand All @@ -15,7 +15,7 @@ def __init__(self,
type: str,
date: datetime.datetime,
protobuf: str,
params: MisoConfig,
params: MisoParameters,
inputs: OrderedDict,
outputs: OrderedDict,
data_source_name: str,
Expand Down Expand Up @@ -78,9 +78,9 @@ def to_xml(self):
ET.SubElement(root, "date").text = "{0:%Y-%m-%d_%H%M%S}".format(self.date)
ET.SubElement(root, "protobuf").text = self.protobuf

parent_node = ET.SubElement(root, "params")
for key, value in self.params.asdict().items():
ET.SubElement(parent_node, key).text = str(value)
# parent_node = ET.SubElement(root, "params")
# for key, value in self.params.asdict().items():
# ET.SubElement(parent_node, key).text = str(value)

parent_node = ET.SubElement(root, "inputs")
for name, tensor in self.inputs.items():
Expand Down
69 changes: 44 additions & 25 deletions miso/models/factory.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,35 @@
import math

from keras import Model
from keras.optimizers import SGD, Adam

from miso.models.keras_models import head, tail, KERAS_MODEL_PARAMETERS
from miso.models.base_cyclic import *
from miso.models.resnet_cyclic import *
from miso.training.parameters import MisoConfig
from miso.training.parameters import MisoParameters

try:
from tensorflow.keras.applications.efficientnet import *
except ImportError:
pass


def generate(tp: MisoConfig):
def create_optimizer(tp: MisoParameters):
if tp.optimizer.name == "sgd":
opt = SGD(learning_rate=tp.optimizer.learning_rate,
decay=tp.optimizer.decay,
momentum=tp.optimizer.momentum,
nesterov=tp.optimizer.nesterov)
elif tp.optimizer.name == "adam":
opt = Adam(learning_rate=tp.optimizer.learning_rate, decay=tp.optimizer.decay)
else:
raise ValueError(f"The optimizer {tp.optimizer.name} is not supported, valid optimizers are: sgd, adam")
return opt


def generate(tp: MisoParameters):
# Base Cyclic - custom network created at CEREGE specifically for foraminifera by adding cyclic layers
if tp.cnn.type.startswith("base_cyclic"):
if tp.cnn.id.startswith("base_cyclic"):
model = base_cyclic(input_shape=tp.cnn.img_shape,
nb_classes=tp.dataset.num_classes,
filters=tp.cnn.filters,
Expand All @@ -29,7 +43,7 @@ def generate(tp: MisoConfig):
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

# ResNet Cyclic - custom network created at CEREGE specifically for foraminifera by adding cyclic layers
elif tp.cnn.type.startswith("resnet_cyclic"):
elif tp.cnn.id.startswith("resnet_cyclic"):
blocks = int(math.log2(tp.cnn.img_shape[0]) - 2)
blocks -= 1 # Resnet has one block to start with already
resnet_params = ResnetModelParameters('resnet_cyclic',
Expand All @@ -40,21 +54,21 @@ def generate(tp: MisoConfig):
use_cyclic=True,
global_pooling=tp.cnn.global_pooling)
model = ResNetCyclic(resnet_params, tp.cnn.img_shape, None, True, tp.dataset.num_classes)
opt = Adam(lr=0.001)
opt = create_optimizer(tp)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

# Standard networks in keras applications (with options for transfer learning or cyclic gain
else:
# Legacy support where transfer learning, cyclic and gain were part of the cnn type
parts = tp.cnn.type.split("_")
parts = tp.cnn.id.split("_")
if len(parts) > 0:
tp.cnn.type = parts[0]
tp.cnn.id = parts[0]
if "cyclic" in parts:
tp.cnn.use_cyclic = True
if "cyclicgain" in parts:
tp.cnn.use_cyclic_gain = True
if tp.cnn.type in KERAS_MODEL_PARAMETERS.keys():
model_head = head(tp.cnn.type,
if "gain" in parts:
tp.cnn.use_cyclic_gain = True
if tp.cnn.id in KERAS_MODEL_PARAMETERS.keys():
model_head = head(tp.cnn.id,
use_cyclic=tp.cnn.use_cyclic,
use_gain=tp.cnn.use_cyclic_gain,
input_shape=tp.cnn.img_shape,
Expand All @@ -65,24 +79,24 @@ def generate(tp: MisoConfig):
model = combine_head_and_tail(model_head, model_tail)
else:
raise ValueError(
"The CNN type {} is not supported, valid CNNs are {}".format(tp.cnn.type, KERAS_MODEL_PARAMETERS.keys()))
"The CNN type {} is not supported, valid CNNs are {}".format(tp.cnn.id, KERAS_MODEL_PARAMETERS.keys()))
# opt = SGD(lr=0.01, decay=1e-6, momentum=0.9, nesterov=True)
opt = Adam(lr=0.001)
opt = create_optimizer(tp)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
return model


def generate_tl(tp: MisoConfig):
def generate_tl(tp: MisoParameters):
# Legacy support where transfer learning, cyclic and gain were part of the cnn type
parts = tp.cnn.type.split("_")
parts = tp.cnn.id.split("_")
if len(parts) > 0:
tp.cnn.type = parts[0]
tp.cnn.id = parts[0]
if "cyclic" in parts:
tp.cnn.use_cyclic = True
if "cyclicgain" in parts:
tp.cnn.use_cyclic_gain = True
if tp.cnn.type in KERAS_MODEL_PARAMETERS.keys():
model_head = head(tp.cnn.type,
if "gain" in parts:
tp.cnn.use_cyclic_gain = True
if tp.cnn.id in KERAS_MODEL_PARAMETERS.keys():
model_head = head(tp.cnn.id,
use_cyclic=tp.cnn.use_cyclic,
use_gain=tp.cnn.use_cyclic_gain,
input_shape=tp.cnn.img_shape,
Expand All @@ -92,8 +106,8 @@ def generate_tl(tp: MisoConfig):
model_tail = tail(tp.dataset.num_classes, [model_head.layers[-1].output.shape[-1], ])
else:
raise ValueError(
"The CNN type {} is not supported, valid CNNs are {}".format(tp.cnn.type, KERAS_MODEL_PARAMETERS.keys()))
opt = Adam(lr=0.001)
"The CNN type {} is not supported, valid CNNs are {}".format(tp.cnn.id, KERAS_MODEL_PARAMETERS.keys()))
opt = create_optimizer(tp)
model_tail.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])
return model_head, model_tail

Expand All @@ -102,7 +116,12 @@ def combine_head_and_tail(model_head, model_tail):
return Model(inputs=model_head.input, outputs=model_tail.call(model_head.output))


def generate_vector_from_model(model, cnn_type):
vector_tensor = model.get_layer(index=-2).get_output_at(1)
vector_model = Model(model.inputs, vector_tensor)
def generate_vector_from_model(model, tp):
try:
vector_model = Model(model.inputs, model.get_layer(index=-2).get_output_at(1))
return vector_model
except:
pass

vector_model = Model(model.inputs, model.get_layer(index=-2).get_output_at(0))
return vector_model
6 changes: 5 additions & 1 deletion miso/models/keras_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,9 @@ def tail_vector(num_classes, input_shape, dropout=(0.5, 0.5)):
'efficientnetb5': KerasModelParameters(ka.efficientnet.EfficientNetB5, no_prepro, [456, 456, 3]),
'efficientnetb6': KerasModelParameters(ka.efficientnet.EfficientNetB6, no_prepro, [528, 528, 3]),
'efficientnetb7': KerasModelParameters(ka.efficientnet.EfficientNetB7, no_prepro, [600, 600, 3]),
'convnexttiny': KerasModelParameters(ka.convnext.ConvNeXtTiny, no_prepro, [600, 600, 3]),
'convnexttiny': KerasModelParameters(ka.convnext.ConvNeXtTiny, no_prepro, [224, 224, 3]),
'convnextsmall': KerasModelParameters(ka.convnext.ConvNeXtSmall, no_prepro, [224, 224, 3]),
'convnextbase': KerasModelParameters(ka.convnext.ConvNeXtBase, no_prepro, [224, 224, 3]),
'convnextlarge': KerasModelParameters(ka.convnext.ConvNeXtLarge, no_prepro, [224, 224, 3]),
'convnextxlarge': KerasModelParameters(ka.convnext.ConvNeXtXLarge, no_prepro, [224, 224, 3]),
}
50 changes: 30 additions & 20 deletions miso/training/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def default_field(obj):
return field(default_factory=lambda: copy.copy(obj))

@dataclass
class BaseConfig(object):
class BaseParameters(object):
class Meta:
ordered = True

Expand All @@ -42,9 +42,9 @@ def load(cls, path: Union[Path, str]):


@dataclass
class ModelConfig(BaseConfig):
class ModelParameters(BaseParameters):
# Common values
type: str = "base_cyclic"
id: str = "base_cyclic"
img_shape: List[int] = field(default=(128, 128, 1))
img_type: str = "greyscale"

Expand All @@ -63,7 +63,7 @@ class ModelConfig(BaseConfig):


@dataclass
class TrainingConfig(BaseConfig):
class TrainingParameters(BaseParameters):
batch_size: int = 64
max_epochs: int = 10000
alr_epochs: int = 10
Expand All @@ -77,7 +77,16 @@ class TrainingConfig(BaseConfig):


@dataclass
class DatasetConfig(BaseConfig):
class OptimizerParameters(BaseParameters):
name: str = "adam"
learning_rate: float = 0.001
momentum: float = 0.9
decay: float = 0.0
nesterov: bool = False


@dataclass
class DatasetParameters(BaseParameters):
num_classes: Optional[int] = None
source: Optional[str] = None
min_count: Optional[int] = 10
Expand All @@ -89,7 +98,7 @@ class DatasetConfig(BaseConfig):


@dataclass
class AugmentationConfig(BaseConfig):
class AugmentationParameters(BaseParameters):
rotation: List[float] = field(default=(0, 360))
gain: Optional[List[float]] = field(default=(0.8, 1, 1.2))
gamma: Optional[List[float]] = field(default=(0.5, 1, 2))
Expand All @@ -101,31 +110,32 @@ class AugmentationConfig(BaseConfig):


@dataclass
class OutputConfig(BaseConfig):
output_dir: str = None
class OutputParameters(BaseParameters):
save_dir: str = None
save_model: bool = True
save_mislabeled: bool = False


@dataclass
class MisoConfig(BaseConfig):
class MisoParameters(BaseParameters):
name: str = ""
description: str = ""
cnn: ModelConfig = default_field(ModelConfig())
dataset: DatasetConfig = default_field(DatasetConfig())
training: TrainingConfig = default_field(TrainingConfig())
augmentation: AugmentationConfig = default_field(AugmentationConfig())
output: OutputConfig = default_field(OutputConfig())
cnn: ModelParameters = default_field(ModelParameters())
dataset: DatasetParameters = default_field(DatasetParameters())
training: TrainingParameters = default_field(TrainingParameters())
augmentation: AugmentationParameters = default_field(AugmentationParameters())
output: OutputParameters = default_field(OutputParameters())
optimizer: OptimizerParameters = default_field(OptimizerParameters())

def sanitise(self):
if self.name == "":
self.name = self.dataset.source[:64] + "_" + self.cnn.type
self.name = self.dataset.source[:64] + "_" + self.cnn.id
self.name = re.sub('[^A-Za-z0-9]+', '-', self.name)
if self.cnn.img_shape is None:
if self.cnn.type.endswith("_tl"):
shape = KERAS_MODEL_PARAMETERS[self.cnn.type.split('_')[0]].default_input_shape
if self.cnn.id.endswith("_tl"):
shape = KERAS_MODEL_PARAMETERS[self.cnn.id.split('_')[0]].default_input_shape
else:
if self.cnn.type.startswith("base_cyclic") or self.cnn.type.startswith("resnet_cyclic"):
if self.cnn.id.startswith("base_cyclic") or self.cnn.id.startswith("resnet_cyclic"):
shape = [128, 128, 3]
else:
shape = [224, 224, 3]
Expand All @@ -136,7 +146,7 @@ def sanitise(self):
shape[2] = 1
self.augmentation.orig_img_shape[2] = 3
self.cnn.img_shape = shape
elif self.cnn.type.startswith("base_cyclic") or self.cnn.type.startswith("resnet_cyclic"):
elif self.cnn.id.startswith("base_cyclic") or self.cnn.id.startswith("resnet_cyclic"):
pass
else:
self.cnn.img_shape[2] = 3
Expand All @@ -150,7 +160,7 @@ def get_default_shape(cnn_type):


if __name__ == "__main__":
m = MisoConfig()
m = MisoParameters()
print(m.dumps())


Expand Down
Loading

0 comments on commit 8554123

Please sign in to comment.