Skip to content

Commit

Permalink
UNet's complete architecture
Browse files Browse the repository at this point in the history
  • Loading branch information
zaeemansari70 committed May 28, 2023
1 parent 9eef35e commit 24e6a94
Showing 1 changed file with 57 additions and 6 deletions.
63 changes: 57 additions & 6 deletions ivy_models/UNET/UNET.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import ivy

ivy.set_backend("tensorflow")


def double_conv(in_c, out_c):
conv = ivy.Sequential(
Expand All @@ -12,21 +10,60 @@ def double_conv(in_c, out_c):
)
return conv

def crop_img(tensor, target_tensor):
target_size = target_tensor.shape[2]
tensor_size = tensor.shape[2]
delta = tensor_size - target_size
delta = delta // 2
return tensor[:, delta:tensor_size-delta, delta:tensor_size-delta, :]

class UNet(ivy.Module):
def __init__(self):
super(UNet, self).__init__()

self.pool = ivy.MaxPool2D(2, 2, 0)
self.down_conv_1 = double_conv(1, 64)
self.down_conv_2 = double_conv(64, 128)
self.down_conv_3 = double_conv(128, 256)
self.down_conv_4 = double_conv(256, 512)
self.down_conv_5 = double_conv(512, 1024)
self.up_trans_1 = ivy.Conv2DTranspose(
1024,
512,
[2, 2],
2,
"VALID"
)
self.up_conv_1 = double_conv(1024, 512)
self.up_trans_2 = ivy.Conv2DTranspose(
512,
256,
[2, 2],
2,
"VALID"
)
self.up_conv_2 = double_conv(512, 256)
self.up_trans_3 = ivy.Conv2DTranspose(
256,
128,
[2, 2],
2,
"VALID"
)
self.up_conv_3 = double_conv(256, 128)
self.up_trans_4 = ivy.Conv2DTranspose(
128,
64,
[2, 2],
2,
"VALID"
)
self.up_conv_4 = double_conv(128, 64)
self.out = ivy.Conv2D(64, 2, [1, 1], 1, 0)

def _forward(self, image):
# B, H, W, C
#encoder
x1 = self.down_conv_1(image)
print(x1.shape)
x2 = self.pool(x1)
x3 = self.down_conv_2(x2)
x4 = self.pool(x3)
Expand All @@ -35,9 +72,23 @@ def _forward(self, image):
x7 = self.down_conv_4(x6)
x8 = self.pool(x7)
x9 = self.down_conv_5(x8)
print(x9.shape)

#decoder
x = self.up_trans_1(x9)
y = crop_img(x7, x)
x = self.up_conv_1(ivy.concat([x, y], axis=-1))
x = self.up_trans_2(x)
y = crop_img(x5, x)
x = self.up_conv_2(ivy.concat([x, y], axis=-1))
x = self.up_trans_3(x)
y = crop_img(x3, x)
x = self.up_conv_3(ivy.concat([x, y], axis=-1))
x = self.up_trans_4(x)
y = crop_img(x1, x)
x = self.up_conv_4(ivy.concat([x, y], axis=-1))
x = self.out(x)
return x

image = ivy.random_normal(shape=(1, 572, 572, 1))
model = UNet()
print(model(image))
model(image)

0 comments on commit 24e6a94

Please sign in to comment.