In [6]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import glob
import os
from PIL import Image
import numpy
import matplotlib.pyplot as plt
from torch import optim

# Części do sieci neuronowej


*   https://github.com/milesial/Pytorch-UNet/



In [7]:
class Conv(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class FirstHalfEncoder(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            Conv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Decoder(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2, padding=2)
    self.conv1 = Conv(in_channels, in_channels)
    self.conv2 = Conv(2*in_channels, in_channels)


  def forward(self, x1, x2):

    x1 = self.conv1(x1)

    diffY = x2.size()[2] - x1.size()[2]
    diffX = x2.size()[3] - x1.size()[3]
    x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
    x = torch.cat([x2, x1], dim=1)

    x = self.conv2(x)
    return self.up(x)

class OutConv(nn.Module):
  def __init__(self, in_channels, out_channels):
    super().__init__()
    self.outlayer = nn.Sequential( 
        nn.Conv2d(in_channels, out_channels, kernel_size=(2,1), padding=(4,3)),
        nn.Sigmoid())

  def forward(self, x):
    return self.outlayer(x)


class SpatialAttention(nn.Module):
  def __init__(self, n_channels):
    super().__init__()
    n_channels *=2
    self.r1 = nn.Sequential(
        Conv(n_channels, n_channels),
        nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False)
    )
    self.r3 = nn.Sequential(
        Conv(n_channels, n_channels),
        nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False)
    )
    self.r5 = nn.Sequential(
        Conv(n_channels, n_channels),
        nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False)
    )
    self.r7 = nn.Sequential(
        Conv(n_channels, n_channels),
        nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False)
    )

  def forward(self, x):

    x1 = self.r1(x)
    x3 = self.r3(x)
    x5 = self.r5(x)
    x7 = self.r7(x)

    x_out = torch.add(torch.add(torch.add(x1, x3), x5), x7)

    return torch.mul(x, x_out)

class ChannelAttetnion(nn.Module):

  def __init__(self, n_channels):
    super().__init__()

    self.blocks = nn.Sequential(
        nn.AvgPool2d(1),
        nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False),
        nn.ReLU(inplace=True),
        nn.Conv2d(n_channels, n_channels, kernel_size=1, bias=False),
        nn.Sigmoid()
    )
  def forward(self, x):
    x_out = self.blocks(x)
    return torch.mul(x, x_out)

class FusionBlock(nn.Module):
  def __init__(self, n_channels):
    super().__init__()

    self.channelatt1 = ChannelAttetnion(n_channels)
    self.channelatt2 = ChannelAttetnion(n_channels)
    self.spatial = SpatialAttention(n_channels)
    self.outlayers = Conv(2*n_channels, n_channels)

  def forward(self, x1, x2):

    x1 = self.channelatt1(x1)
    x2 = self.channelatt2(x2)
    

    x = torch.cat([x1, x2], dim=1)


    x = self.spatial(x)


    x = self.outlayers(x)



    return x



