Skip to content

Commit

Permalink
fixed patch_train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Yazeed Alaudah committed Feb 13, 2019
1 parent 199b007 commit f2f16aa
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 11 deletions.
20 changes: 20 additions & 0 deletions core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import numpy as np
import torch
import torchvision.utils as vutils

def np_to_tb(array):
# if 2D :
if array.ndim == 2:
# HW => CHW
array = np.expand_dims(array,axis=0)
# CHW => NCHW
array = np.expand_dims(array,axis=0)
elif array.ndim == 3:
# HWC => CHW
array = array.transpose(2, 0, 1)
# CHW => NCHW
array = np.expand_dims(array,axis=0)

array = torch.from_numpy(array)
array = vutils.make_grid(array, normalize=True, scale_each=True)
return array
19 changes: 8 additions & 11 deletions patch_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from core.loader.data_loader import *
from core.metrics import runningScore
from core.models import get_model

from core.utils import np_to_tb

def split_train_val(args, per_val=0.1):
# create inline and crossline pacthes for training and validation:
Expand Down Expand Up @@ -205,10 +205,8 @@ def train(args):
tb_original_image, epoch + 1)

labels_original = labels_original.numpy()[0]
correct_label_decoded = train_set.decode_segmap(
np.squeeze(labels_original))
writer.add_image('train/original_label',
correct_label_decoded, epoch + 1)
correct_label_decoded = train_set.decode_segmap(np.squeeze(labels_original))
writer.add_image('train/original_label',np_to_tb(correct_label_decoded), epoch + 1)
out = F.softmax(outputs, dim=1)

# this returns the max. channel number:
Expand All @@ -219,7 +217,7 @@ def train(args):
confidence, normalize=True, scale_each=True)

decoded = train_set.decode_segmap(np.squeeze(prediction))
writer.add_image('train/predicted', decoded, epoch + 1)
writer.add_image('train/predicted', np_to_tb(decoded), epoch + 1)
writer.add_image('train/confidence', tb_confidence, epoch + 1)

unary = outputs.cpu().detach()
Expand All @@ -232,8 +230,7 @@ def train(args):
decoded_channel = unary[0][channel]
tb_channel = vutils.make_grid(
decoded_channel, normalize=True, scale_each=True)
writer.add_image(
f'train_classes/_{class_names[channel]}', tb_channel, epoch + 1)
writer.add_image(f'train_classes/_{class_names[channel]}', tb_channel, epoch + 1)

# Average metrics, and save in writer()
loss_train /= total_iteration
Expand Down Expand Up @@ -283,7 +280,7 @@ def train(args):
correct_label_decoded = train_set.decode_segmap(
np.squeeze(labels_original))
writer.add_image('val/original_label',
correct_label_decoded, epoch + 1)
np_to_tb(correct_label_decoded), epoch + 1)

out = F.softmax(outputs_val, dim=1)

Expand All @@ -296,7 +293,7 @@ def train(args):

decoded = train_set.decode_segmap(
np.squeeze(prediction))
writer.add_image('val/predicted', decoded, epoch + 1)
writer.add_image('val/predicted', np_to_tb(decoded), epoch + 1)
writer.add_image('val/confidence',
tb_confidence, epoch + 1)

Expand Down Expand Up @@ -357,7 +354,7 @@ def train(args):
help='Path to previous saved model to restart from')
parser.add_argument('--clip', nargs='?', type=float, default=0.1,
help='Max norm of the gradients if clipping. Set to zero to disable. ')
parser.add_argument('--per_val', nargs='?', type=float, default=0,
parser.add_argument('--per_val', nargs='?', type=float, default=0.2,
help='percentage of the training data for validation')
parser.add_argument('--stride', nargs='?', type=int, default=50,
help='The vertical and horizontal stride when we are sampling patches from the volume.' +
Expand Down

0 comments on commit f2f16aa

Please sign in to comment.