Skip to content

Commit

Permalink
Image scoring (#680)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Jun 13, 2024
1 parent 6d568c9 commit f1f9040
Show file tree
Hide file tree
Showing 28 changed files with 1,036 additions and 6 deletions.
27 changes: 27 additions & 0 deletions configs/image_classification/local.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
task: image_classification
base_model: google/vit-base-patch16-224
project_name: autotrain-image-classification-model
log: tensorboard
backend: local

data:
path: data/
train_split: train # this folder inside data/ will be used for training, it contains the images in subfolders.
valid_split: null
column_mapping:
image_column: image
target_column: labels

params:
epochs: 2
batch_size: 4
lr: 2e-5
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 1
mixed_precision: fp16

hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
27 changes: 27 additions & 0 deletions configs/image_scoring/hub_dataset.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
task: image_regression
base_model: google/vit-base-patch16-224
project_name: autotrain-cats-vs-dogs-finetuned
log: tensorboard
backend: local

data:
path: cats_vs_dogs
train_split: train
valid_split: null
column_mapping:
image_column: image
target_column: labels

params:
epochs: 2
batch_size: 4
lr: 2e-5
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 1
mixed_precision: fp16

hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
28 changes: 28 additions & 0 deletions configs/image_scoring/local.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
task: image_regression
base_model: google/vit-base-patch16-224
project_name: autotrain-image-regression-model
log: tensorboard
backend: local

data:
path: data/
train_split: train # this folder inside data/ will be used for training, it contains the images and metadata.jsonl
valid_split: valid # this folder inside data/ will be used for validation, it contains the images and metadata.jsonl. can be set to null
# column mapping should not be changed for local datasets
column_mapping:
image_column: image
target_column: target

params:
epochs: 2
batch_size: 4
lr: 2e-5
optimizer: adamw_torch
scheduler: linear
gradient_accumulation: 1
mixed_precision: fp16

hub:
username: ${HF_USERNAME}
token: ${HF_TOKEN}
push_to_hub: true
4 changes: 4 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
title: LLM Finetuning
- local: image_classification
title: Image Classification
- local: image_regression
title: Image Scoring/Regression
- local: object_detection
title: Object Detection
- local: dreambooth
Expand All @@ -53,6 +55,8 @@
title: LLM Finetuning
- local: image_classification_params
title: Image Classification
- local: image_regression_params
title: Image Scoring/Regression
- local: object_detection_params
title: Object Detection
- local: dreambooth_params
Expand Down
58 changes: 58 additions & 0 deletions docs/source/image_regression.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Image Scoring/Regression

Image scoring is a form of supervised learning where a model is trained to predict a
score or value for an image. AutoTrain simplifies the process, enabling you to train a
state-of-the-art image scoring model by simply uploading labeled example images.


## Preparing your data

To ensure your image scoring model trains effectively, follow these guidelines for preparing your data:


### Organizing Images


Prepare a zip file containing your images and metadata.jsonl.


```
Archive.zip
├── 0001.png
├── 0002.png
├── 0003.png
├── .
├── .
├── .
└── metadata.jsonl
```

Example for `metadata.jsonl`:

```
{"file_name": "0001.png", "target": 0.5}
{"file_name": "0002.png", "target": 0.7}
{"file_name": "0003.png", "target": 0.3}
```

Please note that metadata.jsonl should contain the `file_name` and the `target` value for each image.


### Image Requirements

- Format: Ensure all images are in JPEG, JPG, or PNG format.

- Quantity: Include at least 5 images to provide the model with sufficient examples for learning.

- Exclusivity: The zip file should exclusively contain images and metadata.jsonl.
No additional files or nested folders should be included.


Some points to keep in mind:

- The images must be jpeg, jpg or png.
- There should be at least 5 images per class.
- There must not be any other files in the zip file.
- There must not be any other folders inside the zip folder.

When train.zip is decompressed, it creates no folders: only images and metadata.jsonl.
3 changes: 3 additions & 0 deletions docs/source/image_regression_params.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Image Scoring/Regression Parameters

The Parameters for image scoring/regression are same as the parameters for image classification.
4 changes: 1 addition & 3 deletions docs/source/object_detection.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ No additional files or nested folders should be included.

Some points to keep in mind:

- The zip file should contain multiple folders (the classes), each folder should contain images of a single class.
- The name of the folder should be the name of the class.
- The images must be jpeg, jpg or png.
- There should be at least 5 images per class.
- There should be at least 5 images per split.
- There must not be any other files in the zip file.
- There must not be any other folders inside the zip folder.

Expand Down
20 changes: 20 additions & 0 deletions src/autotrain/app/api_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
from autotrain.trainers.tabular.params import TabularParams
Expand Down Expand Up @@ -86,6 +87,7 @@ def create_api_base_model(base_class, class_name):
TextRegressionParamsAPI = create_api_base_model(TextRegressionParams, "TextRegressionParamsAPI")
TokenClassificationParamsAPI = create_api_base_model(TokenClassificationParams, "TokenClassificationParamsAPI")
SentenceTransformersParamsAPI = create_api_base_model(SentenceTransformersParams, "SentenceTransformersParamsAPI")
ImageRegressionParamsAPI = create_api_base_model(ImageRegressionParams, "ImageRegressionParamsAPI")


class LLMSFTColumnMapping(BaseModel):
Expand Down Expand Up @@ -122,6 +124,11 @@ class ImageClassificationColumnMapping(BaseModel):
target_column: str


class ImageRegressionColumnMapping(BaseModel):
image_column: str
target_column: str


class Seq2SeqColumnMapping(BaseModel):
text_column: str
target_column: str
Expand Down Expand Up @@ -201,6 +208,7 @@ class APICreateProjectModel(BaseModel):
"text-regression",
"tabular-classification",
"tabular-regression",
"image-regression",
]
base_model: str
hardware: Literal[
Expand Down Expand Up @@ -232,6 +240,7 @@ class APICreateProjectModel(BaseModel):
TextClassificationParamsAPI,
TextRegressionParamsAPI,
TokenClassificationParamsAPI,
ImageRegressionParamsAPI,
]
username: str
column_mapping: Optional[
Expand All @@ -254,6 +263,7 @@ class APICreateProjectModel(BaseModel):
STPairScoreColumnMapping,
STTripletColumnMapping,
STQAColumnMapping,
ImageRegressionColumnMapping,
]
] = None
hub_dataset: str
Expand Down Expand Up @@ -408,6 +418,14 @@ def validate_column_mapping(cls, values):
if not values.get("column_mapping").get("sentence2_column"):
raise ValueError("sentence2_column is required for st:qa")
values["column_mapping"] = STQAColumnMapping(**values["column_mapping"])
elif values.get("task") == "image-regression":
if not values.get("column_mapping"):
raise ValueError("column_mapping is required for image-regression")
if not values.get("column_mapping").get("image_column"):
raise ValueError("image_column is required for image-regression")
if not values.get("column_mapping").get("target_column"):
raise ValueError("target_column is required for image-regression")
values["column_mapping"] = ImageRegressionColumnMapping(**values["column_mapping"])
return values

