-
Notifications
You must be signed in to change notification settings - Fork 2.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Created a basic train loop + changed a bit loss and utils
- Loading branch information
Showing
8 changed files
with
195 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,6 @@ | ||
*.pyc | ||
data/ | ||
__pycache__/ | ||
checkpoints/ | ||
*.pth | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def plot_img_mask(img, mask): | ||
fig = plt.figure() | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,42 @@ | ||
|
||
# | ||
# load.py : utils on generators / lists of ids to transform from strings to | ||
# cropped images and masks | ||
|
||
import os | ||
import random | ||
import numpy as np | ||
|
||
from PIL import Image | ||
from functools import partial | ||
from utils import resize_and_crop, get_square | ||
from utils import resize_and_crop, get_square, normalize | ||
|
||
|
||
def get_ids(dir): | ||
"""Returns a list of the ids in the directory""" | ||
return (f[:-4] for f in os.listdir(dir)) | ||
|
||
|
||
def split_ids(ids, n=2): | ||
"""Split each id in n, creating n tuples (id, k) for each id""" | ||
return ((id, i) for i in range(n) for id in ids) | ||
|
||
def shuffle_ids(ids): | ||
"""Returns a shuffle list od the ids""" | ||
lst = list(ids) | ||
random.shuffle(lst) | ||
return lst | ||
|
||
def to_cropped_imgs(ids, dir, suffix): | ||
"""From a list of tuples, returns the correct cropped img (left or right)""" | ||
"""From a list of tuples, returns the correct cropped img""" | ||
for id, pos in ids: | ||
im = resize_and_crop(Image.open(dir + id + suffix)) | ||
yield get_square(im, pos) | ||
|
||
|
||
|
||
def get_imgs_and_masks(): | ||
"""From the list of ids, return the couples (img, mask)""" | ||
dir_img = 'data/train/' | ||
dir_mask = 'data/train_masks/' | ||
|
||
ids = get_ids(dir_img) | ||
ids = split_ids(ids) | ||
ids = shuffle_ids(ids) | ||
def get_imgs_and_masks(ids, dir_img, dir_mask): | ||
"""Return all the couples (img, mask)""" | ||
|
||
imgs = to_cropped_imgs(ids, dir_img, '.jpg') | ||
|
||
# need to transform from HWC to CHW | ||
imgs_switched = map(partial(np.transpose, axes=[2, 0, 1]), imgs) | ||
imgs_normalized = map(normalize, imgs_switched) | ||
|
||
masks = to_cropped_imgs(ids, dir_mask, '_mask.gif') | ||
|
||
return zip(imgs_switched, masks) | ||
return zip(imgs_normalized, masks) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import torch | ||
|
||
from load import * | ||
from data_vis import * | ||
from utils import split_train_val, batch | ||
from myloss import DiceLoss | ||
from unet_model import UNet | ||
from torch.autograd import Variable | ||
from torch import optim | ||
from optparse import OptionParser | ||
|
||
|
||
def train_net(net, epochs=5, batch_size=2, lr=0.1, val_percent=0.05, | ||
cp=True, gpu=False): | ||
dir_img = 'data/train/' | ||
dir_mask = 'data/train_masks/' | ||
dir_checkpoint = 'checkpoints/' | ||
|
||
# get ids | ||
ids = get_ids(dir_img) | ||
ids = split_ids(ids) | ||
|
||
iddataset = split_train_val(ids, val_percent) | ||
|
||
print(''' | ||
Starting training: | ||
Epochs: {} | ||
Batch size: {} | ||
Learning rate: {} | ||
Training size: {} | ||
Validation size: {} | ||
Checkpoints: {} | ||
CUDA: {} | ||
'''.format(epochs, batch_size, lr, len(iddataset['train']), | ||
len(iddataset['val']), str(cp), str(gpu))) | ||
|
||
N_train = len(iddataset['train']) | ||
|
||
train = get_imgs_and_masks(iddataset['train'], dir_img, dir_mask) | ||
val = get_imgs_and_masks(iddataset['val'], dir_img, dir_mask) | ||
|
||
optimizer = optim.Adam(net.parameters(), lr=lr) | ||
criterion = DiceLoss() | ||
|
||
for epoch in range(epochs): | ||
print('Starting epoch {}/{}.'.format(epoch+1, epochs)) | ||
|
||
epoch_loss = 0 | ||
|
||
for i, b in enumerate(batch(train, batch_size)): | ||
X = np.array([i[0] for i in b]) | ||
y = np.array([i[1] for i in b]) | ||
|
||
X = torch.FloatTensor(X) | ||
y = torch.ByteTensor(y) | ||
|
||
if gpu: | ||
X = Variable(X).cuda() | ||
y = Variable(y).cuda() | ||
else: | ||
X = Variable(X) | ||
y = Variable(y) | ||
|
||
optimizer.zero_grad() | ||
|
||
y_pred = net(X) | ||
|
||
loss = criterion(y_pred, y.float()) | ||
epoch_loss += loss.data[0] | ||
|
||
print('{0:.4f} --- loss: {1:.6f}'.format(i*batch_size/N_train, | ||
loss.data[0])) | ||
|
||
loss.backward() | ||
optimizer.step() | ||
|
||
print('Epoch finished ! Loss: {}'.format(epoch_loss/i)) | ||
|
||
if cp: | ||
torch.save(net.state_dict(), | ||
dir_checkpoint + 'CP{}.pth'.format(epoch+1)) | ||
|
||
print('Checkpoint {} saved !'.format(epoch+1)) | ||
|
||
|
||
parser = OptionParser() | ||
parser.add_option("-e", "--epochs", dest="epochs", default=5, type="int", | ||
help="number of epochs") | ||
parser.add_option("-b", "--batch-size", dest="batchsize", default=10, | ||
type="int", help="batch size") | ||
parser.add_option("-l", "--learning-rate", dest="lr", default=0.1, | ||
type="int", help="learning rate") | ||
parser.add_option("-g", "--gpu", action="store_true", dest="gpu", | ||
default=False, help="use cuda") | ||
parser.add_option("-n", "--ngpu", action="store_false", dest="gpu", | ||
default=False, help="use cuda") | ||
|
||
|
||
(options, args) = parser.parse_args() | ||
|
||
net = UNet(3, 1) | ||
if options.gpu: | ||
net.cuda() | ||
|
||
train_net(net, options.epochs, options.batchsize, options.lr, gpu=options.gpu) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,26 +1,54 @@ | ||
import PIL | ||
import numpy as np | ||
import random | ||
|
||
|
||
def get_square(img, pos): | ||
"""Extract a left or a right square from PILimg""" | ||
"""shape : (H, W, C))""" | ||
"""Extract a left or a right square from PILimg shape : (H, W, C))""" | ||
img = np.array(img) | ||
|
||
h = img.shape[0] | ||
w = img.shape[1] | ||
|
||
if pos == 0: | ||
return img[:, :h] | ||
else: | ||
return img[:, -h:] | ||
|
||
def resize_and_crop(pilimg, scale=0.5, final_height=640): | ||
|
||
def resize_and_crop(pilimg, scale=0.2, final_height=None): | ||
w = pilimg.size[0] | ||
h = pilimg.size[1] | ||
newW = int(w * scale) | ||
newH = int(h * scale) | ||
diff = newH - final_height | ||
|
||
if not final_height: | ||
diff = 0 | ||
else: | ||
diff = newH - final_height | ||
|
||
img = pilimg.resize((newW, newH)) | ||
img = img.crop((0, diff // 2, newW, newH - diff // 2)) | ||
return img | ||
|
||
|
||
def batch(iterable, batch_size): | ||
"""Yields lists by batch""" | ||
b = [] | ||
for i, t in enumerate(iterable): | ||
b.append(t) | ||
if (i+1) % batch_size == 0: | ||
yield b | ||
b = [] | ||
|
||
if len(b) > 0: | ||
yield b | ||
|
||
|
||
def split_train_val(dataset, val_percent=0.05): | ||
dataset = list(dataset) | ||
length = len(dataset) | ||
n = int(length * val_percent) | ||
random.shuffle(dataset) | ||
return {'train': dataset[:-n], 'val': dataset[-n:]} | ||
|
||
|
||
def normalize(x): | ||
return x / 255 |