Skip to content

Commit

Permalink
Created a basic train loop + changed a bit loss and utils
Browse files Browse the repository at this point in the history
  • Loading branch information
milesial committed Aug 17, 2017
1 parent 8332f89 commit 4063565
Show file tree
Hide file tree
Showing 8 changed files with 195 additions and 40 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
*.pyc
data/
__pycache__/
checkpoints/
*.pth

1 change: 1 addition & 0 deletions data_vis.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import matplotlib.pyplot as plt


def plot_img_mask(img, mask):
fig = plt.figure()

Expand Down
31 changes: 13 additions & 18 deletions load.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,42 @@

#
# load.py : utils on generators / lists of ids to transform from strings to
# cropped images and masks

import os
import random
import numpy as np

from PIL import Image
from functools import partial
from utils import resize_and_crop, get_square
from utils import resize_and_crop, get_square, normalize


def get_ids(dir):
"""Returns a list of the ids in the directory"""
return (f[:-4] for f in os.listdir(dir))


def split_ids(ids, n=2):
"""Split each id in n, creating n tuples (id, k) for each id"""
return ((id, i) for i in range(n) for id in ids)

def shuffle_ids(ids):
"""Returns a shuffle list od the ids"""
lst = list(ids)
random.shuffle(lst)
return lst

def to_cropped_imgs(ids, dir, suffix):
"""From a list of tuples, returns the correct cropped img (left or right)"""
"""From a list of tuples, returns the correct cropped img"""
for id, pos in ids:
im = resize_and_crop(Image.open(dir + id + suffix))
yield get_square(im, pos)



def get_imgs_and_masks():
"""From the list of ids, return the couples (img, mask)"""
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'

ids = get_ids(dir_img)
ids = split_ids(ids)
ids = shuffle_ids(ids)
def get_imgs_and_masks(ids, dir_img, dir_mask):
"""Return all the couples (img, mask)"""

imgs = to_cropped_imgs(ids, dir_img, '.jpg')

# need to transform from HWC to CHW
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs)
imgs_normalized = map(normalize, imgs_switched)

masks = to_cropped_imgs(ids, dir_mask, '_mask.gif')

return zip(imgs_switched, masks)
return zip(imgs_normalized, masks)
44 changes: 31 additions & 13 deletions myloss.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,52 @@

#
# myloss.py : implementation of the Dice coeff and the associated loss
#

import torch
from torch.nn.modules.loss import _Loss
from torch.autograd import Function
import torch.nn.functional as F

class DiceCoeff(Function):
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable


def forward(ctx, input, target):
ctx.save_for_backward(input, target)
ctx.inter = torch.dot(input, target) + 0.0001
ctx.union = torch.sum(input) + torch.sum(target) + 0.0001
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
self.save_for_backward(input, target)
self.inter = torch.dot(input, target) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001

t = 2*ctx.inter.float()/ctx.union.float()
t = 2*self.inter.float()/self.union.float()
return t

# This function has only a single output, so it gets only one gradient
def backward(ctx, grad_output):
def backward(self, grad_output):

input, target = ctx.saved_variables
input, target = self.saved_variables
grad_input = grad_target = None

if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * ctx.union + ctx.inter) \
/ ctx.union * ctx.union
grad_input = grad_output * 2 * (target * self.union + self.inter) \
/ self.union * self.union
if self.needs_input_grad[1]:
grad_target = None

return grad_input, grad_target


def dice_coeff(input, target):
return DiceCoeff().forward(input, target)
"""Dice coeff for batches"""
if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_())
else:
s = Variable(torch.FloatTensor(1).zero_())

for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])

return s / (i+1)


class DiceLoss(_Loss):
def forward(self, input, target):
Expand Down
105 changes: 105 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import torch

from load import *
from data_vis import *
from utils import split_train_val, batch
from myloss import DiceLoss
from unet_model import UNet
from torch.autograd import Variable
from torch import optim
from optparse import OptionParser


def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05,
cp=True, gpu=False):
dir_img = 'data/train/'
dir_mask = 'data/train_masks/'
dir_checkpoint = 'checkpoints/'

# get ids
ids = get_ids(dir_img)
ids = split_ids(ids)

iddataset = split_train_val(ids, val_percent)

print('''
Starting training:
Epochs: {}
Batch size: {}
Learning rate: {}
Training size: {}
Validation size: {}
Checkpoints: {}
CUDA: {}
'''.format(epochs, batch_size, lr, len(iddataset['train']),
len(iddataset['val']), str(cp), str(gpu)))

N_train = len(iddataset['train'])