@model_validator(mode="before")
Expand Down Expand Up @@ -441,6 +459,8 @@ def validate_params(cls, values):
values["params"] = TokenClassificationParamsAPI(**values["params"])
elif values.get("task").startswith("st:"):
values["params"] = SentenceTransformersParamsAPI(**values["params"])
elif values.get("task") == "image-regression":
values["params"] = ImageRegressionParamsAPI(**values["params"])
return values


Expand Down
1 change: 1 addition & 0 deletions src/autotrain/app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ def fetch_models():
_mc["text-classification"] = _fetch_text_classification_models()
_mc["llm"] = _fetch_llm_models()
_mc["image-classification"] = _fetch_image_classification_models()
_mc["image-regression"] = _fetch_image_classification_models()
_mc["dreambooth"] = _fetch_dreambooth_models()
_mc["seq2seq"] = _fetch_seq2seq_models()
_mc["token-classification"] = _fetch_token_classification_models()
Expand Down
37 changes: 37 additions & 0 deletions src/autotrain/app/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autotrain.trainers.clm.params import LLMTrainingParams
from autotrain.trainers.dreambooth.params import DreamBoothTrainingParams
from autotrain.trainers.image_classification.params import ImageClassificationParams
from autotrain.trainers.image_regression.params import ImageRegressionParams
from autotrain.trainers.object_detection.params import ObjectDetectionParams
from autotrain.trainers.sent_transformers.params import SentenceTransformersParams
from autotrain.trainers.seq2seq.params import Seq2SeqParams
Expand Down Expand Up @@ -126,6 +127,10 @@
mixed_precision="fp16",
log="tensorboard",
).model_dump()
PARAMS["image-regression"] = ImageRegressionParams(
mixed_precision="fp16",
log="tensorboard",
).model_dump()


