Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unet encoder-decoder and image output feature #3913

Merged
merged 1 commit into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions examples/semantic_segmentation/camseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import logging
import os
import shutil

import pandas as pd
import torch
import yaml
from torchvision.utils import save_image

from ludwig.api import LudwigModel
from ludwig.datasets import camseq

# clean out prior results
shutil.rmtree("./results", ignore_errors=True)

# set up Python dictionary to hold model training parameters
with open("./config_camseq.yaml") as f:
config = yaml.safe_load(f.read())

# Define Ludwig model object that drive model training
model = LudwigModel(config, logging_level=logging.INFO)

# load Camseq dataset
df = camseq.load(split=False)

pred_set = df[0:1] # prediction hold-out 1 image
data_set = df[1:] # train,test,validate on remaining images

# initiate model training
(train_stats, _, output_directory) = model.train( # training statistics # location for training results saved to disk
dataset=data_set,
experiment_name="simple_image_experiment",
model_name="single_model",
skip_save_processed_input=True,
)

# print("{}".format(model.model))

# predict
pred_set.reset_index(inplace=True)
pred_out_df, results = model.predict(pred_set)

if not isinstance(pred_out_df, pd.DataFrame):
pred_out_df = pred_out_df.compute()
pred_out_df["image_path"] = pred_set["image_path"]
pred_out_df["mask_path"] = pred_set["mask_path"]

for index, row in pred_out_df.iterrows():
pred_mask = torch.from_numpy(row["mask_path_predictions"])
pred_mask_path = os.path.dirname(os.path.realpath(__file__)) + "/predicted_" + os.path.basename(row["mask_path"])
print(f"\nSaving predicted mask to {pred_mask_path}")
if torch.any(pred_mask.gt(1)):
pred_mask = pred_mask.float() / 255
save_image(pred_mask, pred_mask_path)
print("Input image_path: {}".format(row["image_path"]))
print("Label mask_path: {}".format(row["mask_path"]))
print(f"Predicted mask_path: {pred_mask_path}")
33 changes: 33 additions & 0 deletions examples/semantic_segmentation/config_camseq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
input_features:
- name: image_path
type: image
preprocessing:
num_processes: 6
infer_image_max_height: 1024
infer_image_max_width: 1024
encoder: unet

output_features:
- name: mask_path
type: image
preprocessing:
num_processes: 6
infer_image_max_height: 1024
infer_image_max_width: 1024
infer_image_num_classes: true
num_classes: 32
decoder:
type: unet
num_fc_layers: 0
loss:
type: softmax_cross_entropy

combiner:
type: concat
num_fc_layers: 0

trainer:
epochs: 100
early_stop: -1
batch_size: 1
max_batch_size: 1
2 changes: 2 additions & 0 deletions ludwig/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
INFER_IMAGE_MAX_HEIGHT = "infer_image_max_height"
INFER_IMAGE_MAX_WIDTH = "infer_image_max_width"
INFER_IMAGE_SAMPLE_SIZE = "infer_image_sample_size"
INFER_IMAGE_NUM_CLASSES = "infer_image_num_classes"
IMAGE_MAX_CLASSES = 128
NUM_CLASSES = "num_classes"
NUM_CHANNELS = "num_channels"
REQUIRES_EQUAL_DIMENSIONS = "requires_equal_dimensions"
Expand Down
21 changes: 21 additions & 0 deletions ludwig/datasets/configs/camseq.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
version: 1.0
name: camseq
kaggle_dataset_id: carlolepelaars/camseq-semantic-segmentation
archive_filenames: camseq-semantic-segmentation.zip
sha256:
camseq-semantic-segmentation.zip: ea3aeba2661d9b3e3ea406668e7d9240cb2ba0c7e374914bb6d866147faff502
loader: camseq.CamseqLoader
preserve_paths:
- images
- masks
description: |
CamSeq01 Cambridge Labeled Objects in Video
https://www.kaggle.com/datasets/carlolepelaars/camseq-semantic-segmentation
columns:
- name: image_path
type: image
- name: mask_path
type: image
output_features:
- name: mask_path
type: image
61 changes: 61 additions & 0 deletions ludwig/datasets/loaders/camseq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2023 Aizen Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
from typing import List

