In [1]:
%cd "/content/drive/MyDrive/PlotNeuralNet"
# sync python module
%load_ext autoreload
%autoreload 2

/content/drive/MyDrive/PlotNeuralNet


In [2]:
from pycore.tikzeng import *

In [9]:
class ImgLayer:
  def __init__(
        self,
        name='',
        file_name='',
        dim='60',
        offset='(0,0,0)', 
        to='(0,0,0)',
        is_inp=True
      ):
        self.name = name
        self.file_name = file_name
        self.dim = dim
        self.offset = offset
        self.to = to
        self.is_inp = is_inp

  def __str__(self):
    if self.is_inp:
      return f"""
      \\node[canvas is zy plane at x=0] (temp) at {self.to} {{\includegraphics[width=13cm,height=13cm]{{{self.file_name}}}}};
      \pic[shift={{(0, 0, 0)}}] at {self.to} {{
          Box={{
              name={self.name},
              opacity=0.0,
              fill=\convA,
              height={self.dim},
              width={{0}},
              depth={self.dim}
          }}
      }};
      """
    else:
      return f"""
      \pic[shift={{{self.offset}}}] at {self.to} {{
          Box={{
              name={self.name},
              opacity=0.0,
              fill=\convA,
              height={self.dim},
              width={{0}},
              depth={self.dim}
          }}
      }};
      \\node[canvas is zy plane at x=0] (temp) at ({self.name}-east) {{\includegraphics[width=13cm,height=13cm]{{{self.file_name}}}}};
      """


class ConcatLayer:
  def __init__(
        self,
        name='',
        label='||',
        width=2.5,
        offset='(0,0,0)', 
        to='(0,0,0)',
        fill='\concat',
      ):
        self.name = name
        self.label = label
        self.offset = offset
        self.to = to
        self.width = width
        self.fill = fill

  def __str__(self):
    return f"""
    \pic[shift={{{self.offset}}}] at {self.to} {{
        Ball={{
            name={self.name},
            fill={self.fill},
            radius={self.width},
            logo=${self.label}$
        }}
    }};
    """


class Layer:
  def __init__(
        self,
        name='',
        feature=None,
        width=2,
        dim=64,
        offset='(0,0,0)', 
        to='(0,0,0)',
        fill='\conv',
      ):
        self.name = name
        self.feature = feature
        self.offset = offset
        self.to = to
        self.width = width
        self.height = dim
        self.depth = dim
        self.fill = fill

  def __str__(self):
      if self.feature is not None:
        return f"""
        \pic[shift={{{self.offset}}}] at {self.to} {{
              Box={{
                  name={self.name},
                  xlabel={{{{{self.feature},}}}},
                  fill={self.fill},
                  height={self.height},
                  width={{{{{self.width}}}}},
                  depth={self.depth}
              }}
          }};
        """
      else:
        return f"""
        \pic[shift={{{self.offset}}}] at {self.to} {{
              Box={{
                  name={self.name},
                  fill={self.fill},
                  opacity=0.6,
                  height={self.height},
                  width={{{{{self.width}}}}},
                  depth={self.depth}
              }}
          }};
        """

def build_layers(ls, name, offset, to):
  # set first layer to
  ls[0].to = to

  # set first layer offset
  ls[0].offset = offset

  # add layer name
  for i, l in enumerate(ls):
    ls[i].name = f'{name}_l{i}'
  
  # add layer to
  for i in range(1,len(ls[1:])+1):
    ls[i].to = f'({ls[i-1].name}-east)'
  return ls


class BaseBlock:
  def __init__(
      self,
      name,
      feature=96,
      width=2,
      dim=64,
      offset='(2,0,0)', 
      to='(0,0,0)',
    ):
    self.name = name
    self.feature = feature
    self.offset = offset
    self.to = to
    self.width = width
    self.dim = dim
    self.height = dim
    self.depth = dim
    self.ls = []

  def __str__(self):
    out = ''
    for l in self.ls:
      out+=str(l)
    return out


class StemBlock(BaseBlock):
  def __init__(
      self,
      name,
      feature=96,
      width=2,
      dim=64,
      offset='(2,0,0)', 
      to='(0,0,0)',
    ):
    super().__init__(name, feature, width, dim, offset, to)

    self.l_width = 1
    
    self.conv1 = Layer(
        feature=self.feature,
        width=self.width, 
        dim=self.dim,
        fill='\convS'
    )

    self.norm1 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\ln'
    )

    self.ls = [
        self.conv1,
        self.norm1,
    ]
    self.ls = build_layers(self.ls, self.name, self.offset, self.to)


