-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adapted preprocess_fn support for Stage1 and refactored some code #14
Conversation
Did so because I changed the resizing (cv2 to skimage) when training some classification models
Received
|
Also fixed model_name.lower() to only apply when model_name is a string so the function will run properly when model_name
I think we need to change the workflow from preprocess->TTA->predict to TTA->preprocess->predict, as done in |
Inconsistent with Stage2, but it's not really necessary to do so for models in this repository.
…nce.classification Because inference.classification.Stage1 is extremely memory intensive from extensive TTA.
…on code into run_classification_prediction Also, fixed import bug with preprocess_input (from wrong module). I refactored the prediction code into run_classification_prediction to be consistent with segmentation-only's run_segmentation_prediction and also because it's cleaner. The major change is refactoring the prediction process to be done with batched filepaths lists to deal with the memory issues when doing TTA.
Did so for more autonomy over the inference speed
Refactored the save path into the var save_path. Also, updated the rest of the functions to be compatible with TTA_Classification, such as TTA_Classification_All, run_classification_prediction, and Stage1
…der of TTA/preprocessing This was to address the issue where grayscale models were being trained with [TTA->preprocess_input] instead of the inference pipeline's [preprocess_input->TTA]. It's also cleaner this way because np.invert is supposed to only work with integers in [0, 255], not the preprocesed/resized inputs. The current implementation is pretty hacky, so it'll need some refactoring in the future, but it's fine for now.
…s_fn in the arguments Reduced the number of **kwargs occurrences in the functions arguments (only Stage1 has it) by making it so that Stage1 has it made into a partial'd function. Fixed documentation to accomodate and added support in TTA_Classification to handle cases where preprocess_fn is None. Also, fixed a bug where batch_test_fpaths was called instead of test_fpaths_batched.
…ation dataframe If index=True, then there will be an extra column "Unnamed: 0" which causes an error when submitting to kaggle.
Did so because I use it for grayscale classification model inference anyways.
…ng it Did so to make Stage2 less bulky and reduce repeat code.
Kaggle handled NaN's as empty masks anyways, but this just makes the results cleaner.
|
|
That's currently not a priority. I'll raise it in an issue for the future I guess. |
Major Changes
Stage1
intorun_classification_prediction
Stage2
to userun_seg_prediction
like insegmentation_only.SegmentationOnlyInference
ensemble_classification_from_df
now saves the output .csv file usingindex=False
. (a16a482) to prevent an extra column that yields an error when submitting to Kaggle.
segmentation_only.SegmentationOnlyInference
no longer has NaNs. (<- wasn't a fatal bug because Kaggle handles NaNs as empty masks, but it's just cleaner this way).load_pretrained_classification_model
to handle the edge case wheremodel_name="efficientnet"
andpretrained="nih"
better.Additions
inference.classification.Stage1
fpaths_batch_size
,n_tta_iter_per_image
,tta_then_preprocess
,preprocess_fn
,**kwargs
Stage1
in batched lists of filepaths to reduce the memory overhead.preprocess_fn
is now compatible withStage1
.Stage1(tta=True)
for the grayscale classification models that are trained with data augmentation usingtta_then_preprocess
.TTA_Classification
,TTA_Classification_All
io.utils.resize_and_preprocess