import pandas as pd

from ludwig.datasets.loaders.dataset_loader import DatasetLoader
from ludwig.utils.fs_utils import makedirs


class CamseqLoader(DatasetLoader):
def transform_files(self, file_paths: List[str]) -> List[str]:
if not os.path.exists(self.processed_dataset_dir):
os.makedirs(self.processed_dataset_dir)

# move images and masks into separate directories
source_dir = self.raw_dataset_dir
images_dir = os.path.join(source_dir, "images")
masks_dir = os.path.join(source_dir, "masks")
makedirs(images_dir, exist_ok=True)
makedirs(masks_dir, exist_ok=True)

data_files = []
for f in os.listdir(source_dir):
if f.endswith("_L.png"): # masks
dest_file = os.path.join(masks_dir, f)
elif f.endswith(".png"): # images
dest_file = os.path.join(images_dir, f)
else:
continue
source_file = os.path.join(source_dir, f)
os.replace(source_file, dest_file)
data_files.append(dest_file)

return super().transform_files(data_files)

def load_unprocessed_dataframe(self, file_paths: List[str]) -> pd.DataFrame:
"""Creates a dataframe of image paths and mask paths."""
images_dir = os.path.join(self.processed_dataset_dir, "images")
masks_dir = os.path.join(self.processed_dataset_dir, "masks")
images = []
masks = []
for f in os.listdir(images_dir):
images.append(os.path.join(images_dir, f))
mask_f = f[:-4] + "_L.png"
masks.append(os.path.join(masks_dir, mask_f))

return pd.DataFrame({"image_path": images, "mask_path": masks})
1 change: 1 addition & 0 deletions ludwig/decoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# register all decoders
import ludwig.decoders.generic_decoders # noqa
import ludwig.decoders.image_decoders # noqa
import ludwig.decoders.llm_decoders # noqa
import ludwig.decoders.sequence_decoders # noqa
import ludwig.decoders.sequence_tagger # noqa
91 changes: 91 additions & 0 deletions ludwig/decoders/image_decoders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#! /usr/bin/env python
# Copyright (c) 2023 Aizen Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import logging
from typing import Dict, Optional, Type

import torch

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import ENCODER_OUTPUT_STATE, HIDDEN, IMAGE, LOGITS, PREDICTIONS
from ludwig.decoders.base import Decoder
from ludwig.decoders.registry import register_decoder
from ludwig.modules.convolutional_modules import UNetUpStack
from ludwig.schema.decoders.image_decoders import ImageDecoderConfig, UNetDecoderConfig

logger = logging.getLogger(__name__)


@DeveloperAPI
@register_decoder("unet", IMAGE)
class UNetDecoder(Decoder):
def __init__(
self,
input_size: int,
height: int,
width: int,
num_channels: int = 1,
num_classes: int = 2,
conv_norm: Optional[str] = None,
decoder_config=None,
**kwargs,
):
super().__init__()
self.config = decoder_config
self.num_classes = num_classes

logger.debug(f" {self.name}")
if num_classes < 2:
raise ValueError(f"Invalid `num_classes` {num_classes} for unet decoder")
if height % 16 or width % 16:
raise ValueError(f"Invalid `height` {height} or `width` {width} for unet decoder")

self.unet = UNetUpStack(
img_height=height,
img_width=width,
out_channels=num_classes,
norm=conv_norm,
)

self.input_reshape = list(self.unet.input_shape)
self.input_reshape.insert(0, -1)
self._output_shape = (height, width)

def forward(self, combiner_outputs: Dict[str, torch.Tensor], target: torch.Tensor):
hidden = combiner_outputs[HIDDEN]
skips = combiner_outputs[ENCODER_OUTPUT_STATE]

# unflatten combiner outputs
hidden = hidden.reshape(self.input_reshape)

