From ee6229503e1fcd1f57bb2da16b98305b4779c0dc Mon Sep 17 00:00:00 2001 From: Rishab Mallick Date: Mon, 5 Jun 2023 18:04:44 +0530 Subject: [PATCH] feat: adds efficientnetv1 model arch --- ivy_models/efficientnet/efficientnetv1.py | 317 ++++++++++ ivy_models/efficientnet/variant_configs.json | 602 +++++++++++++++++++ 2 files changed, 919 insertions(+) create mode 100644 ivy_models/efficientnet/efficientnetv1.py create mode 100644 ivy_models/efficientnet/variant_configs.json diff --git a/ivy_models/efficientnet/efficientnetv1.py b/ivy_models/efficientnet/efficientnetv1.py new file mode 100644 index 00000000..e687c510 --- /dev/null +++ b/ivy_models/efficientnet/efficientnetv1.py @@ -0,0 +1,317 @@ +import ivy +import math + + +class CNNBlock(ivy.Module): + def __init__( + self, + input_channels, + output_channels, + kernel_size, + stride, + padding, + training: bool = False, + ): + """ + Helper module used in the MBConv and FusedMBConvBlock. Basic CNN + block with batch norm and SiLU activation layer. + + Parameters + ---------- + input_channels + Number of input channels for the layer. + output_channels + Number of output channels for the layer. + kernel_size + Size of the convolutional filter. + stride + The stride of the sliding window for each dimension of input. + padding + SAME" or "VALID" indicating the algorithm, or + list indicating the per-dimension paddings. + """ + self.training = training + self.input_channels = input_channels + self.output_channels = output_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + super(CNNBlock, self).__init__() + + def _build(self, *args, **kwargs) -> bool: + self.v0 = ivy.Conv2D( + self.input_channels, + self.output_channels, + [self.kernel_size, self.kernel_size], + self.stride, + self.padding, + with_bias=False, + ) + self.v1 = ivy.BatchNorm2D(self.output_channels) + self.silu = ivy.SiLU() + + + def _forward(self, x): + return self.silu(self.v1(self.v0(x))) + + +class SqueezeExcitation(ivy.Module): + def __init__(self, input_channels, reduced_dim): + """ + Helper module used in the MBConv and FusedMBConvBlock. + + Parameters + ---------- + input_channels + Number of input channels for the layer. + reduced_dim + Number of dimensionality reduction applied during the squeeze phase + """ + self.fc1 = ivy.Conv2D(input_channels, reduced_dim, [1, 1], 1, "VALID") + self.fc2 = ivy.Conv2D(reduced_dim, input_channels, [1, 1], 1, "VALID") + self.silu = ivy.SiLU() + super(SqueezeExcitation, self).__init__() + + def _forward(self, x): + # N x H x W x C -> N x C x H x W + x = ivy.reshape(x, shape=(x.shape[0], x.shape[3], x.shape[1], x.shape[2])) + x = ivy.adaptive_avg_pool2d(x, 1) # C x H x W -> C x 1 x 1 + x = ivy.reshape(x, shape=(x.shape[0], x.shape[2], x.shape[3], x.shape[1])) + x = self.fc1(x) + x = self.silu(x) + x = self.fc2(x) + return self.silu(x) + + +class MBConvBlock(ivy.Module): + def __init__( + self, + input_channels, + output_channels, + kernel_size, + stride, + expand_ratio, + padding="VALID", + reduction_ratio=4, + survival_prob=0.8, + training: bool = False, + ): + """ + Instantiates the Mobile Inverted Residual Bottleneck Block + + Parameters + ---------- + input_channels + Number of input channels for the layer. + output_channels + Number of output channels for the layer. + kernel_size + Size of the convolutional filter. + stride + The stride of the sliding window for each dimension of input. + expand_ratio + Degree of input channel expansion. + padding + SAME" or "VALID" indicating the algorithm, or + list indicating the per-dimension paddings. + reduction_ratio + Dimensionality reduction in squeeze excitation. + survival_prob + Hyperparameter for stochastic depth. + """ + self.training = training + self.input_channels = input_channels + self.output_channels = output_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + + self.use_residual = input_channels == output_channels and stride == 1 + self.hidden_dim = input_channels * expand_ratio + self.expand = input_channels != self.hidden_dim + self.reduced_dim = int(self.input_channels / reduction_ratio) + self.reduction_ratio = reduction_ratio + self.survival_prob = survival_prob + super(MBConvBlock, self).__init__() + + def _build(self, *args, **kwrgs): + self.block = [] + if self.expand: + self.block.append(CNNBlock( + self.input_channels, + self.hidden_dim, + kernel_size=1, + stride=1, + padding=self.padding, + training=self.training, + )) + + self.block += [ + ivy.DepthwiseConv2D( + self.hidden_dim, + [self.kernel_size, self.kernel_size], + self.stride, + self.padding, + with_bias=False, + ), + ivy.BatchNorm2D(self.hidden_dim), + ivy.SiLU(), + SqueezeExcitation(self.hidden_dim, self.reduced_dim), + CNNBlock( + self.hidden_dim, + self.output_channels, + kernel_size=1, + stride=1, + padding=self.padding, + training=self.training, + ) + ] + + def stochastic_depth(self, x): + binary_tensor = ( + ivy.random_uniform( + shape=(x.shape[0], 1, 1, 1), low=0, high=1, device=x.device + ) + < self.survival_prob + ) + return ivy.divide(x, self.survival_prob) * binary_tensor + + def _forward(self, inputs): + x = inputs + for layer in self.block: + x = layer(x) + if self.use_residual: + return self.stochastic_depth(x) + inputs + return x + + +class EfficientNetV1(ivy.Module): + def __init__( + self, + base_model, # expand_ratio, channels, repeats, stride, kernel_size + phi_values, # phi_value, resolution, drop_rate + num_classes, + device="cuda:0", + training: bool = True, + v: ivy.Container = None, + ): + """ + Instantiates the EfficientNetV1 architecture using given scaling + coefficients. + + Parameters + ---------- + base_model + Base model configuration. Should contain expand_ratio, + channels, repeats, stride, kernel_size + phi_values + Variant specific configuration. Should contain phi, resolution, + dropout_rate + num_classes + Number of classes to classify images into. + device + device on which to create the model's variables 'cuda:0', 'cuda:1', 'cpu' + etc. Default is cuda:0. + training + Default is ``True``. + v + the variables for the model, as a container, constructed internally + by default. + """ + self.training = training + self.base_model = base_model + alpha = 1.2 + beta = 1.1 + self.dropout_rate = phi_values["dropout_rate"] + self.depth_factor = alpha ** phi_values["phi"] + self.width_factor = beta ** phi_values["phi"] + self.num_classes = num_classes + self.last_channels = math.ceil(1280 * self.width_factor) + self.se_reduction_ratio = 4 + super(EfficientNetV1, self).__init__(device=device, training=self.training) + if v is not None: + self.v = v + + def _build(self, *args, **kwrgs): + channels = int(32 * self.width_factor) + features = [ + CNNBlock(3, channels, 3, stride=2, padding=1, training=self.training) + ] + in_channels = channels + + for args in self.base_model: + out_channels = self.se_reduction_ratio * int(ivy.ceil( + int(args["channels"] * self.width_factor) / self.se_reduction_ratio + ).item()) + layers_repeats = int(ivy.ceil(args["repeats"] * self.depth_factor).item()) + + for layer in range(layers_repeats): + features.append( + MBConvBlock( + in_channels, + out_channels, + expand_ratio=args["expand_ratio"], + stride=args["stride"] if layer == 0 else 1, + kernel_size=args["kernel_size"], + padding="SAME", + reduction_ratio=self.se_reduction_ratio, + training=self.training, + ) + ) + in_channels = out_channels + features.append( + CNNBlock( + in_channels, + self.last_channels, + kernel_size=1, + stride=1, + padding="SAME", + training=self.training, + ) + ) + self.features = features + + self.classifier = [ + ivy.Dropout(self.dropout_rate, training=self.training), + ivy.Linear(self.last_channels, self.num_classes), + ] + + def _forward(self, x): + # x = self.features(x) + for layer in self.features: + x = layer(x) + # N x H x W x C -> N x C x H x W + x = ivy.reshape(x, shape=(x.shape[0], x.shape[3], x.shape[1], x.shape[2])) + x = ivy.adaptive_avg_pool2d(x, 1) + x = ivy.reshape(x, shape=(x.shape[0], -1)) + for layer in self.classifier: + x = layer(x) + return x + + +if __name__ == "__main__": + import json + + # ivy.set_tensorflow_backend() + ivy.set_jax_backend() + import jax + + jax.config.update("jax_enable_x64", True) + + with open("variant_configs.json") as json_file: + configs = json.load(json_file) + + configs = configs["v1"] + base_model = configs["base_args"] + phi_values = configs["phi_values"]["b0"] + + model = EfficientNetV1( + base_model, + phi_values, + 1000, + ) + # print(model.v) + + res = phi_values["resolution"] + x = ivy.random_normal(shape=(16, res, res, 3)) + print(model(x).shape) diff --git a/ivy_models/efficientnet/variant_configs.json b/ivy_models/efficientnet/variant_configs.json new file mode 100644 index 00000000..ff683bf1 --- /dev/null +++ b/ivy_models/efficientnet/variant_configs.json @@ -0,0 +1,602 @@ +{ + "v1": { + "base_args": [ + { + "expand_ratio": 1, + "channels": 16, + "repeats": 1, + "stride": 1, + "kernel_size": 3 + }, + { + "expand_ratio": 6, + "channels": 24, + "repeats": 2, + "stride": 2, + "kernel_size": 3 + }, + { + "expand_ratio": 6, + "channels": 40, + "repeats": 2, + "stride": 2, + "kernel_size": 5 + }, + { + "expand_ratio": 6, + "channels": 80, + "repeats": 3, + "stride": 2, + "kernel_size": 3 + }, + { + "expand_ratio": 6, + "channels": 112, + "repeats": 3, + "stride": 1, + "kernel_size": 5 + }, + { + "expand_ratio": 6, + "channels": 192, + "repeats": 4, + "stride": 2, + "kernel_size": 5 + }, + { + "expand_ratio": 6, + "channels": 320, + "repeats": 1, + "stride": 1, + "kernel_size": 3 + } + ], + "phi_values": { + "b0": { + "phi": 0, + "resolution": 224, + "dropout_rate": 0.2 + }, + "b1": { + "phi": 0.5, + "resolution": 240, + "dropout_rate": 0.2 + }, + "b2": { + "phi": 1, + "resolution": 260, + "dropout_rate": 0.3 + }, + "b3": { + "phi": 2, + "resolution": 300, + "dropout_rate": 0.3 + }, + "b4": { + "phi": 3, + "resolution": 380, + "dropout_rate": 0.4 + }, + "b5": { + "phi": 4, + "resolution": 456, + "dropout_rate": 0.4 + }, + "b6": { + "phi": 5, + "resolution": 528, + "dropout_rate": 0.5 + }, + "b7": { + "phi": 6, + "resolution": 600, + "dropout_rate": 0.5 + } + } + }, + "v2": { + "efficientnetv2-s": { + "phi_values": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "resolution": 384 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 24, + "output_filters": 24, + "expand_ratio": 1, + "se_ratio": 0.0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 4, + "input_filters": 24, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0.0, + "strides": 2, + "conv_type": 1 + }, + { + "conv_type": 1, + "expand_ratio": 4, + "input_filters": 48, + "kernel_size": 3, + "num_repeat": 4, + "output_filters": 64, + "se_ratio": 0, + "strides": 2 + }, + { + "conv_type": 0, + "expand_ratio": 4, + "input_filters": 64, + "kernel_size": 3, + "num_repeat": 6, + "output_filters": 128, + "se_ratio": 0.25, + "strides": 2 + }, + { + "conv_type": 0, + "expand_ratio": 6, + "input_filters": 128, + "kernel_size": 3, + "num_repeat": 9, + "output_filters": 160, + "se_ratio": 0.25, + "strides": 1 + }, + { + "conv_type": 0, + "expand_ratio": 6, + "input_filters": 160, + "kernel_size": 3, + "num_repeat": 15, + "output_filters": 256, + "se_ratio": 0.25, + "strides": 2 + } + ] + }, + "efficientnetv2-m": { + "phi_values": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "resolution": 480 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 24, + "output_filters": 24, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 24, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 48, + "output_filters": 80, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 80, + "output_filters": 160, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 14, + "input_filters": 160, + "output_filters": 176, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 18, + "input_filters": 176, + "output_filters": 304, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 304, + "output_filters": 512, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + } + ] + }, + "efficientnetv2-l": { + "phi_values": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "resolution": 480 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 4, + "input_filters": 32, + "output_filters": 32, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 32, + "output_filters": 64, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 64, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 10, + "input_filters": 96, + "output_filters": 192, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 19, + "input_filters": 192, + "output_filters": 224, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 25, + "input_filters": 224, + "output_filters": 384, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 7, + "input_filters": 384, + "output_filters": 640, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + } + ] + }, + "efficientnetv2-b0": { + "phi_values": { + "width_coefficient": 1.0, + "depth_coefficient": 1.0, + "resolution": 224 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + } + ] + }, + "efficientnetv2-b1": { + "phi_values": { + "width_coefficient": 1.0, + "depth_coefficient": 1.1, + "resolution": 240 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + } + ] + }, + "efficientnetv2-b2": { + "phi_values": { + "width_coefficient": 1.1, + "depth_coefficient": 1.2, + "resolution": 260 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + } + ] + }, + "efficientnetv2-b3": { + "phi_values": { + "width_coefficient": 1.2, + "depth_coefficient": 1.4, + "resolution": 300 + }, + "blocks": [ + { + "kernel_size": 3, + "num_repeat": 1, + "input_filters": 32, + "output_filters": 16, + "expand_ratio": 1, + "se_ratio": 0, + "strides": 1, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 16, + "output_filters": 32, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 2, + "input_filters": 32, + "output_filters": 48, + "expand_ratio": 4, + "se_ratio": 0, + "strides": 2, + "conv_type": 1 + }, + { + "kernel_size": 3, + "num_repeat": 3, + "input_filters": 48, + "output_filters": 96, + "expand_ratio": 4, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 5, + "input_filters": 96, + "output_filters": 112, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 1, + "conv_type": 0 + }, + { + "kernel_size": 3, + "num_repeat": 8, + "input_filters": 112, + "output_filters": 192, + "expand_ratio": 6, + "se_ratio": 0.25, + "strides": 2, + "conv_type": 0 + } + ] + } + } +} \ No newline at end of file