train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask)
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask)

optimizer = optim.Adam(net.parameters(), lr=lr)
criterion = DiceLoss()

for epoch in range(epochs):
print('Starting epoch {}/{}.'.format(epoch+1, epochs))

epoch_loss = 0

for i, b in enumerate(batch(train, batch_size)):
X = np.array([i[0] for i in b])
y = np.array([i[1] for i in b])

X = torch.FloatTensor(X)
y = torch.ByteTensor(y)

if gpu:
X = Variable(X).cuda()
y = Variable(y).cuda()
else:
X = Variable(X)
y = Variable(y)

optimizer.zero_grad()

y_pred = net(X)

loss = criterion(y_pred, y.float())
epoch_loss += loss.data[0]

print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train,
loss.data[0]))

loss.backward()
optimizer.step()

print('Epoch finished ! Loss: {}'.format(epoch_loss/i))

if cp:
torch.save(net.state_dict(),
dir_checkpoint + 'CP{}.pth'.format(epoch+1))

print('Checkpoint {} saved !'.format(epoch+1))


parser = OptionParser()
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int",
help="number of epochs")
parser.add_option("-b", "--batch-size", dest="batchsize", default=10,
type="int", help="batch size")
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1,
type="int", help="learning rate")
parser.add_option("-g", "--gpu", action="store_true", dest="gpu",
default=False, help="use cuda")
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu",
default=False, help="use cuda")


(options, args) = parser.parse_args()

net = UNet(3, 1)
if options.gpu:
net.cuda()

train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu)
1 change: 1 addition & 0 deletions unet_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from unet_parts import *


class UNet(nn.Module):
def __init__(self, n_channels, n_classes):
super(UNet, self).__init__()
Expand Down
9 changes: 7 additions & 2 deletions unet_parts.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch.nn as nn
import torch.nn.functional as F


class double_conv(nn.Module):
def __init__(self, in_ch, out_ch):
super(double_conv, self).__init__()
Expand All @@ -13,10 +14,12 @@ def __init__(self, in_ch, out_ch):
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.ReLU()
)

def forward(self, x):
x = self.conv(x)
return x


class inconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(inconv, self).__init__()
Expand All @@ -26,6 +29,7 @@ def forward(self, x):
x = self.conv(x)
return x


class down(nn.Module):
def __init__(self, in_ch, out_ch):
super(down, self).__init__()
Expand All @@ -38,15 +42,15 @@ def forward(self, x):
x = self.mpconv(x)
return x


class up(nn.Module):
def __init__(self, in_ch, out_ch):
super(up, self).__init__()
self.up = nn.UpsamplingBilinear2d(scale_factor=2)
#self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
# self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
self.conv = double_conv(in_ch, out_ch)

def forward(self, x1, x2):

x1 = self.up(x1)
diffX = x1.size()[2] - x2.size()[2]
diffY = x1.size()[3] - x2.size()[3]
Expand All @@ -56,6 +60,7 @@ def forward(self, x1, x2):
x = self.conv(x)
return x


class outconv(nn.Module):
def __init__(self, in_ch, out_ch):
super(outconv, self).__init__()
Expand Down
42 changes: 35 additions & 7 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,54 @@
import PIL
import numpy as np
import random


def get_square(img, pos):
"""Extract a left or a right square from PILimg"""
"""shape : (H, W, C))"""
"""Extract a left or a right square from PILimg shape : (H, W, C))"""
img = np.array(img)

h = img.shape[0]
w = img.shape[1]

if pos == 0:
return img[:, :h]
else:
return img[:, -h:]

def resize_and_crop(pilimg, scale=0.5, final_height=640):

def resize_and_crop(pilimg, scale=0.2, final_height=None):
w = pilimg.size[0]
h = pilimg.size[1]
newW = int(w * scale)
newH = int(h * scale)
diff = newH - final_height

if not final_height:
diff = 0
else:
diff = newH - final_height

img = pilimg.resize((newW, newH))
img = img.crop((0, diff // 2, newW, newH - diff // 2))
return img


def batch(iterable, batch_size):
"""Yields lists by batch"""
b = []
for i, t in enumerate(iterable):
b.append(t)
if (i+1) % batch_size == 0:
yield b
b = []

if len(b) > 0:
yield b


def split_train_val(dataset, val_percent=0.05):
dataset = list(dataset)
length = len(dataset)
n = int(length * val_percent)
random.shuffle(dataset)
return {'train': dataset[:-n], 'val': dataset[-n:]}


def normalize(x):
return x / 255

0 comments on commit 4063565

Please sign in to comment.