Neuron Segmentation Using U-Net

In [17]:
import torch
import torch.nn as nn
from torchvision import models, transforms
from torch.nn.functional import relu
import test

In [14]:
class UNET(nn.Module):
  def __init__(self, n_class):
    super().__init__()

    # Contracting path
    # First level of UNET
    self.en11 = nn.Conv2d(in_channels=1,out_channels=64, kernel_size=3)
    self.en12 = nn.Conv2d(in_channels=64,out_channels=64, kernel_size=3)

    # Second level of UNET
    self.en21 = nn.Conv2d(in_channels=64,out_channels=128, kernel_size=3)
    self.en22 = nn.Conv2d(in_channels=128,out_channels=128, kernel_size=3)

    # Third level of UNET
    self.en31 = nn.Conv2d(in_channels=128,out_channels=256, kernel_size=3)
    self.en32 = nn.Conv2d(in_channels=256,out_channels=256, kernel_size=3)

    # Fourth level of UNET
    self.en41 = nn.Conv2d(in_channels=256,out_channels=512, kernel_size=3)
    self.en42 = nn.Conv2d(in_channels=512,out_channels=512, kernel_size=3)

    # Fifth level of UNET
    self.en51 = nn.Conv2d(in_channels=512,out_channels=1024, kernel_size=3)
    self.en52 = nn.Conv2d(in_channels=1024,out_channels=1024, kernel_size=3)

    # Max pooling
    self.max_pool_2x2 = nn.MaxPool2d(kernel_size=2, stride=2)

    # Expansive path
    self.up_conv_1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
    self.de11 = nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3)
    self.de12 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3)

    self.up_conv_2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
    self.de21 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3)
    self.de22 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3)

    self.up_conv_3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
    self.de31 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3)
    self.de32 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3)

    self.up_conv_4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
    self.de41 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3)
    self.de42 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3)

    # Output layer
    self.output = nn.Conv2d(64,n_class, kernel_size=1)

  def forward(self, image):
    # Contracting path
    ex11 = relu(self.en11(image))
    ex12 = relu(self.en12(ex11))
    pool1 = self.max_pool_2x2(ex12)

    ex21 = relu(self.en21(pool1))
    ex22 = relu(self.en22(ex21))
    pool2 = self.max_pool_2x2(ex22)

    ex31 = relu(self.en31(pool2))
    ex32 = relu(self.en32(ex31))
    pool3 = self.max_pool_2x2(ex32)

    ex41 = relu(self.en41(pool3))
    ex42 = relu(self.en42(ex41))
    pool4 = self.max_pool_2x2(ex42)

    ex51 = relu(self.en51(pool4))
    ex52 = relu(self.en52(ex51))

    # Expansive path
    dx1 = self.up_conv_1(ex52)
    cat1 = torch.cat([dx1, ex42], dim=1)
    dec11 = relu(self.de11(cat1))
    dec12 = relu(self.de12(dec11))

    dx2 = self.up_conv_2(dec12)
    cat2 = torch.cat([dx2, ex32], dim=1)
    dec21 = relu(self.de21(cat2))
    dec22 = relu(self.de22(dec21))

    dx3 = self.up_conv_3(dec22)
    cat3 = torch.cat([dx3, ex22])
    dec31 = relu(self.de31(cat3))
    dec32 = relu(self.de32(dec31))

    dx4 = self.up_conv_4
    cat4 = torch.cat([dx4, ex12])
    dec41 = relu(self.de41(cat4))
    dec42 = relu(self.de42(dec41))

    out_conv = self.output(dec42)
    return out_conv

In [18]:
model = UNET(4) # foreground, background, label, neurons

In [24]:
!git co
!git add .
!git commit -m "finished model, need test"
!git push neuron-segmentation master

Author identity unknown

*** Please tell me who you are.

Run

  git config --global user.email "you@example.com"
  git config --global user.name "Your Name"

to set your account's default identity.
Omit --global to set the identity only in this repository.

fatal: unable to auto-detect email address (got 'root@848e4341938d.(none)')
Everything up-to-date
