Skip to content

Commit

Permalink
updated mini-imagenet dataloader so that data can be in train/val/tes…
Browse files Browse the repository at this point in the history
…t subfolders
  • Loading branch information
lmzintgraf committed Jul 8, 2019
1 parent aad1896 commit 32c0620
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 27 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -17,4 +17,5 @@ saves/
logs/
plots/
results/
*result_plots
*result_files
5 changes: 3 additions & 2 deletions classification/arguments.py
Expand Up @@ -38,13 +38,14 @@ def parse_args():

#

parser.add_argument('--data_path', type=str, default='./data', help='folder which contains image data')
parser.add_argument('--rerun', action='store_true',
parser.add_argument('--data_path', type=str, default='./data/miniimagenet/', help='folder which contains image data')
parser.add_argument('--rerun', action='store_true', default=False,
help='Re-run experiment (will override previously saved results)')

args = parser.parse_args()

# use the GPU if available
args.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print('Running on device: {}'.format(args.device))

return args
80 changes: 57 additions & 23 deletions classification/dataset_miniimagenet.py
Expand Up @@ -15,7 +15,9 @@ class MiniImagenet(Dataset):
"""
put mini-imagenet files as :
root :
|- images/*.jpg includes all imgeas
|- train/*.jpg
|- test/*.jpg
|- val/*.jpg
|- train.csv
|- test.csv
|- val.csv
Expand Down Expand Up @@ -55,18 +57,38 @@ def __init__(self, mode, batchsz, n_way, k_shot, k_query, imsize, data_path, sta
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

self.path_images = os.path.join(data_path, 'miniimagenet', 'images') # image path
self.path_preprocessed = os.path.join(data_path, 'miniimagenet',
'images_preprocessed') # preprocessed image path
if not os.path.exists(self.path_preprocessed):
os.mkdir(self.path_preprocessed)

csvdata = [self.loadCSV(os.path.join(data_path, 'miniimagenet', mode + '.csv'))] # csv path
# check if images are all in one folder or separated into train/val/test folders
if os.path.exists(os.path.join(data_path, 'images')):
self.subfolder_split = False
self.path_images = os.path.join(data_path, 'images') # image path
self.path_preprocessed = os.path.join(data_path, 'images_preprocessed') # preprocessed image path
elif os.path.exists(os.path.join(data_path, 'train')):
self.subfolder_split = True
self.path_images = os.path.join(data_path, mode)
self.path_preprocessed = os.path.join(data_path, 'images_preprocessed')
if not os.path.exists(self.path_preprocessed):
os.mkdir(self.path_preprocessed)
self.path_preprocessed = os.path.join(data_path, 'images_preprocessed', mode)
if not os.path.exists(self.path_preprocessed):
os.mkdir(self.path_preprocessed)
else:
raise FileNotFoundError('Mini-Imagenet data not found. '
'Please add images in one of the following folder structures:'
'./data/miniimagenet/images'
'./data/miniimagenet/{train}{test}{val}'
'or specify --data_path in the arguments.'
)

csvdata = [self.loadCSV(os.path.join(data_path, mode + '.csv'))] # csv path

# check if we have the images
if not os.listdir(self.path_images):
raise FileNotFoundError('Mini-Imagenet data not found. Please put the images in the folder '
'./data/miniimagenet/images or specify --data_path in the arguments.')
raise FileNotFoundError('Mini-Imagenet data not found. '
'Please add images in one of the following folder structures:'
'./data/miniimagenet/images'
'./data/miniimagenet/{train}{test}{val}'
'or specify --data_path in the arguments.'
)

self.data = []
self.img2label = {}
Expand Down Expand Up @@ -158,22 +180,34 @@ def __getitem__(self, index):
query_y_relative[query_y == l] = idx

# pre-process the images and save as numpy arrays (makes the code run much faster afterwards)
# - for the support set
for i, filename in enumerate(filenames_support_x):
filename_preprocesses = filename[:-4] + '_preprocesses_{}'.format(self.imsize)
path_preprocesses = os.path.join(self.path_preprocessed, filename_preprocesses)
if not os.path.exists(path_preprocesses + '.npy'):
support_x[i] = self.transform(os.path.join(self.path_images, filename))
np.save(path_preprocesses, support_x[i].numpy())
filename_preprocessed = filename[:-4] + '_preprocessed_{}'.format(self.imsize)
path_preprocessed = os.path.join(self.path_preprocessed, filename[:9], filename_preprocessed)
if not os.path.exists(path_preprocessed + '.npy'):
if not os.path.exists(os.path.join(self.path_preprocessed, filename[:9])):
os.mkdir(os.path.join(self.path_preprocessed, filename[:9]))
if self.subfolder_split:
support_x[i] = self.transform(os.path.join(self.path_images, filename[:9], filename))
else:
support_x[i] = self.transform(os.path.join(self.path_images, filename))
np.save(path_preprocessed, support_x[i].numpy())
else:
support_x[i] = torch.from_numpy(np.load(path_preprocesses + '.npy'))
for i, path in enumerate(filenames_query_x):
filename_preprocesses = path[:-4] + '_preprocesses_{}'.format(self.imsize)
path_preprocesses = os.path.join(self.path_preprocessed, filename_preprocesses)
if not os.path.exists(path_preprocesses + '.npy'):
query_x[i] = self.transform(os.path.join(self.path_images, path))
np.save(path_preprocesses, query_x[i].numpy())
support_x[i] = torch.from_numpy(np.load(path_preprocessed + '.npy'))
# - same thing for the query set
for i, filename in enumerate(filenames_query_x):
filename_preprocessed = filename[:-4] + '_preprocessed_{}'.format(self.imsize)
path_preprocessed = os.path.join(self.path_preprocessed, filename[:9], filename_preprocessed)
if not os.path.exists(path_preprocessed + '.npy'):
if not os.path.exists(os.path.join(self.path_preprocessed, filename[:9])):
os.mkdir(os.path.join(self.path_preprocessed, filename[:9]))
if self.subfolder_split:
query_x[i] = self.transform(os.path.join(self.path_images, filename[:9], filename))
else:
query_x[i] = self.transform(os.path.join(self.path_images, filename))
np.save(path_preprocessed, query_x[i].numpy())
else:
query_x[i] = torch.from_numpy(np.load(path_preprocesses + '.npy'))
query_x[i] = torch.from_numpy(np.load(path_preprocessed + '.npy'))

return support_x, torch.LongTensor(support_y_relative), query_x, torch.LongTensor(query_y_relative)

Expand Down
4 changes: 2 additions & 2 deletions classification/main.py
Expand Up @@ -239,7 +239,8 @@ def evaluate(iter_counter, args, model, logger, dataloader, save_path):
path = os.path.join(utils.get_base_path(), 'result_files', utils.get_path_from_args(args))
log_interval = 100

if not os.path.exists(path + '.npy') or args.rerun:
if (not os.path.exists(path + '.npy')) or args.rerun:
print('Starting experiment. Logging under filename {}'.format(path + '.npy'))
run(args, num_workers=1, log_interval=log_interval, save_path=path)
else:
print('Found results in {}. If you want to re-run, use the argument --rerun'.format(path))
Expand Down Expand Up @@ -274,7 +275,6 @@ def evaluate(iter_counter, args, model, logger, dataloader, save_path):
plt.ylim([0, 1.01])
plt.xlim([0, 60000])


title = 'k={}, cfilt={}, init={}, #t={}, lr={}-{}, ' \
'grad={}-{} phi={} ({}) #f={} i={} seed={}'.format(args.k_shot,
args.num_filters,
Expand Down

0 comments on commit 32c0620

Please sign in to comment.