In [3]:
# !pip install torchview
# !pip install graphviz
# !pip install pytorch_lightning

In [1]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from torchview import draw_graph
import graphviz

graphviz.set_jupyter_format('png')


'svg'

In [2]:
class UNetSegmenterConv(nn.Module):
  def __init__(self, in_channels_conv, out_channels_conv):
    super().__init__()
    self.conv2d_01 = nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels_conv, out_channels=out_channels_conv, padding='same')
    self.conv2d_02 = nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels_conv, out_channels=out_channels_conv, padding='same')

  def forward(self, x):
    x = self.conv2d_01(x)
    x = self.conv2d_02(x)
    return x


class UNetSegmenterConvMaxPool(nn.Module):
  def __init__(self, in_channels_conv, out_channels_conv):
    super().__init__()
    self.conv2d_01 = nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels_conv, out_channels=out_channels_conv, padding='same')
    self.conv2d_02 = nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels_conv, out_channels=out_channels_conv, padding='same')
    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=(2, 2))

  def forward(self, x):
    x = self.conv2d_01(x)
    x = self.conv2d_02(x)
    return x, self.maxpool(x)


class UNetSegmenterUpsampleMaxPool(nn.Module):
  def __init__(self, in_channels_conv, out_channels_conv):
    super().__init__()
    self.upsample = nn.ConvTranspose2d(in_channels=in_channels_conv, out_channels=out_channels_conv, kernel_size=2, stride=2)
    self.conv2d_01 = nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels_conv, out_channels=out_channels_conv, padding='same')
    self.conv2d_02 = nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels_conv, out_channels=out_channels_conv, padding='same')


  def forward(self, x):
    x_u = self.upsample(x[1])
    x = torch.concat([x[0], x_u], dim=1)
    x = self.conv2d_01(x)
    x = self.conv2d_02(x)
    return x

class UNetSegmenter(nn.Module):
  def __init__(self) -> None:
    super().__init__()
    self.conv_block_01 = UNetSegmenterConvMaxPool(in_channels_conv=3, out_channels_conv=16)
    self.conv_block_02 = UNetSegmenterConvMaxPool(in_channels_conv=16, out_channels_conv=32)
    self.conv_block_03 = UNetSegmenterConvMaxPool(in_channels_conv=32, out_channels_conv=64)
    self.conv_block_04 = UNetSegmenterConvMaxPool(in_channels_conv=64, out_channels_conv=128)
    self.conv_block_05 = UNetSegmenterConv(in_channels_conv=128, out_channels_conv=256)
    self.upsample_block_01 = UNetSegmenterUpsampleMaxPool(in_channels_conv=256, out_channels_conv=128)
    self.upsample_block_02 = UNetSegmenterUpsampleMaxPool(in_channels_conv=128, out_channels_conv=64)
    self.upsample_block_03 = UNetSegmenterUpsampleMaxPool(in_channels_conv=64, out_channels_conv=32)
    self.upsample_block_04 = UNetSegmenterUpsampleMaxPool(in_channels_conv=32, out_channels_conv=16)
    self.conv_block_final_01 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, padding='same')
    self.conv_block_final_02 = nn.Conv2d(in_channels=16, out_channels=1, kernel_size=1, padding='same')





  def forward(self, x):
    x_conv_01 = self.conv_block_01(x)
    x_conv_02 = self.conv_block_02(x_conv_01[1])
    x_conv_03 = self.conv_block_03(x_conv_02[1])
    x_conv_04 = self.conv_block_04(x_conv_03[1])
    x_conv_05 = self.conv_block_05(x_conv_04[1])
    x_upsample_01 = self.upsample_block_01([x_conv_04[0], x_conv_05])
    x_upsample_02 = self.upsample_block_02([x_conv_03[0], x_upsample_01])
    x_upsample_03 = self.upsample_block_03([x_conv_02[0], x_upsample_02])
    x_upsample_04 = self.upsample_block_04([x_conv_01[0], x_upsample_03])
    x_out = self.conv_block_final_01(x_upsample_04)
    x_out = self.conv_block_final_01(x_out)
    x_out = self.conv_block_final_02(x_out)
    return x_out

segmenter = UNetSegmenter()

segmenter(torch.rand(size=(1, 3, 128, 128))).shape

model_graph = draw_graph(segmenter
                         , input_size=(1, 3, 128, 128)
                         , expand_nested=True, roll=True
                         , hide_inner_tensors=True
                         , hide_module_functions=True)
model_graph.resize_graph(scale=1)
model_graph.visual_graph



ExecutableNotFound: failed to execute PosixPath('dot'), make sure the Graphviz executables are on your systems' PATH

<graphviz.graphs.Digraph at 0x13fb5cd30>

In [3]:
!pip install graphviz

You should consider upgrading via the '/Users/fermibot/PycharmProjects/pythonProject/venv/bin/python -m pip install --upgrade pip' command.[0m
