In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import time
import torch.nn.init as ini
import h5py
import random

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.first_layer_down_conv1 = nn.Conv3d(4, 8, 3, padding = 1)
        self.first_layer_down_bn1 = nn.BatchNorm3d(8)
        self.first_layer_down_pre1 = nn.PReLU()
        self.second_layer_down_conv1 = nn.Conv3d(8, 16, 3, padding = 1, stride = 2)
        self.second_layer_down_bn1 = nn.BatchNorm3d(16)
        self.second_layer_down_pre1 = nn.PReLU()
        self.second_layer_down_conv2 = nn.Conv3d(16, 16, 3, padding = 1)
        self.second_layer_down_bn2 = nn.BatchNorm3d(16)
        self.second_layer_down_pre2 = nn.PReLU()
        self.third_layer_down_conv1 = nn.Conv3d(16, 32, 3, padding = 1, stride = 2)
        self.third_layer_down_bn1 = nn.BatchNorm3d(32)
        self.third_layer_down_pre1 = nn.PReLU()
        self.third_layer_down_conv2 = nn.Conv3d(32, 32, 3, padding = 1)
        self.third_layer_down_bn2 = nn.BatchNorm3d(32)
        self.third_layer_down_pre2 = nn.PReLU()
        self.fourth_layer_down_conv1 = nn.Conv3d(32, 64, 3, padding = 1, stride = 2)
        self.fourth_layer_down_bn1 = nn.BatchNorm3d(64)
        self.fourth_layer_down_pre1 = nn.PReLU()
        self.fourth_layer_down_conv2 = nn.Conv3d(64, 64, 3, padding = 1)
        self.fourth_layer_up_conv1 = nn.Conv3d(64, 64, 1)
        self.fourth_layer_up_bn1 = nn.BatchNorm3d(64)
        self.fourth_layer_up_pre1 = nn.PReLU()
        self.fourth_layer_up_deconv = nn.ConvTranspose3d(64, 32, 3, padding = 1, output_padding = 1, stride = 2)
        self.fourth_layer_up_bn2 = nn.BatchNorm3d(32)
        self.fourth_layer_up_pre2 = nn.PReLU()
        self.third_layer_up_conv1 = nn.Conv3d(64, 64, 3, padding = 1)
        self.third_layer_up_bn1 = nn.BatchNorm3d(64)
        self.third_layer_up_pre1 = nn.PReLU()
        self.third_layer_up_conv2 = nn.Conv3d(64, 32, 1)
        self.third_layer_up_bn2 = nn.BatchNorm3d(32)
        self.third_layer_up_pre2 = nn.PReLU()
        self.third_layer_up_deconv = nn.ConvTranspose3d(32, 16, 3, padding = 1, output_padding = 1, stride = 2)
        self.third_layer_up_bn3 = nn.BatchNorm3d(16)
        self.third_layer_up_pre3 = nn.PReLU()
        self.second_layer_up_conv1 = nn.Conv3d(32, 32, 3, padding = 1)
        self.second_layer_up_bn1 = nn.BatchNorm3d(32)
        self.second_layer_up_pre1 = nn.PReLU()
        self.second_layer_up_conv2 = nn.Conv3d(32, 16, 1)
        self.second_layer_up_bn2 = nn.BatchNorm3d(16)
        self.second_layer_up_pre2 = nn.PReLU()
        self.second_layer_up_deconv = nn.ConvTranspose3d(16, 8, 3, padding = 1, output_padding = 1, stride = 2)
        self.second_layer_up_bn3 = nn.BatchNorm3d(8)
        self.second_layer_up_pre3 = nn.PReLU()
        self.first_layer_up_conv1 = nn.Conv3d(16, 16, 3, padding = 1)
        self.first_layer_up_bn1 = nn.BatchNorm3d(16)
        self.first_layer_up_pre1 = nn.PReLU()
        self.third_seg = nn.Conv3d(64, 5, 1)
        self.second_seg = nn.Conv3d(32, 5, 1)
        self.first_seg = nn.Conv3d(16, 5, 1)
        self.upsample_layer = nn.Upsample(scale_factor = 2, mode = 'trilinear')

    def forward(self, x):
        x = self.first_layer_down_conv1(x)
        x = self.first_layer_down_bn1(x)
        x = self.first_layer_down_pre1(x)
        first_layer_feature = x
        
        x = self.second_layer_down_conv1(x)
        temp = x
        x = self.second_layer_down_bn1(x)
        x = self.second_layer_down_pre1(x)
        x = self.second_layer_down_conv2(x)
        x = torch.add(x, temp)
        x = self.second_layer_down_bn2(x)
        x = self.second_layer_down_pre2(x)
        second_layer_feature = x
        
        x = self.third_layer_down_conv1(x)
        temp = x
        x = self.third_layer_down_bn1(x)
        x = self.third_layer_down_pre1(x)
        x = self.third_layer_down_conv2(x)
        x = torch.add(x, temp)
        x = self.third_layer_down_bn2(x)
        x = self.third_layer_down_pre2(x)
        third_layer_feature = x
        
        x = self.fourth_layer_down_conv1(x)
        temp = x
        x = self.fourth_layer_down_bn1(x)
        x = self.fourth_layer_down_pre1(x)
        x = self.fourth_layer_down_conv2(x)
        x = torch.add(x, temp)
        
        x = self.fourth_layer_up_conv1(x)
        x = self.fourth_layer_up_bn1(x)
        x = self.fourth_layer_up_pre1(x)
        x = self.fourth_layer_up_deconv(x)
        x = self.fourth_layer_up_bn2(x)
        x = self.fourth_layer_up_pre2(x)
        
        x = torch.cat((x, third_layer_feature), 1)
        x = self.third_layer_up_conv1(x)
        x = self.third_layer_up_bn1(x)
        x = self.third_layer_up_pre1(x)
        third_seg_map = self.third_seg(x)
        x = self.third_layer_up_conv2(x)
        x = self.third_layer_up_bn2(x)
        x = self.third_layer_up_pre2(x)
        x = self.third_layer_up_deconv(x)
        x = self.third_layer_up_bn3(x)
        x = self.third_layer_up_pre3(x)
        
        x = torch.cat((x, second_layer_feature), 1)
        x = self.second_layer_up_conv1(x)
        x = self.second_layer_up_bn1(x)
        x = self.second_layer_up_pre1(x)
        second_seg_map = self.second_seg(x)
        x = self.second_layer_up_conv2(x)
        x = self.second_layer_up_bn2(x)
        x = self.second_layer_up_pre2(x)
        x = self.second_layer_up_deconv(x)
        x = self.second_layer_up_bn3(x)
        x = self.second_layer_up_pre3(x)
        
        x = torch.cat((x, first_layer_feature), 1)
        x = self.first_layer_up_conv1(x)
        x = self.first_layer_up_bn1(x)
        x = self.first_layer_up_pre1(x)
        first_seg_map = self.first_seg(x)
        
        third_seg_map = self.upsample_layer(third_seg_map)
        second_seg_map = torch.add(third_seg_map, second_seg_map)
        second_seg_map = self.upsample_layer(second_seg_map)
        x = torch.add(first_seg_map, second_seg_map)
        return x
        
