Skip to content

Commit

Permalink
update README; add "scale_height"; fix resolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Apr 5, 2017
1 parent 4802933 commit f1f7049
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 9 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ bash ./pretrained_models/download_model.sh style_cezanne
```
- Now, let's generate Paul Cézanne style images:
```
DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test th test.lua
DATA_ROOT=./datasets/ae_photos name=style_cezanne_pretrained model=one_direction_test phase=test loadSize=256 fineSize=256 resize_or_crop=``scale_width`` th test.lua
```
The test results will be saved to `./results/style_cezanne_pretrained/latest_test/index.html`.
Please refer to [Model Zoo](#Pre-trained-models) for more pre-trained models.
Expand Down
4 changes: 2 additions & 2 deletions data/dataset.lua
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ end
-- converts a table of samples (and corresponding labels) to a clean tensor
local function tableToOutput(self, dataTable, scalarTable)
local data, scalarLabels, labels
if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' then
assert(#scalarTable == 1)
if opt.resize_or_crop == 'crop' or opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
assert(#scalarTable == 1)
data = torch.Tensor(1,
dataTable[1]:size(1), dataTable[1]:size(2), dataTable[1]:size(3))
data[1]:copy(dataTable[1])
Expand Down
10 changes: 7 additions & 3 deletions data/donkey_folder.lua
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ local function loadSingleImage(path)
end
if iW~=oW then
w1 = math.ceil(torch.uniform(1e-2, iW-oW))
end
end
if iH ~= oH or iW ~= oW then
im = image.crop(im, w1, h1, w1 + oW, h1 + oH)
end
Expand All @@ -136,11 +136,15 @@ local function loadSingleImage(path)
w = math.floor(w/4)*4
local x = math.floor(torch.uniform(0, iW - w))
local y = math.floor(torch.uniform(0, iH - w))
im = image.crop(im, x, y, x+w, y+w)
elseif (opt.resize_or_crop =='scale_width') then
im = image.crop(im, x, y, x+w, y+w)
elseif (opt.resize_or_crop == 'scale_width') then
w = oW
h = torch.floor(iH * oW/iW)
im = image.scale(im, w, h)
elseif (opt.resize_or_crop == 'scale_height') then
h = oH
w = torch.floor(iW * oH / iH)
im = image.scale(im, 448, h)
end

if opt.flip == 1 and torch.uniform() > 0.5 then
Expand Down
4 changes: 2 additions & 2 deletions options.lua
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ local opt_train = {
-- options for test
opt_test = {
DATA_ROOT = '', -- path to images (should have subfolders 'train', 'val', etc)
loadSize = 256, -- scale images to this size
fineSize = 256, -- then crop to this size
loadSize = 128, -- scale images to this size
fineSize = 128, -- then crop to this size
flip = 0, -- horizontal mirroring data augmentation
display = 1, -- display samples while training. 0 = false
display_id = 200, -- display window id.
Expand Down
2 changes: 1 addition & 1 deletion test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ for n = 1, math.floor(opt.how_many) do
visuals = model:GetCurrentVisuals(opt, opt.fineSize)

for i,visual in ipairs(visuals) do
if opt.resize_or_crop == 'scale_width' then
if opt.resize_or_crop == 'scale_width' or opt.resize_or_crop == 'scale_height' then
s1 = nil
s2 = nil
end
Expand Down

0 comments on commit f1f7049

Please sign in to comment.