<a href="https://colab.research.google.com/github/Qiuyan918/Unet_Implementation_PyTorch/blob/master/Unet_Implementation_PyTorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import pandas as pd
import cv2
from tqdm import tqdm
from copy import deepcopy
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
from torch.utils.data.sampler import RandomSampler

# Configs

In [None]:
n_fold = 5
pad_left = 27
pad_right = 27
fine_size = 202
batch_size = 18
epoch = 300 
snapshot = 6 
max_lr = 0.012 
min_lr = 0.001 
momentum = 0.9 
weight_decay = 1e-4 
n_fold = 5
device = torch.device('cuda')
save_weight = 'weights/'
if not os.path.isdir(save_weight):
  os.mkdir(save_weight)
weight_name = 'model_' + str(fine_size+pad_left+pad_right) + '_res18' 

train_image_dir = 'tgs-salt-identification-challenge/train/images'
train_mask_dir = 'tgs-salt-identification-challenge/train/masks'
test_image_dir = 'tgs-salt-identification-challenge/test/images'

# Split

In [None]:
depths = pd.read_csv('tgs-salt-identification-challenge/depths.csv')
depths.sort_values('z', inplace=True)
depths.drop('z', axis=1, inplace=True)
depths['fold'] = (list(range(0,5)) * depths.shape[0])[:depths.shape[0]]

train_df = pd.read_csv('tgs-salt-identification-challenge/train.csv')
train_df = train_df.merge(depths)
dist = []
for id in train_df.id.values:
  img = cv2.imread(f'tgs-salt-identification-challenge/train/images/{id}.png', cv2.IMREAD_GRAYSCALE)
  dist.append(np.unique(img).shape[0])
train_df['unique_pixels'] = dist

# Dataset

In [None]:
def trainImageFetch(images_id):
  image_train = np.zeros((images_id.shape[0], 101, 101), dtype=np.float32)
  mask_train = np.zeros((images_id.shape[0], 101, 101), dtype=np.float32)

  for idx, image_id in tqdm(enumerate(images_id), total=images_id.shape[0]):
    image_path = os.path.join(train_image_dir, image_id+'.png')
    mask_path = os.path.join(train_mask_dir, image_id+'.png')

    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255

    image_train[idx] = image
    mask_train[idx] = mask
  
  return image_train, mask_train

def testImageFetch(test_id):
  image_test = np.zeros((len(test_id), 101, 101), dtype=np.float32)

  for idx, image_id in tqdm(enumerate(test_id), total=len(test_id)):
    image_path = os.path.join(test_image_dir, image_id+'.png')
    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE).astype(np.float32) / 255
    image_test[idx] = image

  return image_test

def do_resize2(image, mask, H, W):
  image = cv2.resize(image, dsize=(W,H))
  mask = cv2.resize(mask, dsize=(W,H))
  return image, mask

def do_center_pad(image, pad_left, pad_right):
  return np.pad(image, (pad_left, pad_right), 'edge')

def do_center_pad2(image, mask, pad_left, pad_right):
  image = do_center_pad(image, pad_left, pad_right)
  mask = do_center_pad(mask, pad_left, pad_right)
  return image, mask

