Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
include timm/models/pruned/*.txt

include timm/models/_pruned/*.txt
include timm/data/_info/*.txt
include timm/data/_info/*.json
56 changes: 47 additions & 9 deletions inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pandas as pd
import torch

from timm.data import create_dataset, create_loader, resolve_data_config
from timm.data import create_dataset, create_loader, resolve_data_config, ImageNetInfo, infer_imagenet_subset
from timm.layers import apply_test_time_pool
from timm.models import create_model
from timm.utils import AverageMeter, setup_default_logging, set_jit_fuser, ParseKwargs
Expand Down Expand Up @@ -46,6 +46,7 @@

_FMT_EXT = {
'json': '.json',
'json-record': '.json',
'json-split': '.json',
'parquet': '.parquet',
'csv': '.csv',
Expand Down Expand Up @@ -122,7 +123,7 @@
scripting_group.add_argument('--aot-autograd', default=False, action='store_true',
help="Enable AOT Autograd support.")

parser.add_argument('--results-dir',type=str, default=None,
parser.add_argument('--results-dir', type=str, default=None,
help='folder for output results')
parser.add_argument('--results-file', type=str, default=None,
help='results filename (relative to results-dir)')
Expand All @@ -134,14 +135,20 @@
metavar='N', help='Top-k to output to CSV')
parser.add_argument('--fullname', action='store_true', default=False,
help='use full sample name in output (not just basename).')
parser.add_argument('--filename-col', default='filename',
parser.add_argument('--filename-col', type=str, default='filename',
help='name for filename / sample name column')
parser.add_argument('--index-col', default='index',
parser.add_argument('--index-col', type=str, default='index',
help='name for output indices column(s)')
parser.add_argument('--output-col', default=None,
parser.add_argument('--label-col', type=str, default='label',
help='name for output indices column(s)')
parser.add_argument('--output-col', type=str, default=None,
help='name for logit/probs output column(s)')
parser.add_argument('--output-type', default='prob',
parser.add_argument('--output-type', type=str, default='prob',
help='output type colum ("prob" for probabilities, "logit" for raw logits)')
parser.add_argument('--label-type', type=str, default='description',
help='type of label to output, one of "none", "name", "description", "detailed"')
parser.add_argument('--include-index', action='store_true', default=False,
help='include the class index in results')
parser.add_argument('--exclude-output', action='store_true', default=False,
help='exclude logits/probs from results, just indices. topk must be set !=0.')

Expand Down Expand Up @@ -237,10 +244,26 @@ def main():
**data_config,
)

to_label = None
if args.label_type in ('name', 'description', 'detail'):
imagenet_subset = infer_imagenet_subset(model)
if imagenet_subset is not None:
dataset_info = ImageNetInfo(imagenet_subset)
if args.label_type == 'name':
to_label = lambda x: dataset_info.index_to_label_name(x)
elif args.label_type == 'detail':
to_label = lambda x: dataset_info.index_to_description(x, detailed=True)
else:
to_label = lambda x: dataset_info.index_to_description(x)
to_label = np.vectorize(to_label)
else:
_logger.error("Cannot deduce ImageNet subset from model, no labelling will be performed.")

top_k = min(args.topk, args.num_classes)
batch_time = AverageMeter()
end = time.time()
all_indices = []
all_labels = []
all_outputs = []
use_probs = args.output_type == 'prob'
with torch.no_grad():
Expand All @@ -254,7 +277,12 @@ def main():

if top_k:
output, indices = output.topk(top_k)
all_indices.append(indices.cpu().numpy())
np_indices = indices.cpu().numpy()
if args.include_index:
all_indices.append(np_indices)
if to_label is not None:
np_labels = to_label(np_indices)
all_labels.append(np_labels)

all_outputs.append(output.cpu().numpy())

Expand All @@ -267,6 +295,7 @@ def main():
batch_idx, len(loader), batch_time=batch_time))

all_indices = np.concatenate(all_indices, axis=0) if all_indices else None
all_labels = np.concatenate(all_labels, axis=0) if all_labels else None
all_outputs = np.concatenate(all_outputs, axis=0).astype(np.float32)
filenames = loader.dataset.filenames(basename=not args.fullname)

Expand All @@ -276,13 +305,20 @@ def main():
if all_indices is not None:
for i in range(all_indices.shape[-1]):
data_dict[f'{args.index_col}_{i}'] = all_indices[:, i]
if all_labels is not None:
for i in range(all_labels.shape[-1]):
data_dict[f'{args.label_col}_{i}'] = all_labels[:, i]
for i in range(all_outputs.shape[-1]):
data_dict[f'{output_col}_{i}'] = all_outputs[:, i]
else:
if all_indices is not None:
if all_indices.shape[-1] == 1:
all_indices = all_indices.squeeze(-1)
data_dict[args.index_col] = list(all_indices)
if all_labels is not None:
if all_labels.shape[-1] == 1:
all_labels = all_labels.squeeze(-1)
data_dict[args.label_col] = list(all_labels)
if all_outputs.shape[-1] == 1:
all_outputs = all_outputs.squeeze(-1)
data_dict[output_col] = list(all_outputs)
Expand All @@ -291,7 +327,7 @@ def main():

results_filename = args.results_file
if results_filename:
filename_no_ext, ext = os.path.splitext(results_filename)[-1]
filename_no_ext, ext = os.path.splitext(results_filename)
if ext and ext in _FMT_EXT.values():
# if filename provided with one of expected ext,
# remove it as it will be added back
Expand All @@ -308,14 +344,16 @@ def main():
save_results(df, results_filename, fmt)

print(f'--result')
print(json.dumps(dict(filename=results_filename)))
print(df.set_index(args.filename_col).to_json(orient='index', indent=4))


def save_results(df, results_filename, results_format='csv', filename_col='filename'):
results_filename += _FMT_EXT[results_format]
if results_format == 'parquet':
df.set_index(filename_col).to_parquet(results_filename)
elif results_format == 'json':
df.set_index(filename_col).to_json(results_filename, indent=4, orient='index')
elif results_format == 'json-records':
df.to_json(results_filename, lines=True, orient='records')
elif results_format == 'json-split':
df.to_json(results_filename, indent=4, orient='split', index=False)
Expand Down
2 changes: 2 additions & 0 deletions timm/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from .constants import *
from .dataset import ImageDataset, IterableImageDataset, AugMixDataset
from .dataset_factory import create_dataset
from .dataset_info import DatasetInfo
from .imagenet_info import ImageNetInfo, infer_imagenet_subset
from .loader import create_loader
from .mixup import Mixup, FastCollateMixup
from .readers import create_reader
Expand Down
Loading