net = Unet()
prev_time = time.clock()

for param in net.parameters():
    try:
        nout = param.size()[0]
        nin = param.size()[1]
        ini.normal(param.data, mean = 0, std = 0.01)
        param = param / ((2/(nin+nout))**0.5)
    except:
        pass



In [2]:
f = h5py.File('Unet-training.h5')
SAMPLE = [ "LG/0001", "LG/0002", "LG/0004", "LG/0006", "LG/0008", "LG/0011",
          "LG/0012", "LG/0013", "LG/0014", "LG/0015", "HG/0001", "HG/0002",
          "HG/0003", "HG/0004", "HG/0005", "HG/0006", "HG/0007", "HG/0008",
          "HG/0009", "HG/0010", "HG/0011", "HG/0012", "HG/0013", "HG/0014",
          "HG/0015", "HG/0022", "HG/0024", "HG/0025", "HG/0026"]

def create_train_batch(img = 0):
    case = SAMPLE[img]
    key0 = case[:2]
    key1 = case[3:]
    _, X, Y, Z = f[key0][key1].shape
    train_batch = [];
    train_label = [];
    x = random.randint(64, X-64)
    y = random.randint(64, Y-64)
    z = random.randint(48, Z-48)
    train_batch.append(f[key0][key1][0:4,x-64:x+64,y-64:y+64,z-48:z+48])
    train_label.append(f[key0][key1][4,x-64:x+64,y-64:y+64,z-48:z+48])
    train_batch = np.array(train_batch)
    train_label = np.array(train_label)
    train_batch = torch.from_numpy(train_batch)
    train_label = torch.from_numpy(train_label)
    train_label = torch.Tensor.long(train_label)
    return train_batch, train_label