class SaltDataset(Dataset):
  def __init__(self, image_list, mode, mask_list=None, fine_size=202, pad_left=0, pad_right=0):
    self.imagelist = image_list
    self.mode = mode
    self.masklist = mask_list
    self.fine_size = fine_size
    self.pad_left = pad_left
    self.pad_right = pad_right

  def __len__(self):
    return len(self.imagelist)

  def __getitem__(self, idx):
    image = deepcopy(self.imagelist[idx])

    if self.mode == 'train':
      mask = deepcopy(self.masklist[idx])
      label = np.where(mask.sum() == 0, 1.0, 0.0).astype(np.float32)

      if self.fine_size != image.shape[0]:
        image, mask = do_resize2(image, mask, self.fine_size, self.fine_size)

      if self.pad_left != 0:
        image, mask = do_center_pad2(image, mask, self.pad_left, self.pad_right)

      image = image.reshape(1, image.shape[0], image.shape[1])
      mask = mask.reshape(1, mask.shape[0], mask.shape[1])    

      return image, mask, label

    elif self.mode == 'val':
      mask = deepcopy(self.masklist[idx])

      if self.fine_size != image.shape[0]:
        image, mask = do_resize2(image, mask, self.fine_size, self.fine_size)

      if self.pad_left != 0:
        image = do_center_pad(image, self.pad_left, self.pad_right)

      image = image.reshape(1, image.shape[0], image.shape[1])
      mask = mask.reshape(1, mask.shape[0], mask.shape[1])  

      return image, mask

    elif self.mode == 'test':
      if self.fine_size != image.shape[0]:
        image = cv2.resize(image, dsize=(self.fine_size, self.fine_size))

      if self.pad_left != 0:
        image = do_center_pad(image, self.pad_left, self.pad_right)

      image = image.reshape(1, image.shape[0], image.shape[1])

      return image     

# Model

In [None]:
class Decoder(nn.Module):
  def __init__(self, in_channels, middle_channels, out_channels):
    super(Decoder, self).__init__()
    self.up = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2)
    self.conv_relu = nn.Sequential(
        nn.Conv2d(middle_channels, out_channels, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
        )
  def forward(self, x1, x2):
    x1 = self.up(x1)
    x1 = torch.cat((x1, x2), dim=1)
    x1 = self.conv_relu(x1)
    return x1

class UNet(nn.Module):
    def __init__(self, n_class):
        super().__init__()
        
        self.base_model = torchvision.models.resnet18(True)
        self.base_layers = list(self.base_model.children())
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False),
            self.base_layers[1],
            self.base_layers[2])
        self.layer2 = nn.Sequential(*self.base_layers[3:5])
        self.layer3 = self.base_layers[5]
        self.layer4 = self.base_layers[6]
        self.layer5 = self.base_layers[7]
        self.decode4 = Decoder(512, 256+256, 256)
        self.decode3 = Decoder(256, 256+128, 256)
        self.decode2 = Decoder(256, 128+64, 128)
        self.decode1 = Decoder(128, 64+64, 64)
        self.decode0 = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64, 32, kernel_size=3, padding=1, bias=False),
            nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=False)
            )
        self.conv_last = nn.Conv2d(64, n_class, 1)

    def forward(self, input):
        e1 = self.layer1(input) # 64,128,128
        e2 = self.layer2(e1) # 64,64,64
        e3 = self.layer3(e2) # 128,32,32
        e4 = self.layer4(e3) # 256,16,16
        f = self.layer5(e4) # 512,8,8
        d4 = self.decode4(f, e4) # 256,16,16
        d3 = self.decode3(d4, e3) # 256,32,32
        d2 = self.decode2(d3, e2) # 128,64,64
        d1 = self.decode1(d2, e1) # 64,128,128
        d0 = self.decode0(d1) # 64,256,256
        out = self.conv_last(d0) # 1,256,256
        return out

# Helper Functions

In [None]:
def train(train_loader, model):
  running_loss = 0.0
  data_size = len(train_data)

  model.train()

  for inputs, masks, labels in train_loader:
    inputs, masks, labels = inputs.to(device), masks.to(device), labels.to(device)
    optimizer.zero_grad()

    with torch.set_grad_enabled(True):
      logit = model(inputs)
      loss = nn.BCEWithLogitsLoss()(logit.squeeze(1), masks.squeeze(1))
      loss.backward()
      optimizer.step()

    running_loss += loss.item() * inputs.size(0)

  epoch_loss = running_loss / data_size
  return epoch_loss