@dataclass
Expand Down Expand Up @@ -168,6 +173,8 @@ def munge(self):
return self._munge_params_text_reg()
elif self.task.startswith("st:"):
return self._munge_params_sent_transformers()
elif self.task == "image-regression":
return self._munge_params_img_reg()
else:
raise ValueError(f"Unknown task: {self.task}")

Expand Down Expand Up @@ -315,6 +322,22 @@ def _munge_params_img_clf(self):

return ImageClassificationParams(**_params)

def _munge_params_img_reg(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
_params["log"] = "tensorboard"
if not self.using_hub_dataset:
_params["image_column"] = "autotrain_image"
_params["target_column"] = "autotrain_label"
_params["valid_split"] = "validation"
else:
_params["image_column"] = self.column_mapping.get("image" if not self.api else "image_column", "image")
_params["target_column"] = self.column_mapping.get("target" if not self.api else "target_column", "target")
_params["train_split"] = self.train_split
_params["valid_split"] = self.valid_split

return ImageRegressionParams(**_params)

def _munge_params_img_obj_det(self):
_params = self._munge_common_params()
_params["model"] = self.base_model
Expand Down Expand Up @@ -511,6 +534,20 @@ def get_task_params(task, param_type):
"early_stopping_threshold",
]
task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
if task == "image-regression" and param_type == "basic":
more_hidden_params = [
"warmup_ratio",
"weight_decay",
"max_grad_norm",
"seed",
"logging_steps",
"auto_find_batch_size",
"save_total_limit",
"evaluation_strategy",
"early_stopping_patience",
"early_stopping_threshold",
]
task_params = {k: v for k, v in task_params.items() if k not in more_hidden_params}
if task == "image-object-detection" and param_type == "basic":
more_hidden_params = [
"warmup_ratio",
Expand Down
5 changes: 5 additions & 0 deletions src/autotrain/app/templates/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@
fields = ['image', 'label'];
fieldNames = ['image', 'label'];
break;
case 'image-regression':
fields = ['image', 'label'];
fieldNames = ['image', 'target'];
break;
case 'image-object-detection':
fields = ['image', 'objects'];
fieldNames = ['image', 'objects'];
Expand Down Expand Up @@ -200,6 +204,7 @@
<optgroup label="Image Tasks">
<option value="dreambooth">DreamBooth LoRA</option>
<option value="image-classification">Image Classification</option>
<option value="image-regression">Image Scoring/Regression</option>
<option value="image-object-detection">Object Detection</option>
</optgroup>
<optgroup label="Tabular Tasks">
Expand Down
13 changes: 13 additions & 0 deletions src/autotrain/app/ui_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
AutoTrainDataset,
AutoTrainDreamboothDataset,
AutoTrainImageClassificationDataset,
AutoTrainImageRegressionDataset,
AutoTrainObjectDetectionDataset,
)
from autotrain.help import get_app_help
Expand Down Expand Up @@ -437,6 +438,8 @@ async def fetch_model_choices(
hub_models = MODEL_CHOICE["text-regression"]
elif task == "image-object-detection":
hub_models = MODEL_CHOICE["image-object-detection"]
elif task == "image-regression":
hub_models = MODEL_CHOICE["image-regression"]
else:
raise NotImplementedError

Expand Down Expand Up @@ -539,6 +542,16 @@ async def handle_form(
percent_valid=None, # TODO: add to UI
local=hardware.lower() == "local-ui",
)
elif task == "image-regression":
dset = AutoTrainImageRegressionDataset(
train_data=training_files[0],
token=token,
project_name=project_name,
username=autotrain_user,
valid_data=validation_files[0] if validation_files else None,
percent_valid=None, # TODO: add to UI
local=hardware.lower() == "local-ui",
)
elif task == "image-object-detection":
dset = AutoTrainObjectDetectionDataset(
train_data=training_files[0],
Expand Down
Loading

0 comments on commit f1f9040

Please sign in to comment.