def create_val():
    case = "HG/0001"
    key0 = case[:2]
    key1 = case[3:]
    _, X, Y, Z = f[key0][key1].shape
    val_batch = [];
    val_label = [];
    x = X//2
    y = Y//2
    z = Z//2
    val_batch.append(f[key0][key1][0:4,x-64:x+64,y-64:y+64,z-48:z+48])
    val_label.append(f[key0][key1][4,x-64:x+64,y-64:y+64,z-48:z+48])
    val_batch = np.array(val_batch)
    val_label = np.array(val_label)
    val_batch = torch.from_numpy(val_batch)
    val_label = torch.from_numpy(val_label)
    val_label = torch.Tensor.long(val_label)
    return val_batch, val_label

val_x, val_y = create_val()
val_x = Variable(val_x, requires_grad=True)
val_y = val_y.view(-1)
y_pred = net.forward(val_x)

In [3]:
loss=Variable(torch.Tensor([0]), requires_grad=True)
p0=Variable(torch.Tensor([0,1,0,1,0]), requires_grad=True)
t0=Variable(torch.Tensor([0,1,0,1,0]), requires_grad=True)
loss=1-torch.norm(torch.mul(p0,t0))/(torch.norm(p0)**2+torch.norm(t0)**2-torch.norm(torch.mul(p0,t0)))
loss.backward()
print(p0.grad)

Variable containing:
1.00000e-08 *
  0.0000
  2.9802
  0.0000
  2.9802
  0.0000
[torch.FloatTensor of size 5]



In [4]:
smooth=1e-100

def dice_coef(y_true, y_pred):
    y_true_f = y_true.view(-1)
    y_pred_f = y_pred.view(-1)
    intersection = torch.sum(y_true_f * y_pred_f)
    return (2.0 * intersection + smooth) / (torch.sum(y_true_f*y_true_f) + torch.sum(y_pred_f*y_pred_f) + smooth)


def dice_coef_loss(y_true, y_pred):
    return 1.-dice_coef(y_true, y_pred)

def jacc_loss(out, val_batch):
    for i in range(5):
    #transform for the prediction
        pp=(out-(i-1))
        pp=F.relu(pp)
        pn=(-1*out+(i+1))
        pn = F.relu(pn)
        p=torch.mul(pp, pn)+1e-10
        #transfrom for the truth
        tp=(val_batch-(i-1))
        tp=F.relu(tp)
        tn=(-1*val_batch+(i+1))
        tn=F.relu(tn)
        t=torch.mul(tp, tn)+1e-10
        up=torch.dot(p, t)
        p=torch.norm(p)
        p=torch.pow(p, 2)
        t=torch.norm(t)
        t=torch.pow(t, 2)
        #compute Jaccard
        jacc=up/(p+t-up)
        print(jacc)
        if i==0:
            loss=1-jacc
        else:
            loss=loss+1-jacc
    return loss

In [5]:
out=Variable(torch.Tensor([0,2,0,4,1]), requires_grad=True).float()
val_batch=Variable(torch.Tensor([0,2,0,0,1]), requires_grad=True).float()
loss=jacc_loss(out, val_batch)
loss.backward()
print(out.grad)

Variable containing:
 0.6667
