In [1]:
import numpy as np
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt

import lovasz_losses as L
from metrics import iou_pytorch

from IPython.core.debugger import set_trace

device = torch.device("cuda:1")

## sample data

In [2]:
def prepare_data():
    # read numpy format data
    with open('../data/processed/dataset_%d.pkl'%SEED, 'rb') as f:
        ids_train, ids_valid, x_train, x_valid, y_train, y_valid, cov_train, cov_test, depth_train, depth_test = pickle.load(f)
    if debug:
        x_train, y_train = x_train[:400], y_train[:400]
        x_valid, y_valid = x_valid[:50], y_valid[:50]
        
    # make pytorch.data.Dataset
    train_ds = TgsDataSet(x_train, y_train)
    val_ds = TgsDataSet(x_valid, y_valid)
    
    train_dl = DataLoader(
        train_ds,
        batch_size=BATCH_SIZE,
        shuffle=True,
        #sampler=StratifiedSampler(),
        num_workers=NUM_WORKERS,
    )
    
    val_dl = DataLoader(
        val_ds,
        batch_size=BATCH_SIZE,
        shuffle=False,
        #sampler=StratifiedSampler(),
        num_workers=NUM_WORKERS,
    )
    
    return train_dl, val_dl

from dataset import TgsDataSet
from torch.utils.data import DataLoader

SEED = 1234
debug = True
BATCH_SIZE = 16#32
NUM_WORKERS = 20
train_dl, val_dl = prepare_data()


for i,(x,y) in enumerate(train_dl):
    x = x.to(device=device, dtype=torch.float)
    y = y.to(device=device, dtype=torch.float)
    # for classify zero mask modelling, 1: zero 0: nonzero
    y = y.reshape(-1, 256*256).sum(dim=1, keepdim=True)==0
    if i==0:
        break

## experiment with net architecture

In [3]:
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

from IPython.core.debugger import set_trace
import matplotlib.pyplot as plt

import lovasz_losses as L
#from metrics import iou_pytorch
from sklearn.metrics import roc_auc_score, confusion_matrix


class ConvBn2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=(3,3), stride=(1,1), padding=(1,1)):
        super(ConvBn2d, self).__init__()

        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False)
        #self.dropout = nn.Dropout2d(p=0.1, inplace=False)
        self.bn = nn.BatchNorm2d(out_channels)
        #self.bn = SynchronizedBatchNorm2d(out_channels)


    def forward(self, z):
        x = self.conv(z)
        #x = self.dropout(x)
        x = self.bn(x)
        return x

class Decoder(nn.Module):
    def __init__(self, in_channels, channels, out_channels):
        super(Decoder, self).__init__()
        self.conv1 =  ConvBn2d(in_channels,  channels, kernel_size=3, padding=1)
        self.conv2 =  ConvBn2d(channels, out_channels, kernel_size=3, padding=1)
        self.spa_cha_gate = SCSE(out_channels)

    def forward(self, x, e=None):
        x = F.upsample(x, scale_factor=2, mode='bilinear', align_corners=True)#False
        if e is not None:
            x = torch.cat([x, e], 1)
        x = F.relu(self.conv1(x),inplace=True)
        x = F.relu(self.conv2(x),inplace=True)
        x = self.spa_cha_gate(x)
        return x

class SCSE(nn.Module):
    def __init__(self, in_ch):
        super(SCSE, self).__init__()
        self.spatial_gate = SpatialGate2d(in_ch, 16)#16
        self.channel_gate = ChannelGate2d(in_ch)
    
    def forward(self, x):
        g1 = self.spatial_gate(x)
        g2 = self.channel_gate(x)
        x = g1 + g2 #x = g1*x + g2*x
        return x

