Skip to content

Commit

Permalink
Adding coordinates mapping output in pytorch extension
Browse files Browse the repository at this point in the history
  • Loading branch information
lnstadrum committed Mar 27, 2024
1 parent 035cb3f commit e82f65f
Show file tree
Hide file tree
Showing 7 changed files with 174 additions and 33 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -254,5 +254,4 @@ Mixup:
# Roadmap

* Test in a multi-GPU setup
* Extend to object detection: enable bounding boxes / keypoints transformation
* Extend to semantic segmentation: enable nearest-neighbor resampling for segmentation masks
5 changes: 2 additions & 3 deletions pytorch/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from torchvision.datasets import Flowers102
from torchvision.transforms import PILToTensor
from torch.utils.data import DataLoader
from fast_augment_torch import CenterCrop, FastAugment


Expand All @@ -24,10 +23,10 @@
x = []
y = []
for _ in range(20):
image, label = dataset[random.randint(0, len(dataset))]
image, label = dataset[random.randint(0, len(dataset) - 1)]
x.append(crop(image.cuda().permute(1, 2, 0).contiguous()))
y.append(
torch.nn.functional.one_hot(torch.LongTensor([label - 1]), num_classes=102)
torch.nn.functional.one_hot(torch.LongTensor([label]), num_classes=102)
)

x = torch.stack(x, dim=0)
Expand Down
47 changes: 30 additions & 17 deletions pytorch/fast_augment_torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(self, size, translation=0):
self.size = size

def __call__(self, x):
x, _ = self.backend(
x, _, _ = self.backend(
x, _empty_tensor, output_size=self.size, is_float32_output=False
)
return x
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
color_inversion=False,
flip_horizontally=True,
flip_vertically=False,
seed=0,
seed=0
):
"""Creates a FastAugment object used to apply a set of random geometry and
color transformations to batches of images.
Expand Down Expand Up @@ -208,7 +208,7 @@ def __init__(
color_inversion=color_inversion,
flip_horizontally=flip_horizontally,
flip_vertically=flip_vertically,
seed=seed,
seed=seed
)

def set_seed(self, seed: int):
Expand All @@ -219,30 +219,43 @@ def set_seed(self, seed: int):
"""
self.backend.set_seed(seed)

def __call__(self, x, y=None, output_size=None, output_type=torch.float32):
def __call__(self, x, y=None, output_size=None, output_type=torch.float32, output_mapping=False):
"""Applies a sequence of random transformations to images in a batch.
Args:
x: A `Tensor` of `uint8` or `float32` type containing an input
image or batch in channels-last layout (`HWC` or `NHWC`).
3-channel color images are expected (`C=3`).
y: A `Tensor` of `float32` type containing input labels in
one-hot format. Its outermost dimension is expected to match
the batch size. Optional, can be empty or None.
output_size: A list `[W, H]` specifying the output batch width and height
in pixels. If none, the input size is kept (default).
output_type: Output image datatype. Can be `float32` or `uint8`.
Default: `float32`.
x: A `Tensor` of `uint8` or `float32` type containing an input
image or batch in channels-last layout (`HWC` or `NHWC`).
3-channel color images are expected (`C=3`).
y: A `Tensor` of `float32` type containing input labels in
one-hot format. Its outermost dimension is expected to match
the batch size. Optional, can be empty or None.
output_size: A list `[W, H]` specifying the output batch width and height
in pixels. If none, the input size is kept (default).
output_type: Output image datatype. Can be `float32` or `uint8`.
Default: `float32`.
output_mapping: If `True`, the applied transformations are given as the
last output argument. These are 3x3 matrices mapping input
homogeneous coordinates in pixels to output coordinates in
pixels.
"""
if output_type not in [torch.uint8, torch.float32]:
raise ValueError(f"Unsupported output type: {output_type}")