[torch.FloatTensor of size 1]

Variable containing:
 1
[torch.FloatTensor of size 1]

Variable containing:
 1
[torch.FloatTensor of size 1]

Variable containing:
 1
[torch.FloatTensor of size 1]

Variable containing:
1.00000e-10 *
  1.0000
[torch.FloatTensor of size 1]

Variable containing:
 0
 0
 0
 0
 0
[torch.FloatTensor of size 5]



In [2]:
smooth=1e-3
class JaccLoss(nn.Module):
    def __init__(self, size_average=True):
        super(JaccLoss, self).__init__()
        self.size_average=size_average
    def forward(self, out, val_batch):
        pp00=out-(0.0-1.0)
        pp0=F.relu(pp00)
        pn00=(-1.0*out+(0.0+1.0))
        pn0=F.relu(pn00)
        p0=torch.mul(pp0,pn0)
        pp00.register_hook(lambda g: print(g, "pp00"))
        pp0.register_hook(lambda g: print(g, "pp0"))
        pn00.register_hook(lambda g: print(g, "pn00"))
        pn0.register_hook(lambda g: print(g, "pn0"))
        tp0=val_batch-(0.0-1.0)
        tp0=F.relu(tp0)
        tn0=-1.0*val_batch+(0.0+1.0)
        tn0=F.relu(tn0)
        t0=torch.mul(tp0, tn0)
        intersection0 = torch.sum(t0 * p0)
        jacc0=(intersection0) / (torch.sum(t0*t0) + torch.sum(p0*p0) - intersection0)
        loss=1.0-jacc0
        
        pp11=out-(1.0-1.0)
        pp1=F.relu(pp11)
        pn11=(-1.0*out+(1.0+1.0))
        pn1=F.relu(pn11)
        p1=torch.mul(pp1,pn1)
        pp11.register_hook(lambda g: print(g, "pp11"))
        pp1.register_hook(lambda g: print(g, "pp1"))
        pn11.register_hook(lambda g: print(g, "pn11"))
        pn1.register_hook(lambda g: print(g,"pn1"))
        tp1=val_batch-(1.0-1.0)
        tp1=F.relu(tp1)
        tn1=-1.0*val_batch+(1.0+1.0)
        tn1=F.relu(tn1)
        t1=torch.mul(tp1, tn1)
        intersection1 = torch.sum(t1 * p1)
        jacc1=(intersection1) / (torch.sum(t1*t1) + torch.sum(p1*p1) - intersection1)
        loss=loss+1.0-jacc1
        
        pp22=out-(2.0-1.0)
        pp2=F.relu(pp22)
        pn22=(-1.0*out+(2.0+1.0))
        pn2=F.relu(pn22)
        p2=torch.mul(pp2,pn2)
        pp22.register_hook(lambda g: print(g, "pp22"))
        pp2.register_hook(lambda g: print(g, "pp2"))
        pn22.register_hook(lambda g: print(g, "pn22"))
        pn2.register_hook(lambda g: print(g, "pn2"))
        tp2=val_batch-(2.0-1.0)
        tp2=F.relu(tp2)
        tn2=-1.0*val_batch+(2.0+1.0)
        tn2=F.relu(tn2)
        t2=torch.mul(tp2, tn2)
        intersection2 = torch.sum(t2 * p2)
        jacc2=(intersection2) / (torch.sum(t2*t2) + torch.sum(p2*p2) -intersection2)
        loss=loss+1.0-jacc2
        
        pp33=out-(3.0-1.0)
        pp3=F.relu(pp33)
        pn33=(-1.0*out+(3.0+1.0))
        pn3=F.relu(pn33)
        p3=torch.mul(pp3,pn3)
        pp33.register_hook(lambda g: print(g, "pp33"))
        pp3.register_hook(lambda g: print(g, "pp3"))
        pn33.register_hook(lambda g: print(g, "pn33"))
        pn3.register_hook(lambda g: print(g, "pn3"))
        tp3=val_batch-(3.0-1.0)
        tp3=F.relu(tp3)
        tn3=-1.0*val_batch+(3.0+1.0)
        tn3=F.relu(tn3)
        t3=torch.mul(tp3, tn3)
        intersection3 = torch.sum(t3 * p3)
        jacc3=(intersection3) / (torch.sum(t3*t3) + torch.sum(p3*p3) - intersection3)
        loss=loss+1.0-jacc3
        
        pp44=out-(4.0-1)
        pp4=F.relu(pp44)
        pn44=(-1.0*out+(4.0+1.0))
        pn4=F.relu(pn44)
        p4=torch.mul(pp4,pn4)
        pp44.register_hook(lambda g: print(g, "pp44"))
        pp4.register_hook(lambda g: print(g, "pp4"))
        pn44.register_hook(lambda g: print(g, "pn44"))
        pn4.register_hook(lambda g: print(g, "pn4"))
        tp4=val_batch-(4.0-1.0)
        tp4=F.relu(tp4)
        tn4=-1.0*val_batch+(4.0+1.0)
        tn4=F.relu(tn4)
        t4=torch.mul(tp4, tn4)
        intersection4 = torch.sum(t4 * p4)
        jacc4=(intersection4) / (torch.sum(t4*t4) + torch.sum(p4*p4) -intersection4)
        loss=loss+1.0-jacc4
        
        return loss