def test(test_loader, model):
  running_loss = 0.0
  data_size = len(test_loader)
  predicts = []
  truths = []

  model.eval()

  for inputs, masks in test_loader:
    inputs, masks = inputs.to(device), masks.to(device)

    with torch.set_grad_enabled(False):
      outputs = model(inputs)
      # 深拷贝 contiguous
      outputs = outputs[:, :, pad_left:pad_left + fine_size, pad_left:pad_left + fine_size].contiguous()
      loss = nn.BCEWithLogitsLoss()(outputs.squeeze(1), masks.squeeze(1))

    predicts.append(torch.sigmoid(outputs).detach().cpu().numpy()) 
    truths.append(masks.detach().cpu().numpy())
    running_loss += loss.item() * inputs.size(0)

  predicts = np.concatenate(predicts).squeeze()
  truths = np.concatenate(truths).squeeze()
  precision, _, _ = do_kaggle_metric(predicts, truths, 0.5)
  precision = precision.mean()
  epoch_loss = running_loss / data_size
  return epoch_loss, precision

def rle_encode(im):
    '''
    im: numpy array, 1 - mask, 0 - background
    Returns run length as string formated
    '''
    pixels = im.flatten(order = 'F')
    pixels = np.concatenate([[0], pixels, [0]])
    runs = np.where(pixels[1:] != pixels[:-1])[0] + 1
    runs[1::2] -= runs[::2]
    return ' '.join(str(x) for x in runs)

# Train

In [8]:
all_id = train_df['id'].values
fold = []
for i in range(5):
  fold.append(train_df.loc[train_df['fold']==i, 'id'].values)

salt = UNet(1)
salt.to(device)

UNet(
  (base_model): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_run

In [14]:
for idx in range(5):

  if idx == 1:
    break

  # Setup optimizer
  scheduler_step = epoch // snapshot # 300//6
  optimizer = torch.optim.SGD(salt.parameters(), lr=max_lr, momentum=momentum, weight_decay=weight_decay)
  lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, min_lr)

  # Load data
  # 取不同的元素
  train_id = np.setdiff1d(all_id, fold[idx])
  val_id = fold[idx]
  # 取出数据
  X_train, y_train = trainImageFetch(train_id)
  X_val, y_val = trainImageFetch(val_id)
  # 制作数据集
  train_data = SaltDataset(X_train, 'train', y_train, pad_left=27, pad_right=27)
  val_data = SaltDataset(X_val, 'val', y_val, pad_left=27, pad_right=27)
  # 打乱，制作可迭代数据集
  train_loader = DataLoader(train_data,
                            shuffle=RandomSampler(train_data), 
                            batch_size=batch_size) 

  val_loader = DataLoader(val_data,
                            shuffle=False, 
                            batch_size=batch_size) 

  num_snapshot = 0
  best_acc = 0
# 训练
  for epoch_ in range(epoch): # 300
    train_loss = train(train_loader, salt)
    val_loss, accuracy = test(val_loader, salt)
    lr_scheduler.step()

    if accuracy > best_acc:
      best_acc = accuracy
      best_param = salt.state_dict()

    if (epoch_ + 1) % scheduler_step == 0:
      torch.save(best_param, save_weight + weight_name + str(idx) + str(num_snapshot) + '.pth')
      optimizer = torch.optim.SGD(salt.parameters(), lr=max_lr, momentum=momentum, weight_decay=weight_decay)
      lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, scheduler_step, min_lr)
      num_snapshot += 1
      best_acc = 0

    print('epoch: {} train_loss: {:.3f} val_loss: {:.3f} val_accuracy: {:.3f}'.format(epoch_ + 1, train_loss, val_loss, accuracy))

100%|██████████| 3190/3190 [00:01<00:00, 2349.81it/s]
100%|██████████| 810/810 [00:00<00:00, 2232.46it/s]


