In [1]:
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms

import os
from PIL import Image
import math
import numpy as np

import tqdm

In [2]:
resnet34 = torchvision.models.resnet34(pretrained=True)
print(resnet34)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace)
      (conv2): Co

In [3]:
class PreEncoder(nn.Module):
    
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=7, stride=2, padding=3)
        self.bn = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, t):
        out = self.conv(t)
        out = self.bn(out)
        return out
        
        
class Encoder(nn.Module):
    
    def __init__(self, pretrained=True):
        super().__init__()
        self.resnet34 = torchvision.models.resnet34(pretrained=pretrained)
        
        self.encoder1 = self.resnet34.layer1
        self.encoder2 = self.resnet34.layer2
        self.encoder3 = self.resnet34.layer3
        self.encoder4 = self.resnet34.layer4
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        
    def forward(self, t):
        self.enc1 = out = self.encoder1(t)
        self.enc2 = out = self.encoder2(out)
        self.enc3 = out = self.encoder3(out)
        self.enc4 = out = self.encoder4(out)
        out = self.pool(out)
        return out
        
    
class DecoderBlock(nn.Module):
    def __init__(self, in_channel, n_filters):
        super().__init__()
        
        self.conv1 = nn.Conv2d(in_channel, in_channel // 4, 1)
        self.bn1 = nn.BatchNorm2d(in_channel // 4)
        self.relu1 = nn.ReLU()
        
        self.deconv2 = nn.ConvTranspose2d(in_channel // 4, in_channel // 4, 3, stride=2, padding=1, output_padding=1)
        self.bn2 = nn.BatchNorm2d(in_channel // 4)
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(in_channel // 4, n_filters, 1)
        self.bn3 = nn.BatchNorm2d(n_filters)
        self.relu3 = nn.ReLU()
    
    def forward(self, t):
        out = self.conv1(t)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.deconv2(out)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv3(out)
        out = self.bn3(out)
        out = self.relu3(out)
        return out

class Decoder(nn.Module):
    
    def __init__(self, encoder):
        super().__init__()
        
        self.encoder = encoder
        self.decoder1 = DecoderBlock(512, 256)
        self.decoder2 = DecoderBlock(512 + 256, 256)
        self.decoder3 = DecoderBlock(256 + 256, 256)
        self.decoder4 = DecoderBlock(256 + 128, 64)
        self.decoder5 = DecoderBlock(64 + 64, 128)
        self.decoder6 = DecoderBlock(128, 32)
        self.decoder7 = DecoderBlock(32, 32)
        self.conv1 = nn.Conv2d(32, 32, kernel_size=3, padding=1, stride=2)
        self.conv2 = nn.Conv2d(32, 1, kernel_size=1)
    
    def forward(self, t):
        out = torch.cat((self.decoder1(t), self.encoder.enc4), dim=1)
        out = torch.cat((self.decoder2(out), self.encoder.enc3), dim=1)
        out = torch.cat((self.decoder3(out), self.encoder.enc2), dim=1)
        out = torch.cat((self.decoder4(out), self.encoder.enc1), dim=1)
        out = self.decoder5(out)
        out = self.decoder6(out)
        out = self.decoder7(out)
        out = self.conv1(out)
        out = self.conv2(out)
        return out
        
class FCN(nn.Module):
    
    def __init__(self, pretrained=True):
        super().__init__()
        self.pre_encoder = PreEncoder(3, 64)
        self.encoder = Encoder(pretrained)
        self.decoder = Decoder(self.encoder)
        self.softmax = nn.Softmax2d()
    
    def forward(self, t):
        out = self.pre_encoder(t)
        out = self.encoder(out)
        out = self.decoder(out)
        out = self.softmax(out)
        return out
        
        

In [4]:
class Dataset(torch.utils.data.Dataset):
    
    def __init__(self, train_dir, image_filenames):
        self.image_filenames = image_filenames
        self.train_dir = train_dir
        self.batch_size=batch_size
        self.trf = transforms.Compose([transforms.ToTensor()])
    
    def __getitem__(self, index):
        image_path = os.path.join(self.train_dir, image_filenames[index])
        image = Image.open(image_path)
        image = transforms.Compose([transforms.ToTensor()])(image)
        mask_path = os.path.join(train_dir, image_filenames[index].split('_')[0] + "_mask.png")
        mask = self.trf(Image.open(mask_path).convert('L')) / 255
        return image, mask
    #torch.tensor(torch.from_numpy(np.rollaxis(image, 2, 0)), dtype=torch.int64), torch.tensor(torch.from_numpy(np.rollaxis(mask, 2, 0)), dtype=torch.int64)
    
    def __len__(self):
        return len(self.image_filenames)
        

In [5]:
def jaccard_index(outputs, labels):
    outputs = outputs.round().reshape(-1)
    labels = outputs.reshape(-1)
    summation = (labels * outputs) / (outputs + labels - labels * outputs)
    return summation.sum() / len(outputs)

def standard_jaccard(outputs, labels):
    outputs = outputs.round().int()
    labels = labels.round().int()
    intersection = outputs & labels
    union = outputs | labels
    return intersection.float().sum() / union.float().sum()

def fcn_loss(output, label):
    a = 0.7
    bceloss = nn.BCELoss()
    j = math.log(jaccard_index(output, label))
    loss = a * bceloss(output, label) - (1 - a) * j
    return loss

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 8
train_dir = 'train'
val_dir = 'valid'
shuffle = True
epochs = 5
network = FCN().cuda()
criterion = fcn_loss
optimizer = optim.Adam(network.parameters(), lr=0.01)

In [7]:
image_filenames = []
for filename in os.listdir(train_dir):
    image_filenames.append(filename)

In [8]:
#dataset = torchvision.datasets.ImageFolder(root=train_dir, transform=torchvision.transforms.ToTensor())
dataset = Dataset(train_dir, image_filenames)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=shuffle)
# image, label = next(iter(dataloader))
# print(image.shape, label.shape)

In [None]:
for epoch in range(1, epochs+1):
    total_loss = 0.0
    total_accuracy = 0.0
    counter = 0
    for i, data in tqdm.tqdm(enumerate(dataloader)):
        image, mask = data
        image = image.to(device=device, dtype=torch.float32)
        mask = mask.to(device=device, dtype=torch.float32)
        optimizer.zero_grad()
        output = network(image)
        loss = criterion(output.to(device=device, dtype=torch.float32), mask.to(device=device, dtype=torch.float32))
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        total_accuracy += standard_jaccard(output, mask)
        counter += 1
        if counter % 20 == 0:
            print("epoch {}/{}\tbatch {}/{}\tloss {:.5f}\taccuracy {:.5f} %".format(epoch, epochs, i+1, len(dataloader), total_loss / counter,total_accuracy/counter*100))
        

20it [00:23,  1.18s/it]

epoch 1/5	batch 20/2242	loss 19.33928	accuracy 0.00000 %


40it [00:47,  1.22s/it]

epoch 1/5	batch 40/2242	loss 19.33910	accuracy 0.00000 %


60it [01:10,  1.15s/it]

epoch 1/5	batch 60/2242	loss 19.33903	accuracy 0.00000 %


74it [01:26,  1.15s/it]