class ConvNeXtBlock(BaseBlock):
  def __init__(
      self,
      name,
      feature=96,
      width=2,
      dim=64,
      offset='(2,0,0)', 
      to='(0,0,0)',
    ):
    super().__init__(name, feature, width, dim, offset, to)

    self.l_width = 1
    
    self.conv1 = Layer(
        feature=self.feature,
        width=self.width, 
        dim=self.dim,
        fill='\convA'
    )

    self.conv2 = Layer(
        feature=self.feature*4, 
        width=self.width*2, 
        dim=self.dim,
        fill='\convB'
    )

    self.conv3 = Layer(
        feature=self.feature, 
        width=self.width, 
        dim=self.dim,
        fill='\convB'
    )

    self.norm1 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\ln'
    )

    self.ac1 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\gelu'
    )

    self.ls = [
        self.conv1,
        self.norm1,
        self.conv2,
        self.ac1,
        self.conv3
    ]
    self.ls = build_layers(self.ls, self.name, self.offset, self.to)


class ResBlock(BaseBlock):
  def __init__(
      self,
      name,
      feature=96,
      width=2,
      dim=64,
      offset='(2,0,0)', 
      to='(0,0,0)',
    ):
    super().__init__(name, feature, width, dim, offset, to)

    self.l_width = 1
    
    self.conv1 = Layer(
        feature=self.feature,
        width=self.width, 
        dim=self.dim,
        fill='\convC'
    )

    self.conv2 = Layer(
        feature=self.feature,
        width=self.width, 
        dim=self.dim,
        fill='\convC'
    )

    self.norm1 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\in'
    )

    self.norm2 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\in'
    )

    self.ac1 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\lrelu'
    )

    self.ac2 = Layer(
        width=self.l_width, 
        dim=self.dim,
        fill='\lrelu'
    )

    self.ls = [
        self.conv1,
        self.norm1,
        self.ac1,
        self.conv2,
        self.norm2,
        self.ac2,
    ]
    self.ls = build_layers(self.ls, self.name, self.offset, self.to)
    

class ConcatBlock(BaseBlock):
  def __init__(
      self,
      name,
      label='||',
      width=2,
      offset='(2,0,0)', 
      to='(0,0,0)',
    ):
    super().__init__(name, width=width, offset=offset, to=to)
    self.concat = ConcatLayer(
        label='||',
        width=width,
        fill='\convC'
    )
    self.ls = [
        self.concat
    ]
    self.ls = build_layers(self.ls, self.name, self.offset, self.to)


class OutBlock(BaseBlock):
  def __init__(
      self,
      name,
      feature=96,
      width=2,
      dim=64,
      offset='(2,0,0)', 
      to='(0,0,0)',
    ):
    super().__init__(name, feature, width, dim, offset, to)
    
    self.conv1 = Layer(
        feature=self.feature,
        width=self.width, 
        dim=self.dim,
        fill='\convO'
    )

    self.ls = [
        self.conv1,
    ]
    self.ls = build_layers(self.ls, self.name, self.offset, self.to)


class ImgBlock(BaseBlock):
  def __init__(
      self,
      name,
      file_name,
      dim,
      offset='(0,0,0)', 
      to='(0,0,0)',
      is_inp=True
      
    ):
    super().__init__(name, width=0, dim=dim,offset=offset, to=to)
    self.file_name = file_name
    self.is_inp=is_inp
    self.ls = [
      ImgLayer(
        self.name,
        self.file_name,
        self.dim,
        self.offset,
        self.to,
        self.is_inp
      )
    ]
    self.ls = build_layers(self.ls, self.name, self.offset, self.to)

In [4]:
class UConn:
  def __init__(self,p1,p2):
      self.p1 = p1
      self.p2 = p2

  def __str__(self):
    return f"""
    \path ({self.p1}-east) -- ({self.p2}-west|-{self.p1}-west) coordinate[pos=0.5] ({self.p1}-mid);
    \draw[connection]({self.p1}-east)--node{{\midarrow}}({self.p1}-mid)--node{{\midarrow}}({self.p2}-west-|{self.p1}-mid)--node{{\midarrow}}({self.p2}-west);
    """