epoch: 1 train_loss: 0.552 val_loss: 12.026 val_accuracy: 0.385
epoch: 2 train_loss: 0.469 val_loss: 7.533 val_accuracy: 0.216
epoch: 3 train_loss: 0.424 val_loss: 8.726 val_accuracy: 0.213
epoch: 4 train_loss: 0.407 val_loss: 6.621 val_accuracy: 0.352
epoch: 5 train_loss: 0.383 val_loss: 6.971 val_accuracy: 0.272
epoch: 6 train_loss: 0.388 val_loss: 6.317 val_accuracy: 0.384
epoch: 7 train_loss: 0.359 val_loss: 10.014 val_accuracy: 0.365
epoch: 8 train_loss: 0.326 val_loss: 7.225 val_accuracy: 0.331
epoch: 9 train_loss: 0.300 val_loss: 5.103 val_accuracy: 0.570
epoch: 10 train_loss: 0.289 val_loss: 5.256 val_accuracy: 0.468
epoch: 11 train_loss: 0.250 val_loss: 3.828 val_accuracy: 0.625
epoch: 12 train_loss: 0.217 val_loss: 3.834 val_accuracy: 0.600
epoch: 13 train_loss: 0.184 val_loss: 3.655 val_accuracy: 0.666
epoch: 14 train_loss: 0.168 val_loss: 2.817 val_accuracy: 0.708
epoch: 15 train_loss: 0.153 val_loss: 3.231 val_accuracy: 0.628
epoch: 16 train_loss: 0.140 val_loss: 3.385 val

# Test

In [11]:
test_id = [x[:-4] for x in os.listdir(test_image_dir) if x[-4:] == '.png']
image_test = testImageFetch(test_id)
overall_pred_101 = np.zeros((len(test_id), 101, 101), dtype=np.float32)

for step in range(1, 6):

  print('Predicting Snapshot', step)
  pred_null = []

  # Load weight
  param = torch.load(save_weight + weight_name + '0' + str(step) + '.pth')
  salt.load_state_dict(param)

  # Dataloader
  test_data = SaltDataset(image_test, mode='test', fine_size=fine_size, pad_left=pad_left, pad_right=pad_right)
  test_loader = DataLoader(test_data,
                            shuffle=False,
                            batch_size=batch_size)
  
  # Prediction
  salt.eval()
  for images in tqdm(test_loader, total=len(test_loader)):
    images = images.to(device)
    with torch.set_grad_enabled(False):
      pred = salt(images)
      pred = torch.sigmoid(pred).squeeze(1).cpu().numpy()
      pred = pred[:, pad_left:pad_left + fine_size, pad_left:pad_left + fine_size]
      pred_null.append(pred)
  
  idx = 0
  for i in range(len(pred_null)):
    for j in range(batch_size):
      overall_pred_101[idx] += cv2.resize(pred_null[i][j], dsize=(101, 101))
      idx += 1

100%|██████████| 18000/18000 [00:13<00:00, 1320.24it/s]


Predicting Snapshot 1


100%|██████████| 1000/1000 [01:34<00:00, 10.38it/s]


Predicting Snapshot 2


100%|██████████| 1000/1000 [01:35<00:00, 10.47it/s]


Predicting Snapshot 3


100%|██████████| 1000/1000 [01:35<00:00, 10.46it/s]


Predicting Snapshot 4


100%|██████████| 1000/1000 [01:35<00:00, 10.42it/s]


Predicting Snapshot 5


100%|██████████| 1000/1000 [01:35<00:00, 10.55it/s]


In [None]:
submission = pd.DataFrame({'id':test_id, 'rle_mask':list(overall_pred_101)})
submission['rle_mask'] = submission['rle_mask'].map(lambda x: rle_encode(x>5*0.5))
submission.set_index('id', inplace=True)

sample_submission = pd.read_csv('tgs-salt-identification-challenge/sample_submission.csv')
sample_submission.set_index('id', inplace=True)
submission = submission.reindex(sample_submission.index)
submission.reset_index(inplace=True)
submission.to_csv('submission.csv', index=False)