From 24e6a94814bbb546d70cc1c686a05bb74c86106f Mon Sep 17 00:00:00 2001 From: Zaeem Ansari <99063526+zaeemansari70@users.noreply.github.com> Date: Sun, 28 May 2023 18:43:30 +0500 Subject: [PATCH] UNet's complete architecture --- ivy_models/UNET/UNET.py | 63 +++++++++++++++++++++++++++++++++++++---- 1 file changed, 57 insertions(+), 6 deletions(-) diff --git a/ivy_models/UNET/UNET.py b/ivy_models/UNET/UNET.py index 5224bb57..9bb218ea 100644 --- a/ivy_models/UNET/UNET.py +++ b/ivy_models/UNET/UNET.py @@ -1,7 +1,5 @@ import ivy -ivy.set_backend("tensorflow") - def double_conv(in_c, out_c): conv = ivy.Sequential( @@ -12,21 +10,60 @@ def double_conv(in_c, out_c): ) return conv +def crop_img(tensor, target_tensor): + target_size = target_tensor.shape[2] + tensor_size = tensor.shape[2] + delta = tensor_size - target_size + delta = delta // 2 + return tensor[:, delta:tensor_size-delta, delta:tensor_size-delta, :] class UNet(ivy.Module): def __init__(self): super(UNet, self).__init__() - self.pool = ivy.MaxPool2D(2, 2, 0) self.down_conv_1 = double_conv(1, 64) self.down_conv_2 = double_conv(64, 128) self.down_conv_3 = double_conv(128, 256) self.down_conv_4 = double_conv(256, 512) self.down_conv_5 = double_conv(512, 1024) + self.up_trans_1 = ivy.Conv2DTranspose( + 1024, + 512, + [2, 2], + 2, + "VALID" + ) + self.up_conv_1 = double_conv(1024, 512) + self.up_trans_2 = ivy.Conv2DTranspose( + 512, + 256, + [2, 2], + 2, + "VALID" + ) + self.up_conv_2 = double_conv(512, 256) + self.up_trans_3 = ivy.Conv2DTranspose( + 256, + 128, + [2, 2], + 2, + "VALID" + ) + self.up_conv_3 = double_conv(256, 128) + self.up_trans_4 = ivy.Conv2DTranspose( + 128, + 64, + [2, 2], + 2, + "VALID" + ) + self.up_conv_4 = double_conv(128, 64) + self.out = ivy.Conv2D(64, 2, [1, 1], 1, 0) def _forward(self, image): + # B, H, W, C + #encoder x1 = self.down_conv_1(image) - print(x1.shape) x2 = self.pool(x1) x3 = self.down_conv_2(x2) x4 = self.pool(x3) @@ -35,9 +72,23 @@ def _forward(self, image): x7 = self.down_conv_4(x6) x8 = self.pool(x7) x9 = self.down_conv_5(x8) - print(x9.shape) + #decoder + x = self.up_trans_1(x9) + y = crop_img(x7, x) + x = self.up_conv_1(ivy.concat([x, y], axis=-1)) + x = self.up_trans_2(x) + y = crop_img(x5, x) + x = self.up_conv_2(ivy.concat([x, y], axis=-1)) + x = self.up_trans_3(x) + y = crop_img(x3, x) + x = self.up_conv_3(ivy.concat([x, y], axis=-1)) + x = self.up_trans_4(x) + y = crop_img(x1, x) + x = self.up_conv_4(ivy.concat([x, y], axis=-1)) + x = self.out(x) + return x image = ivy.random_normal(shape=(1, 572, 572, 1)) model = UNet() -print(model(image)) +model(image)