x_, y_ = self.backend(
x_, y_, mapping = self.backend(
input=x,
input_labels=_empty_tensor if y is None else y,
output_size=output_size or [],
is_float32_output=output_type == torch.float32,
output_mapping=output_mapping
)
if y is None:

if y is None and not output_mapping:
return x_
return x_, y_

outputs = (x_,)
if y is not None:
outputs += (y_,)
if output_mapping:
outputs += (mapping,)

return outputs
37 changes: 36 additions & 1 deletion pytorch/test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from fast_augment_torch import FastAugment, BYPASS_PARAMS
import numpy
import torch
import tempfile
import unittest


Expand Down Expand Up @@ -229,5 +228,41 @@ def test_uint8_vs_float32(self):
self.assertTrue(torch.equal(output_batch_ref, output_batch_test))


class CoordinatesMappingTest(unittest.TestCase):
def test_coordinates_mapping(self):
# generate random batch of zeros with a bright spot at a known position
input_batch = torch.zeros((30, 120, 250, 3), dtype=torch.uint8).cuda()
y, x = 28, 222
input_batch[:, y-2:y+2, x-2:x+2, :] = 255

# perform augmentation
augment = FastAugment(gamma_corr=0,
brightness=0,
hue=0,
saturation=0,
mixup=0,
cutout=0,
translation=0.1,
rotation=30,
scale=0.2,
perspective=15,
flip_horizontally=True,
flip_vertically=True,
prescale=2.0)
output_batch, mappings = augment(input_batch,
output_type=torch.uint8,
output_mapping=True,
output_size=(400, 400))

# get coordinates of the spot in the augmented images
coords = torch.matmul(mappings, torch.tensor([x, y, 1], dtype=torch.float32).t())
coords = (coords[:, :2] / coords[:, 2:3]).round().to(torch.int32).numpy()

# make sure it is in the output images
for image, (x, y) in zip(output_batch, coords):
if x >= 0 and x < image.shape[-2] and y >= 0 and y < image.shape[-3]:
self.assertEqual(image[y, x, 0], 255)


if __name__ == "__main__":
unittest.main()
93 changes: 89 additions & 4 deletions src/kernel_base.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
protected:
std::default_random_engine rnd;

KernelBase(): rnd(randomDevice()) {}
KernelBase() : rnd(randomDevice())
{
}

static inline void reportCudaError(cudaError_t status, const std::string &message)
{
Expand Down Expand Up @@ -119,6 +121,7 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
* in host memory
* @param outputLabelsPtr Pointer to the output class probabilities tensor
* in host memory
* @param outputMappingPtr Pointer to the output homography tensor in host memory
* @param batchSize Batch size; 0 if 3-dimensional input tensor is
* given
* @param inputHeight Input batch height in pixels
Expand All @@ -132,8 +135,9 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
*/
template <typename in_t, typename out_t>
void run(const Settings &settings, const in_t *inputPtr, out_t *outputPtr, const float *inputLabelsPtr,
float *outputLabelsPtr, int64_t batchSize, int64_t inputHeight, int64_t inputWidth, int64_t outputHeight,
int64_t outputWidth, int64_t numClasses, cudaStream_t stream, BufferAllocationArgs... allocationArgs)
float *outputLabelsPtr, float *outputMappingPtr, int64_t batchSize, int64_t inputHeight,
int64_t inputWidth, int64_t outputHeight, int64_t outputWidth, int64_t numClasses, cudaStream_t stream,
BufferAllocationArgs... allocationArgs)
{
// correct batchSize value (can be zero if input is a 3-dim tensor)
const bool isBatch = batchSize > 0;
Expand Down Expand Up @@ -278,6 +282,86 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
outLabel[i] = (1 - f) * inLabel[i] + f * mixLabel[i];
}
}

// fill output mapping tensor
if (outputMappingPtr)
{
float *ptr = outputMappingPtr;
for (size_t i = 0; i < paramsCpu.size(); ++i, ptr += 9)
{
// compute homography in normalized coordinates following the kernel implementation
const auto &a = paramsCpu[i].geom;
ptr[0] = 2.0f * (a[1][1] * a[2][2] - a[1][2] * a[2][1]);
ptr[1] = -2.0f * (a[1][0] * a[2][2] - a[1][2] * a[2][0]);
ptr[2] = a[2][2] * (a[1][0] - a[1][1]) + a[1][2] * (a[2][1] - a[2][0]);

ptr[3] = -2.0f * (a[0][1] * a[2][2] - a[0][2] * a[2][1]);
ptr[4] = 2.0f * (a[0][0] * a[2][2] - a[0][2] * a[2][0]);
ptr[5] = a[2][2] * (a[0][1] - a[0][0]) + a[0][2] * (a[2][0] - a[2][1]);

ptr[6] = 2.0f * (a[0][1] * a[1][2] - a[0][2] * a[1][1]);
ptr[7] = -2.0f * (a[0][0] * a[1][2] - a[0][2] * a[1][0]);
ptr[8] =
2.0f * (a[0][0] * a[1][1] * a[2][2] - a[0][0] * a[1][2] * a[2][1] - a[0][1] * a[1][0] * a[2][2] +
a[0][1] * a[1][2] * a[2][0] + a[0][2] * a[1][0] * a[2][1] - a[0][2] * a[1][1] * a[2][0]) +
a[0][2] * (a[1][1] - a[1][0]) + a[1][2] * (a[0][0] - a[0][1]);

// take into account flipping
if (paramsCpu[i].flags & FLAG_HORIZONTAL_FLIP)
{
ptr[2] += ptr[0];
ptr[0] = -ptr[0];
ptr[5] += ptr[3];
ptr[3] = -ptr[3];
ptr[8] += ptr[6];
ptr[6] = -ptr[6];
}

if (paramsCpu[i].flags & FLAG_VERTICAL_FLIP)
{
ptr[2] += ptr[1];
ptr[1] = -ptr[1];
ptr[5] += ptr[4];
ptr[4] = -ptr[4];
ptr[8] += ptr[7];
ptr[7] = -ptr[7];
}

// use input pixel coordinates
ptr[0] /= inputWidth;
ptr[1] /= inputHeight;
ptr[2] += 0.5f * (ptr[0] + ptr[1]);
ptr[3] /= inputWidth;
ptr[4] /= inputHeight;
ptr[5] += 0.5f * (ptr[3] + ptr[4]);
ptr[6] /= inputWidth;
ptr[7] /= inputHeight;
ptr[8] += 0.5f * (ptr[6] + ptr[7]);

// take into account the random translation
const float *translation = paramsCpu[i].translation;
ptr[0] -= (translation[0] - 0.5f) * ptr[6];
ptr[1] -= (translation[0] - 0.5f) * ptr[7];
ptr[2] -= (translation[0] - 0.5f) * ptr[8];
ptr[3] -= (translation[1] - 0.5f) * ptr[6];
ptr[4] -= (translation[1] - 0.5f) * ptr[7];
ptr[5] -= (translation[1] - 0.5f) * ptr[8];

// use output pixel coordinates
ptr[0] *= outputWidth;
ptr[1] *= outputWidth;
ptr[2] *= outputWidth;
ptr[3] *= outputHeight;
ptr[4] *= outputHeight;
ptr[5] *= outputHeight;
ptr[0] -= 0.5f * ptr[6];
ptr[1] -= 0.5f * ptr[7];
ptr[2] -= 0.5f * ptr[8];
ptr[3] -= 0.5f * ptr[6];
ptr[4] -= 0.5f * ptr[7];
ptr[5] -= 0.5f * ptr[8];
}
}
}