logits = self.unet(hidden, skips)
predictions = logits.argmax(dim=1).squeeze(1).byte()

return {LOGITS: logits, PREDICTIONS: predictions}

def get_prediction_set(self):
return {LOGITS, PREDICTIONS}

@staticmethod
def get_schema_cls() -> Type[ImageDecoderConfig]:
return UNetDecoderConfig

@property
def output_shape(self) -> torch.Size:
return torch.Size(self._output_shape)

@property
def input_shape(self) -> torch.Size:
return self.unet.input_shape
48 changes: 46 additions & 2 deletions ludwig/encoders/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,19 @@
import torch

from ludwig.api_annotations import DeveloperAPI
from ludwig.constants import ENCODER_OUTPUT, IMAGE
from ludwig.constants import ENCODER_OUTPUT, ENCODER_OUTPUT_STATE, IMAGE
from ludwig.encoders.base import Encoder
from ludwig.encoders.registry import register_encoder
from ludwig.encoders.types import EncoderOutputDict
from ludwig.modules.convolutional_modules import Conv2DStack, ResNet
from ludwig.modules.convolutional_modules import Conv2DStack, ResNet, UNetDownStack
from ludwig.modules.fully_connected_modules import FCStack
from ludwig.modules.mlp_mixer_modules import MLPMixer
from ludwig.schema.encoders.image.base import (
ImageEncoderConfig,
MLPMixerConfig,
ResNetConfig,
Stacked2DCNNConfig,
UNetEncoderConfig,
ViTConfig,
)
from ludwig.utils.torch_utils import FreezeModule
Expand Down Expand Up @@ -424,3 +425,46 @@ def input_shape(self) -> torch.Size:
@property
def output_shape(self) -> torch.Size:
return torch.Size(self._output_shape)


@DeveloperAPI
@register_encoder("unet", IMAGE)
class UNetEncoder(ImageEncoder):
def __init__(
self,
height: int,
width: int,
num_channels: int = 3,
conv_norm: Optional[str] = None,
encoder_config=None,
**kwargs,
):
super().__init__()
self.config = encoder_config

logger.debug(f" {self.name}")
if height % 16 or width % 16:
raise ValueError(f"Invalid `height` {height} or `width` {width} for unet encoder")

self.unet = UNetDownStack(
img_height=height,
img_width=width,
in_channels=num_channels,
norm=conv_norm,
)

def forward(self, inputs: torch.Tensor) -> EncoderOutputDict:
hidden, skips = self.unet(inputs)
return {ENCODER_OUTPUT: hidden, ENCODER_OUTPUT_STATE: skips}

@staticmethod
def get_schema_cls() -> Type[ImageEncoderConfig]:
return UNetEncoderConfig

@property
def output_shape(self) -> torch.Size:
return self.unet.output_shape

@property
def input_shape(self) -> torch.Size:
return self.unet.input_shape
3 changes: 2 additions & 1 deletion ludwig/features/feature_registries.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
)
from ludwig.features.date_feature import DateFeatureMixin, DateInputFeature
from ludwig.features.h3_feature import H3FeatureMixin, H3InputFeature
from ludwig.features.image_feature import ImageFeatureMixin, ImageInputFeature
from ludwig.features.image_feature import ImageFeatureMixin, ImageInputFeature, ImageOutputFeature
from ludwig.features.number_feature import NumberFeatureMixin, NumberInputFeature, NumberOutputFeature
from ludwig.features.sequence_feature import SequenceFeatureMixin, SequenceInputFeature, SequenceOutputFeature
from ludwig.features.set_feature import SetFeatureMixin, SetInputFeature, SetOutputFeature
Expand Down Expand Up @@ -108,6 +108,7 @@ def get_output_type_registry() -> Dict:
TIMESERIES: TimeseriesOutputFeature,
VECTOR: VectorOutputFeature,
CATEGORY_DISTRIBUTION: CategoryDistributionOutputFeature,
IMAGE: ImageOutputFeature,
}


Expand Down
Loading
Loading