In [8]:
class MSCNet(nn.Module):
  def __init__(self, n_channels, n_classes = 1):
    super().__init__()
    self.n_channels = n_channels
    self.n_classes = n_classes

    ## unets
    # unet1
    self.beginning_conv_top = nn.Conv2d(self.n_channels, 64, kernel_size=3)
    
    # block 1 down
    self.enc1a_top = FirstHalfEncoder(64,128)
    self.enc1b_top = Conv(128,128)

    # block 2 down
    self.enc2a_top = FirstHalfEncoder(128,256)
    self.enc2b_top = Conv(256,256)

    # block 3 down
    self.enc3a_top = FirstHalfEncoder(256,512)
    self.enc3b_top = Conv(512,512)

    # block 4 down
    self.enc4a_top = FirstHalfEncoder(512,1024)
    self.enc4b_top = Conv(1024,1024)

    # block 4 up
    self.dec4_top = Decoder(1024, 512)
    self.dec3_top = Decoder(512, 256)
    self.dec2_top = Decoder(256, 128)
    self.dec1_top = Decoder(128, 64)

    # output

    self.out_top = OutConv(64, n_classes)

    # unet2
    
    self.beginning_conv_bot = nn.Conv2d(self.n_channels, 64, kernel_size=3)
    
    # block 1 down
    self.enc1a_bot = FirstHalfEncoder(64,128)
    self.enc1b_bot = Conv(128,128)

    # block 2 down
    self.enc2a_bot = FirstHalfEncoder(128,256)
    self.enc2b_bot = Conv(256,256)

    # block 3 down
    self.enc3a_bot = FirstHalfEncoder(256,512)
    self.enc3b_bot = Conv(512,512)

    # block 4 down
    self.enc4a_bot = FirstHalfEncoder(512,1024)
    self.enc4b_bot = Conv(1024,1024)

    # block 4 up
    self.dec4_bot = Decoder(1024, 512)
    self.dec3_bot = Decoder(512, 256)
    self.dec2_bot = Decoder(256, 128)
    self.dec1_bot = Decoder(128, 64)

    # output

    self.out_bot = OutConv(64, n_classes)

    ## fusion

    self.fusion4 = FusionBlock(1024)
    self.fusion3 = FusionBlock(512)
    self.fusion2 = FusionBlock(256)
    self.fusion1 = FusionBlock(128)

    self.sigmoid_top = nn.Sigmoid()
    self.sigmoid_bot = nn.Sigmoid()

  def forward(self, x):
    
    x_init_top = self.beginning_conv_top(x)
    x1a_top = self.enc1a_top(x_init_top)
    x1b_top = self.enc1b_top(x1a_top)

    x2a_top = self.enc2a_top(x1b_top)
    x2b_top = self.enc2b_top(x2a_top)

    x3a_top = self.enc3a_top(x2b_top)
    x3b_top = self.enc3b_top(x3a_top)

    x4a_top = self.enc4a_top(x3b_top)
    x4b_top = self.enc4b_top(x4a_top)

    x_init_bot = self.beginning_conv_bot(x)    
    x1a_bot = self.enc1a_bot(x_init_bot)
    x1b_bot = self.enc1b_bot(x1a_bot)

    x2a_bot = self.enc2a_bot(x1b_bot)
    x2b_bot = self.enc2b_bot(x2a_bot)

    x3a_bot = self.enc3a_bot(x2b_bot)
    x3b_bot = self.enc3b_bot(x3a_bot)

    x4a_bot = self.enc4a_bot(x3b_bot)
    x4b_bot = self.enc4b_bot(x4a_bot)

    fusion4 = self.fusion4(x4b_top, x4b_bot)

    x4_top = torch.add(x4b_top, fusion4)
    x4_bot = torch.add(x4b_bot, fusion4)

    x4_top = self.dec4_top(x4_top, x4a_top)
    x4_bot = self.dec4_bot(x4_bot, x4a_bot)

    fusion3 = self.fusion3(x4_top, x4_bot)

    x3_top = torch.add(x4_top, fusion3)        
    x3_bot = torch.add(x4_bot, fusion3)        

    x3_top = self.dec3_top(x3_top, x3a_top)
    x3_bot = self.dec3_bot(x3_bot, x3a_bot)

    fusion2 = self.fusion2(x3_top, x3_bot)

    x2_top = torch.add(x3_top, fusion2)
    x2_bot = torch.add(x3_bot, fusion2)

    x2_top = self.dec2_top(x2_top, x2a_top)
    x2_bot = self.dec2_bot(x2_bot, x2a_bot)

    fusion1 = self.fusion1(x2_top, x2_bot)

    x1_top = torch.add(x2_top, fusion1)
    x1_bot = torch.add(x2_bot, fusion1)

    x1_top = self.dec1_top(x1_top, x1a_top)
    x1_bot = self.dec1_bot(x1_bot, x1a_bot)

    x_top = self.out_top(x1_top)
    x_top = self.sigmoid_top(x_top)

    x_bot = self.out_top(x1_bot)
    x_bot = self.sigmoid_bot(x_bot)

    return x_top, x_bot