public:
Expand All @@ -286,7 +370,8 @@ template <class TempGPUBuffer, typename... BufferAllocationArgs> class KernelBas
*
* @param seed the seed value
*/
void setRandomSeed(int seed) {
void setRandomSeed(int seed)
{
rnd.seed(seed);
}
};
Expand Down
20 changes: 15 additions & 5 deletions src/pytorch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,8 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
}

std::vector<torch::Tensor> operator()(const torch::Tensor &input, const torch::Tensor &labels,
const std::vector<int64_t> &outputSize, bool isFloat32Output)
const std::vector<int64_t> &outputSize, bool isFloat32Output,
bool outputMapping)
{
// check output size
if (!outputSize.empty() && outputSize.size() != 2)
Expand Down Expand Up @@ -203,11 +204,20 @@ class TorchKernel : public fastaugment::KernelBase<TorchTempGPUBuffer, c10::cuda
torch::Tensor output = torch::empty(outputShape, outputOptions);
torch::Tensor outputLabels = torch::empty_like(labels);

torch::Tensor mapping;
if (outputMapping)
{
auto opts = torch::TensorOptions().dtype(torch::kFloat32);
mapping = torch::empty({batchSize, 3, 3}, opts);
}
auto outputMappingPtr = outputMapping ? mapping.expect_contiguous()->data_ptr<float>() : nullptr;

// launch the kernel
launchKernel(input, output, inputLabelsPtr, outputLabels.data_ptr<float>(), batchSize, inputHeight, inputWidth,
outputHeight, outputWidth, noLabels ? 0 : labels.size(1), stream.stream(), stream);
launchKernel(input, output, inputLabelsPtr, outputLabels.data_ptr<float>(), outputMappingPtr, batchSize,
inputHeight, inputWidth, outputHeight, outputWidth, noLabels ? 0 : labels.size(1), stream.stream(),
stream);

return {output, outputLabels};
return {output, outputLabels, mapping};
}
};

Expand All @@ -227,5 +237,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, module)
.def("set_seed", &TorchKernel::setRandomSeed, py::arg("seed"))

.def("__call__", &TorchKernel::operator(), py::arg("input"), py::arg("input_labels"), py::arg("output_size"),
py::arg("is_float32_output"));
py::arg("is_float32_output"), py::arg("output_mapping") = false);
}
4 changes: 2 additions & 2 deletions src/tensorflow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,8 @@ class FastAugmentTFOpKernel : public OpKernel,
{
fastaugment::KernelBase<TFTempGPUBuffer, OpKernelContext *>::run<in_t, out_t>(
*this, inputTensor.flat<in_t>().data(), outputTensor->flat<out_t>().data(), inputLabelsPtr,
outputLabelsTensor->flat<float>().data(), batchSize, inputHeight, inputWidth, outputHeight, outputWidth,
noLabels ? 0 : labelsShape.dim_size(1), stream, context);
outputLabelsTensor->flat<float>().data(), nullptr, batchSize, inputHeight, inputWidth, outputHeight,
outputWidth, noLabels ? 0 : labelsShape.dim_size(1), stream, context);
}
catch (std::exception &ex)
{
Expand Down

0 comments on commit e82f65f

Please sign in to comment.