-
Notifications
You must be signed in to change notification settings - Fork 0
/
decode_and_crop_jpeg.cpp
158 lines (132 loc) · 5.07 KB
/
decode_and_crop_jpeg.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <torch/extension.h>
#include <torch/types.h>
#include <stdio.h>
#include <jpeglib.h>
#include <setjmp.h>
#include <random>
#include <iostream>
namespace {
static const JOCTET EOI_BUFFER[1] = {JPEG_EOI};
struct torch_jpeg_error_mgr {
struct jpeg_error_mgr pub; /* "public" fields */
char jpegLastErrorMsg[JMSG_LENGTH_MAX]; /* error messages */
jmp_buf setjmp_buffer; /* for return to caller */
};
using torch_jpeg_error_ptr = struct torch_jpeg_error_mgr*;
void torch_jpeg_error_exit(j_common_ptr cinfo) {
/* cinfo->err really points to a torch_jpeg_error_mgr struct, so coerce
* pointer */
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
/* Always display the message. */
/* We could postpone this until after returning, if we chose. */
// (*cinfo->err->output_message)(cinfo);
/* Create the message */
(*(cinfo->err->format_message))(cinfo, myerr->jpegLastErrorMsg);
/* Return control to the setjmp point */
longjmp(myerr->setjmp_buffer, 1);
}
struct torch_jpeg_mgr {
struct jpeg_source_mgr pub;
const JOCTET* data;
size_t len;
};
static void torch_jpeg_init_source(j_decompress_ptr cinfo) {}
static boolean torch_jpeg_fill_input_buffer(j_decompress_ptr cinfo) {
// No more data. Probably an incomplete image; Raise exception.
torch_jpeg_error_ptr myerr = (torch_jpeg_error_ptr)cinfo->err;
strcpy(myerr->jpegLastErrorMsg, "Image is incomplete or truncated");
longjmp(myerr->setjmp_buffer, 1);
}
static void torch_jpeg_skip_input_data(j_decompress_ptr cinfo, long num_bytes) {
torch_jpeg_mgr* src = (torch_jpeg_mgr*)cinfo->src;
if (src->pub.bytes_in_buffer < (size_t)num_bytes) {
// Skipping over all of remaining data; output EOI.
src->pub.next_input_byte = EOI_BUFFER;
src->pub.bytes_in_buffer = 1;
} else {
// Skipping over only some of the remaining data.
src->pub.next_input_byte += num_bytes;
src->pub.bytes_in_buffer -= num_bytes;
}
}
static void torch_jpeg_term_source(j_decompress_ptr cinfo) {}
static void torch_jpeg_set_source_mgr(
j_decompress_ptr cinfo,
const unsigned char* data,
size_t len) {
torch_jpeg_mgr* src;
if (cinfo->src == 0) { // if this is first time; allocate memory
cinfo->src = (struct jpeg_source_mgr*)(*cinfo->mem->alloc_small)(
(j_common_ptr)cinfo, JPOOL_PERMANENT, sizeof(torch_jpeg_mgr));
}
src = (torch_jpeg_mgr*)cinfo->src;
src->pub.init_source = torch_jpeg_init_source;
src->pub.fill_input_buffer = torch_jpeg_fill_input_buffer;
src->pub.skip_input_data = torch_jpeg_skip_input_data;
src->pub.resync_to_restart = jpeg_resync_to_restart; // default
src->pub.term_source = torch_jpeg_term_source;
// fill the buffers
src->data = (const JOCTET*)data;
src->len = len;
src->pub.bytes_in_buffer = len;
src->pub.next_input_byte = src->data;
}
} // namespace
torch::Tensor decode_and_crop_jpeg(const torch::Tensor& data,
unsigned int crop_y,
unsigned int crop_x,
unsigned int crop_height,
unsigned int crop_width) {
struct jpeg_decompress_struct cinfo;
struct torch_jpeg_error_mgr jerr;
auto datap = data.data_ptr<uint8_t>();
// Setup decompression structure
cinfo.err = jpeg_std_error(&jerr.pub);
jerr.pub.error_exit = torch_jpeg_error_exit;
/* Establish the setjmp return context for my_error_exit to use. */
setjmp(jerr.setjmp_buffer);
jpeg_create_decompress(&cinfo);
torch_jpeg_set_source_mgr(&cinfo, datap, data.numel());
// read info from header.
jpeg_read_header(&cinfo, TRUE);
int channels = cinfo.num_components;
jpeg_start_decompress(&cinfo);
int stride = crop_width * channels;
auto tensor =
torch::empty({int64_t(crop_height), int64_t(crop_width), channels}, torch::kU8);
auto ptr = tensor.data_ptr<uint8_t>();
unsigned int update_width = crop_width;
jpeg_crop_scanline(&cinfo, &crop_x, &update_width);
jpeg_skip_scanlines(&cinfo, crop_y);
const int offset = (cinfo.output_width - crop_width) * channels;
uint8_t* temp = nullptr;
if(offset > 0) temp = new uint8_t[cinfo.output_width * channels];
while (cinfo.output_scanline < crop_y + crop_height) {
/* jpeg_read_scanlines expects an array of pointers to scanlines.
* Here the array is only one element long, but you could ask for
* more than one scanline at a time if that's more convenient.
*/
if(offset>0){
jpeg_read_scanlines(&cinfo, &temp, 1);
memcpy(ptr, temp + offset, stride);
}
else
jpeg_read_scanlines(&cinfo, &ptr, 1);
ptr += stride;
}
if(offset > 0){
delete[] temp;
temp = nullptr;
}
if (cinfo.output_scanline < cinfo.output_height) {
// Skip the rest of scanlines, required by jpeg_destroy_decompress.
jpeg_skip_scanlines(&cinfo,
cinfo.output_height - crop_y - crop_height);
}
jpeg_finish_decompress(&cinfo);
jpeg_destroy_decompress(&cinfo);
return tensor.permute({2, 0, 1});
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_and_crop_jpeg", &decode_and_crop_jpeg, "decode_and_crop_jpeg");
}