class Conn:
  def __init__(self,p1,p2):
      self.p1 = p1
      self.p2 = p2

  def __str__(self):
    return f"""
    \draw [connection]({self.p1}-east)--node{{\midarrow}}({self.p2}-west);
    """

In [5]:
def get_en_stage(stage_name, block_nums, feature, width, dim, offset):
  blocks = []
  conns = []
  names = [f"{stage_name}_b{bn}" for bn in range(block_nums)]
  for i, name in enumerate(names):
    if i == 0:
      # first block set offset
      blocks.append(ConvNeXtBlock(name, feature=feature, width=width, dim=dim, offset=offset))
    else:
      # to: prev block last layer name-east
      blocks.append(ConvNeXtBlock(name, feature=feature, width=width, dim=dim, to=f'({blocks[i-1].ls[-1].name}-east)'))
      # add conn prev block last layer to curr block first layer
      conns.append(Conn(blocks[i-1].ls[-1].name, blocks[i].ls[0].name))
  return [blocks, conns]

In [6]:
def get_de_stage(stage_name, feature, width, concat_width, dim, offset, res_offset):
  names = [f"{stage_name}_b{bn}" for bn in range(2)]
  con_block = ConcatBlock(names[0], width=concat_width, offset=offset)
  # to: prev block last layer name-east
  res_block = ResBlock(names[1], feature=feature, width=width, dim=dim, offset=res_offset, to=f'({con_block.ls[-1].name}-east)')
  blocks = [con_block, res_block]
  # add conn prev block last layer to curr block first layer
  conns = [Conn(blocks[0].ls[-1].name, blocks[1].ls[0].name)]
  return [blocks, conns]

In [12]:
block_nums = [3,3,9,3]
features = [96,192,384,768]
scale = 1
width = 8
dims = [scale * e for e in [45,30,15,10]]
offsets_ys = [12,10,8,4]
offsets = [f'(2, {-y}, 0)' for y in offsets_ys]
en_stages = []
en_stage_conns = []
en_block_conns = []

concat_width = 2
de_width = 32
de_features = [384,192,96,48,24]
de_dims = [scale * e for e in [15,30,45,60,65]]
de_offsets_ys = [4,8,10,12,16]
de_offsets = [f'(2, {y}, 0)' for y in de_offsets_ys]
de_res_offsets_xs = [2,4,6,6,6,6]
de_res_offsets = [f'({x}, 0, 0)' for x in de_res_offsets_xs]
de_stages = []
de_stage_conns = []
de_block_conns = []

# inp img
inp_img = [
    [ImgBlock('inp_img', 'cats.jpg',dim=de_dims[-1])]
]

# out img
out_img = [
    [ImgBlock('out_img', 'cats.jpg', dim=de_dims[-1], to=(0), offset='(4, 0, 0)', is_inp=False)]
]

# encoder_res_stage
res_stages = [
    [ResBlock('rs0', feature=de_features[-1], width=width/4, dim=de_dims[-1], offset='(4, 0, 0)')]
]

# stem_stage
stem_stages = [
    [StemBlock('ss0', feature=de_features[-2], width=width/2, dim=de_dims[-2], offset=f'(0, {-de_offsets_ys[-1]}, 0)')]
]


# encdoer state
for i in range(len(features)):
  blocks, conns = get_en_stage(f'es{i}', block_nums[i], features[i], width, dims[i], offsets[i])
  if i != 0:
    # curr stage first block first layer to: prev stage last block last layer name-east
    blocks[0].ls[0].to = f'({en_stages[i-1][-1].ls[-1].name}-east)'

  en_stages.append(blocks)
  en_block_conns.append(conns)
  
  if i != 0:
    # add stage conn, prev stage last block last layer, curr stage, first block first layer
    uc = UConn(en_stages[i-1][-1].ls[-1].name,en_stages[i][0].ls[0].name)
    en_stage_conns.append(uc)

  # update width
  if i < 1:
    width = width*2


# decoder state
for i in range(len(de_features)):
  blocks, conns = get_de_stage(f'ds{i}', de_features[i], de_width, concat_width, de_dims[i], de_offsets[i], de_res_offsets[i])

  if i != 0:
    # curr stage first block first layer to: prev stage last block last layer name-east
    blocks[0].ls[0].to = f'({de_stages[i-1][-1].ls[-1].name}-east)'

  de_stages.append(blocks)
  de_block_conns.append(conns)

  if i != 0:
    # add stage conn, prev stage last block last layer, curr stage, first block first layer
    uc = UConn(de_stages[i-1][-1].ls[-1].name, de_stages[i][0].ls[0].name)
    de_stage_conns.append(uc)

  de_width = de_width//2