class SpatialGate2d(nn.Module):
    def __init__(self, in_ch, r=16):
        super(SpatialGate2d, self).__init__()

        self.linear_1 = nn.Linear(in_ch, in_ch//r)
        self.linear_2 = nn.Linear(in_ch//r, in_ch)

    def forward(self, x):
        input_x = x

        x = x.view(*(x.shape[:-2]),-1).mean(-1)
        x = F.relu(self.linear_1(x), inplace=True)
        x = self.linear_2(x)
        x = x.unsqueeze(-1).unsqueeze(-1)
        x = F.sigmoid(x)

        x = input_x * x

        return x

class ChannelGate2d(nn.Module):
    def __init__(self, in_ch):
        super(ChannelGate2d, self).__init__()

        self.conv = nn.Conv2d(in_ch, 1, kernel_size=1, stride=1)

    def forward(self, x):
        input_x = x

        x = self.conv(x)
        x = F.sigmoid(x)

        x = input_x * x

        return x

class UNetResNet34(nn.Module):
    # PyTorch U-Net model using ResNet(34, 50 , 101 or 152) encoder.
    def load_pretrain(self, pretrain_file):
        self.encoder.load_state_dict(torch.load(pretrain_file, map_location=lambda storage, loc: storage))

    def __init__(self, pretrained=True, debug=False):
        super().__init__()
        self.resnet = torchvision.models.resnet34(pretrained=pretrained)
        self.debug = debug

        self.conv1 = nn.Sequential(
            self.resnet.conv1,
            self.resnet.bn1,
            self.resnet.relu,
            #self.resnet.maxpool,
        )# 64
        self.encoder2 = nn.Sequential(self.resnet.layer1, SCSE(64))
        self.encoder3 = nn.Sequential(self.resnet.layer2, SCSE(128))
        self.encoder4 = nn.Sequential(self.resnet.layer3, SCSE(256))
        self.encoder5 = nn.Sequential(self.resnet.layer4, SCSE(512))

        self.avgpool = nn.AvgPool2d(7, stride=1)
        self.fc = nn.Linear(51200, 1)

    def forward(self, x):
        #batch_size,C,H,W = x.shape

        mean=[0.485, 0.456, 0.406]
        std =[0.229, 0.224, 0.225]
        x = torch.cat([
            (x-mean[0])/std[0],
            (x-mean[1])/std[1],
            (x-mean[2])/std[2],
        ],1)
        #x = add_depth_channels(x)
        
        if self.debug:
            print('input: ', x.size())

        x = self.conv1(x)
        if self.debug:
            print('e1',x.size())
        e2 = self.encoder2(x)
        if self.debug:
            print('e2',e2.size())
        e3 = self.encoder3(e2)
        if self.debug:
            print('e3',e3.size())
        e4 = self.encoder4(e3)
        if self.debug:
            print('e4',e4.size())
        e5 = self.encoder5(e4)
        if self.debug:
            print('e5',e5.size())

        f = self.avgpool(e5)
        if self.debug:
            print('avgpool: ',f.size())
        f = F.dropout(f, p=0.40)#training=self.training
        
        f = f.view(f.size(0), -1)
        logit = self.fc(f)
        if self.debug:
            print('fc: ', logit.size())
        return logit

        ##-----------------------------------------------------------------

    def criterion(self, logit, truth):
        """Define the (customized) loss function here."""
        loss = L.binary_xloss(logit, truth, ignore=255)
        return loss

    def metric(self, logit, truth):
        """Define metrics for evaluation especially for early stoppping."""
        auc = roc_auc_score(truth.detach(), logit.detach())
        #tn, fp, fn, tp = confusion_matrix(truth.detach(), logit.detach()).ravel()
        return auc#, [tn, fp, fn, tp]

    def set_mode(self, mode):
        self.mode = mode
        if mode in ['eval', 'valid', 'test']:
            self.eval()
        elif mode in ['train']:
            self.train()
        else:
            raise NotImplementedError


def predict_proba(net, test_dl, device):
    y_pred = None
    net.set_mode('test')
    with torch.no_grad():
        for i, (input_data, truth) in enumerate(test_dl):
            #if i > 10:
            #    break
            input_data, truth = input_data.to(device=device, dtype=torch.float), truth.to(device=device, dtype=torch.float)
            logit = net(input_data).cpu().numpy()
            if y_pred is None:
                y_pred = logit
            else:
                y_pred = np.concatenate([y_pred, logit], axis=0)
    return y_pred


def add_depth_channels(image_tensor):
    _, _, h, w = image_tensor.size()
    x_depth_channel = torch.ones(image_tensor.size(), dtype=torch.float64, device=image_tensor.device)
    for row, const in enumerate(np.linspace(0, 1, h)):
        x_depth_channel[:, 0, row, :] = const
    x_depth_channel = x_depth_channel.float()
    x_depth_channel_mul = image_tensor * x_depth_channel
    image_tensor = torch.cat([image_tensor, x_depth_channel, x_depth_channel_mul], 1)
    return image_tensor



In [4]:
net = UNetResNet34(debug=True).cuda(device)

In [5]:
# new decoder version
output = net(x)

input:  torch.Size([16, 3, 256, 256])
e1 torch.Size([16, 64, 128, 128])
e2 torch.Size([16, 64, 128, 128])
e3 torch.Size([16, 128, 64, 64])
e4 torch.Size([16, 256, 32, 32])
e5 torch.Size([16, 512, 16, 16])
avgpool:  torch.Size([16, 512, 10, 10])
fc:  torch.Size([16, 1])


In [6]:
net.criterion(output, y)

tensor(0.6579, device='cuda:1')

In [6]:
#layer_name = 'decoder'
#layer_name = 'resnet'
#layer_name = 'center'
#layer_name = '.'

#layer_name = '.'

print('Total parameters: ', sum(p[1].numel() for p in net.named_parameters() \
                                if '.' in p[0]))

#print('Trainable parameters: ', sum(p[1].numel() for p in net.named_parameters() \
#                                    if p[1].requires_grad and layer_name in p[0]))

print('Encoder parameters: ', sum(p[1].numel() for p in net.named_parameters() \
                                if 'resnet' in p[0]))

print('Center parameters: ', sum(p[1].numel() for p in net.named_parameters() \
                                if 'center' in p[0]))

print('Decoder parameters: ', sum(p[1].numel() for p in net.named_parameters() \
                                if 'decoder' in p[0]))

Total parameters:  21894377
Encoder parameters:  21797672
Center parameters:  0
Decoder parameters:  0


In [8]:
for i,p in enumerate(net.named_parameters()):
    print(p[0])

resnet.conv1.weight
resnet.bn1.weight
resnet.bn1.bias
resnet.layer1.0.conv1.weight
resnet.layer1.0.bn1.weight
resnet.layer1.0.bn1.bias
resnet.layer1.0.conv2.weight
resnet.layer1.0.bn2.weight
resnet.layer1.0.bn2.bias
resnet.layer1.1.conv1.weight
resnet.layer1.1.bn1.weight
resnet.layer1.1.bn1.bias
resnet.layer1.1.conv2.weight
resnet.layer1.1.bn2.weight
resnet.layer1.1.bn2.bias
resnet.layer1.2.conv1.weight
resnet.layer1.2.bn1.weight
resnet.layer1.2.bn1.bias
resnet.layer1.2.conv2.weight
resnet.layer1.2.bn2.weight
resnet.layer1.2.bn2.bias
resnet.layer2.0.conv1.weight
resnet.layer2.0.bn1.weight
resnet.layer2.0.bn1.bias
resnet.layer2.0.conv2.weight
resnet.layer2.0.bn2.weight
resnet.layer2.0.bn2.bias
resnet.layer2.0.downsample.0.weight
resnet.layer2.0.downsample.1.weight
resnet.layer2.0.downsample.1.bias
resnet.layer2.1.conv1.weight
resnet.layer2.1.bn1.weight
resnet.layer2.1.bn1.bias
resnet.layer2.1.conv2.weight
resnet.layer2.1.bn2.weight
resnet.layer2.1.bn2.bias
resnet.layer2.2.conv1.weight
r

In [8]:
layers = [p for p in net.named_parameters() if layer_name in p[0]]

In [9]:
layer = layers[0]
layer[1].shape

torch.Size([64, 3, 7, 7])

In [None]:
for i,(x,y) in enumerate(train_dl):
    x = x.to(device=device, dtype=torch.float)
    y = y.to(device=device, dtype=torch.float)
    if i==0:
        break

In [None]:
output = net(x)

In [None]:
## freeze layers parameters

for param in net.named_parameters():
    #print(param[0][:8])
    if param[0][:8] in ['decoder5']:#'decoder5', 'decoder4', 'decoder3', 'decoder2'
        #print('no')
        param[1].requires_grad = False

In [None]:
## different version : nn.AdaptiveAvgPool2d layer
class SCSE(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SCSE, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)

        self.channel_excitation = nn.Sequential(nn.Linear(channel, int(channel//reduction)),
                                                nn.ReLU(inplace=True),
                                                nn.Linear(int(channel//reduction), channel),
                                                nn.Sigmoid())

        self.spatial_se = nn.Sequential(nn.Conv2d(channel, 1, kernel_size=1,
                                                  stride=1, padding=0, bias=False),
                                        nn.Sigmoid())

    def forward(self, x):
        bahs, chs, _, _ = x.size()

        # Returns a new tensor with the same data as the self tensor but of a different size.
        chn_se = self.avg_pool(x).view(bahs, chs)
        chn_se = self.channel_excitation(chn_se).view(bahs, chs, 1, 1)
        chn_se = torch.mul(x, chn_se)

        spa_se = self.spatial_se(x)
        spa_se = torch.mul(x, spa_se)
        return torch.add(chn_se, 1, spa_se)

In [44]:
## neptune's open solutions

class ConvBnRelu(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(nn.Conv2d(in_channels, out_channels, 3, padding=1),
                                  nn.BatchNorm2d(out_channels),
                                  nn.ReLU(inplace=True))

    def forward(self, x):
        return self.conv(x)

class DecoderBlockV2(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels, is_deconv=True):
        super(DecoderBlockV2, self).__init__()
        self.is_deconv = is_deconv

        self.deconv = nn.Sequential(
            ConvBnRelu(in_channels, middle_channels),
            nn.ConvTranspose2d(middle_channels, out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

        self.upsample = nn.Sequential(
            ConvBnRelu(in_channels, out_channels),
            nn.Upsample(scale_factor=2, mode='bilinear'),
        )

    def forward(self, x):
        if self.is_deconv:
            x = self.deconv(x)
        else:
            x = self.upsample(x)
        return x

class SaltLinkNet(nn.Module):
    def __init__(self, num_classes, dropout_2d=0.2, pretrained=False, is_deconv=False):
        super().__init__()
        self.num_classes = num_classes
        self.dropout_2d = dropout_2d

        self.encoder = torchvision.models.resnet34(pretrained=pretrained)

        self.relu = nn.ReLU(inplace=True)

        self.input_adjust = nn.Sequential(self.encoder.conv1,
                                          self.encoder.bn1,
                                          self.encoder.relu)

        self.conv1_1 = list(self.encoder.layer1.children())[1]
        self.conv1_2 = list(self.encoder.layer1.children())[2]

        self.conv2_0 = list(self.encoder.layer2.children())[0]
        self.conv2_1 = list(self.encoder.layer2.children())[1]
        self.conv2_2 = list(self.encoder.layer2.children())[2]
        self.conv2_3 = list(self.encoder.layer2.children())[3]

        self.dec2 = DecoderBlockV2(128, 256, 256, is_deconv=is_deconv)
        self.dec1 = DecoderBlockV2(256 + 64, 512, 256, is_deconv=is_deconv)
        self.final = nn.Conv2d(256, num_classes, kernel_size=1)

    def forward(self, x):
        mean=[0.485, 0.456, 0.406]
        std =[0.229, 0.224, 0.225]
        x = torch.cat([
            (x-mean[0])/std[0],
            (x-mean[1])/std[1],
            (x-mean[2])/std[2],
        ],1)
        print('input shape: ', x.size())
        
        input_adjust = self.input_adjust(x)
        conv1_1 = self.conv1_1(input_adjust)
        conv1_2 = self.conv1_2(conv1_1)
        conv2_0 = self.conv2_0(conv1_2)
        conv2_1 = self.conv2_1(conv2_0)
        conv2_2 = self.conv2_2(conv2_1)
        conv2_3 = self.conv2_3(conv2_2)

        conv1_sum = conv1_1 + conv1_2
        conv2_sum = conv2_0 + conv2_1 + conv2_2 + conv2_3

        dec2 = self.dec2(conv2_sum)
        dec1 = self.dec1(torch.cat([dec2, conv1_sum], 1))

        return self.final(F.dropout2d(dec1, p=self.dropout_2d))