In [10]:
class DataLoaderSegmentation(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None):
        super(DataLoaderSegmentation, self).__init__()
        self.img_files = glob.glob(os.path.join(folder_path,'images','*.ppm'))
        self.mask1_files = []
        self.mask2_files = []
        self.transform = transform
        for img_path in self.img_files:
             self.mask1_files.append(os.path.join(folder_path,'mask1',os.path.basename(img_path)))
        for img_path in self.img_files:
             self.mask2_files.append(os.path.join(folder_path,'mask2',os.path.basename(img_path)))

    def __getitem__(self, index):
            img_path = self.img_files[index]
            mask1_path = self.mask1_files[index]
            data =Image.open(img_path)
            mask1 = Image.open(mask1_path)
            mask2_path = self.mask2_files[index]
            mask2 = Image.open(mask2_path)

            if self.transform:
              img = self.transform(torch.from_numpy(numpy.asarray(data)/255).float())
              m1 = self.transform(torch.from_numpy(numpy.asarray(mask1)/255).float())
              m2 = self.transform(torch.from_numpy(numpy.asarray(mask2)/255).float())
            else:
              img = torch.from_numpy(numpy.asarray(data)/255).float()
              m1 = torch.from_numpy(numpy.asarray(mask1)/255).float()
              m2 = torch.from_numpy(numpy.asarray(mask2)/255).float()

            img = img.permute(2, 0, 1)
            return img, m1, m2

    def __len__(self):
        return len(self.img_files)

In [None]:
n_epoch = 7

net = MSCNet(n_channels = 3, n_classes = 1)

dataset = DataLoaderSegmentation(".")

train, test = torch.utils.data.random_split(dataset, [15,5], generator = torch.Generator().manual_seed(123))

train_loader = torch.utils.data.DataLoader(train, shuffle=True, batch_size=1, num_workers=0, pin_memory=True)
test_loader = torch.utils.data.DataLoader(test, shuffle=False, batch_size=1, num_workers=0, pin_memory=True, drop_last=True)

optimizer = optim.RMSprop(net.parameters(), lr = 1e-3, weight_decay=1e-8, momentum=0.9)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=2)

for epoch in range(n_epoch):
  running_loss = 0.0
  net.train()
  for i, (image, mask1, mask2) in enumerate(train_loader):
    
    optimizer.zero_grad()

    output_top, output_bot = net(image)
    
    loss1 = nn.BCEWithLogitsLoss()

    loss = loss1(torch.squeeze(output_top, dim=0), mask1) + loss1(torch.squeeze(output_bot, dim=0), mask2)
    loss.backward()
    optimizer.step()

    running_loss += loss.item()
    print(f'[Epoch: {epoch + 1},sample: {i + 1:5d}] loss: {running_loss/ (i+1)}')
  
  
#   net.eval()
#   loss = 0
#   with torch.no_grad():
#       for i, (image, mask1, mask2) in enumerate(test_loader):
#           output_top, output_bot = net(image)  
#           loss1 = nn.BCEWithLogitsLoss()
#           loss += loss1(torch.squeeze(output_top, dim=0), mask1) + loss1(torch.squeeze(output_bot, dim=0), mask2)
  
#   val_loss = loss / len(test_loader)
#   print("Epoch {}, val_loss=".format(epoch), val_loss)


[Epoch: 1,sample:     1] loss: 1.99478018283844
[Epoch: 1,sample:     2] loss: 1.9960351586341858
[Epoch: 1,sample:     3] loss: 1.951777458190918
[Epoch: 1,sample:     4] loss: 1.9273979365825653
[Epoch: 1,sample:     5] loss: 1.9066287755966187
[Epoch: 1,sample:     6] loss: 1.8963579138120015
[Epoch: 1,sample:     7] loss: 1.8877010175159998
[Epoch: 1,sample:     8] loss: 1.888973891735077
[Epoch: 1,sample:     9] loss: 1.882885668012831
[Epoch: 1,sample:    10] loss: 1.8786437273025514
[Epoch: 1,sample:    11] loss: 1.8755618658932773
[Epoch: 1,sample:    12] loss: 1.8741651972134907
[Epoch: 1,sample:    13] loss: 1.8755379640139067
[Epoch: 1,sample:    14] loss: 1.8737242903028215
[Epoch: 1,sample:    15] loss: 1.8733115673065186
[Epoch: 2,sample:     1] loss: 1.8404101133346558
[Epoch: 2,sample:     2] loss: 1.8539645671844482
[Epoch: 2,sample:     3] loss: 1.8437788089116414
[Epoch: 2,sample:     4] loss: 1.8453617990016937
[Epoch: 2,sample:     5] loss: 1.8452625036239625
[Epoc