# out block
out_block = [
  [OutBlock('os0', feature=de_features[-2], width=1, dim=de_dims[-1], offset='(2,0,0)')]
]



# set res to: img
res_stages[0][0].ls[0].to = f'({inp_img[0][0].ls[0].name}-east)'
# conn inp to res and stem 
c = Conn(inp_img[-1][-1].ls[-1].name, res_stages[0][0].ls[0].name)
en_stage_conns.append(c)
uc = UConn(inp_img[-1][-1].ls[-1].name, stem_stages[0][0].ls[0].name)
en_stage_conns.append(uc)

# set stem to: res
stem_stages[0][0].ls[0].to = f'({res_stages[0][0].ls[0].name}-east)'

# connect stem and encoder
# set encoder stage to: stem stage
en_stages[0][0].ls[0].to = f'({stem_stages[-1][-1].ls[-1].name}-east)'
# set decoder and encoder conn
uc = UConn(stem_stages[-1][-1].ls[-1].name, en_stages[0][0].ls[0].name)
de_stage_conns.append(uc)

# connect encoder and decoder
# set decoder first stage first block first layer to: encoder last stage last block last layer
de_stages[0][0].ls[0].to = f'({en_stages[-1][-1].ls[-1].name}-east)'
# set decoder and encoder conn
uc = UConn(en_stages[-1][-1].ls[-1].name, de_stages[0][0].ls[0].name)
de_stage_conns.append(uc)

# set out block to: decoder
out_block[0][0].ls[0].to = f'({de_stages[-1][-1].ls[-1].name}-east)'
c = Conn(de_stages[-1][-1].ls[-1].name, out_block[0][0].ls[0].name)
de_stage_conns.append(c)

# set out img to: out block
out_img[0][0].ls[0].to = f'({out_block[-1][-1].ls[-1].name}-east)'
c = Conn(out_block[-1][-1].ls[-1].name, out_img[0][0].ls[0].name)
de_stage_conns.append(c)




# skip conna
sk_conns = [
    Conn(en_stages[2][-1].ls[-1].name, de_stages[0][0].ls[0].name),
    Conn(en_stages[1][-1].ls[-1].name, de_stages[1][0].ls[0].name),
    Conn(en_stages[0][-1].ls[-1].name, de_stages[2][0].ls[0].name),
    Conn(stem_stages[0][-1].ls[-1].name, de_stages[3][0].ls[0].name),
    Conn(res_stages[0][-1].ls[-1].name, de_stages[4][0].ls[0].name),
]


# build str
out_str = ''
for s in inp_img:
  for b in s:
    out_str+=str(b)

for s in res_stages:
  for b in s:
    out_str+=str(b)

for s in stem_stages:
  for b in s:
    out_str+=str(b)

for s in en_stages:
  for b in s:
    out_str+=str(b)

for s in de_stages:
  for b in s:
    out_str+=str(b)

for s in out_block:
  for b in s:
    out_str+=str(b)

for s in out_img:
  for b in s:
    out_str+=str(b)

for s in en_block_conns:
  for b in s:
    out_str+=str(b)

for s in de_block_conns:
  for b in s:
    out_str+=str(b)

for s in en_stage_conns:
  out_str+=str(s)

for s in de_stage_conns:
  out_str+=str(s)

for s in sk_conns:
  out_str+=str(s)

print(out_str)


      \node[canvas is zy plane at x=0] (temp) at (0,0,0) {\includegraphics[width=13cm,height=13cm]{cats.jpg}};
      \pic[shift={(0, 0, 0)}] at (0,0,0) {
          Box={
              name=inp_img_l0,
              opacity=0.0,
              fill=\convA,
              height=65,
              width={0},
              depth=65
          }
      };
      
        \pic[shift={(4, 0, 0)}] at (inp_img_l0-east) {
              Box={
                  name=rs0_l0,
                  xlabel={{24,}},
                  fill=\convC,
                  height=65,
                  width={{2.0}},
                  depth=65
              }
          };
        
        \pic[shift={(0,0,0)}] at (rs0_l0-east) {
              Box={
                  name=rs0_l1,
                  fill=\in,
                  opacity=0.6,
                  height=65,
                  width={{1}},
                  depth=65
              }
          };
        
        \pic[shift={(0,0,0)}] at (rs0_l1-east) {
            