In [3]:
out=Variable(torch.Tensor([0,0,1,0,4,1,0]).cuda(3), requires_grad=True).float()
val_batch=Variable(torch.Tensor([0,0,1,2,3,4,0]).cuda(3), requires_grad=True).float()
cri = JaccLoss()
loss=cri(out, val_batch)
print(loss)
loss.backward()

Variable containing:
 3.7500
[torch.cuda.FloatTensor of size 1 (GPU 3)]

Variable containing:
 0
 0
 0
 0
 0
-2
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pp4
Variable containing:
 0
 0
 0
 0
 0
-0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pn4
Variable containing:
 0
 0
 0
 0
 0
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pp44
Variable containing:
 0
 0
 0
 0
 0
-0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pn44
Variable containing:
 0
 0
 0
 0
-0
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pp3
Variable containing:
 0
 0
 0
 0
-2
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pn3
Variable containing:
 0
 0
 0
 0
-0
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pp33
Variable containing:
 0
 0
 0
 0
 0
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pn33
Variable containing:
 0
 0
 0
-3
 0
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pp2
Variable containing:
 0
 0
 0
-0
 0
 0
 0
[torch.cuda.FloatTensor of size 7 (GPU 3)]
 pn2
Variable containing:
 0

In [None]:
net = Unet()
num_epoch = 200;
from torch.optim.lr_scheduler import StepLR
import torch.optim as optim
optimizer = optim.Adam(net.parameters(), lr = 5e-3, weight_decay = 5e-10)
for i in range(num_epoch):
    random.shuffle(SAMPLE)
    for j in range(len(SAMPLE)):
        optimizer.zero_grad()
        train_batch, val_batch = create_train_batch(j)
        train_batch = Variable(train_batch, requires_grad=True)
        val_batch = Variable(val_batch, requires_grad=True)
        out = net.forward(train_batch)
        out = torch.transpose(out, 0, 1)
        out.contiguous()
        out = out.view(5, -1)
        out = torch.transpose(out, 0, 1)
        out.contiguous()
        _,out=torch.max(out, 1)
        out=out.view(-1).float()
        val_batch = val_batch.view(-1).float()
        criterion=dice_coef_loss()
        loss=criterion(out, val_batch)
        print(loss)
        loss.backward()
        optimizer.step()
        y_pred = net.forward(val_x)
        y_pred = y_pred.view(5, -1)
        y_pred = torch.transpose(y_pred, 0, 1)
        out.contiguous()
        _, y_pred = torch.max(y_pred.data, 1)
        correct = (y_pred == val_y).sum()
        print('Validation accuracy:', float(correct) / 128 / 128 / 96)
        print('time used:%.3f'% (time.clock() - prev_time))

In [None]:
torch.save(net.state_dict(), "unet.txt")
print ("successfully saved!")