@@ -208,38 +208,37 @@ _png_module::read_png(const Py::Tuple& args) {
208208
209209 png_init_io (png_ptr, fp);
210210 png_set_sig_bytes (png_ptr, 8 );
211-
212211 png_read_info (png_ptr, info_ptr);
213212
214213 png_uint_32 width = info_ptr->width ;
215214 png_uint_32 height = info_ptr->height ;
216- bool do_gray_conversion = (info_ptr->bit_depth < 8 &&
217- info_ptr->color_type == PNG_COLOR_TYPE_GRAY);
218215
219216 int bit_depth = info_ptr->bit_depth ;
220- if (bit_depth == 16 ) {
221- png_set_strip_16 (png_ptr);
222- } else if (bit_depth < 8 ) {
217+
218+ // Unpack 1, 2, and 4-bit images
219+ if (bit_depth < 8 )
223220 png_set_packing (png_ptr);
224- }
225221
226- // convert misc color types to rgb for simplicity
227- if (info_ptr->color_type == PNG_COLOR_TYPE_GRAY ||
228- info_ptr->color_type == PNG_COLOR_TYPE_GRAY_ALPHA) {
229- png_set_gray_to_rgb (png_ptr);
230- } else if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE) {
222+ // If sig bits are set, shift data
223+ png_color_8p sig_bit;
224+ if ((info_ptr->color_type != PNG_COLOR_TYPE_PALETTE) && png_get_sBIT (png_ptr, info_ptr, &sig_bit))
225+ png_set_shift (png_ptr, sig_bit);
226+
227+ // Convert big endian to little
228+ if (bit_depth == 16 )
229+ png_set_swap (png_ptr);
230+
231+ // Convert palletes to full RGB
232+ if (info_ptr->color_type == PNG_COLOR_TYPE_PALETTE)
231233 png_set_palette_to_rgb (png_ptr);
232- }
234+
235+ // If there's an alpha channel convert gray to RGB
236+ if (info_ptr->color_type == PNG_COLOR_TYPE_GRAY_ALPHA)
237+ png_set_gray_to_rgb (png_ptr);
233238
234239 png_set_interlace_handling (png_ptr);
235240 png_read_update_info (png_ptr, info_ptr);
236241
237- bool rgba = info_ptr->color_type == PNG_COLOR_TYPE_RGBA;
238- if ( (info_ptr->color_type != PNG_COLOR_TYPE_RGB) && !rgba) {
239- std::cerr << " Found color type " << (int )info_ptr->color_type << std::endl;
240- throw Py::RuntimeError (" _image_module::readpng: cannot handle color_type" );
241- }
242-
243242 /* read file */
244243 if (setjmp (png_jmpbuf (png_ptr)))
245244 throw Py::RuntimeError (" _image_module::readpng: error during read_image" );
@@ -255,37 +254,36 @@ _png_module::read_png(const Py::Tuple& args) {
255254 npy_intp dimensions[3 ];
256255 dimensions[0 ] = height; // numrows
257256 dimensions[1 ] = width; // numcols
258- dimensions[2 ] = 4 ;
259-
260- PyArrayObject *A = (PyArrayObject *) PyArray_SimpleNew (3 , dimensions, PyArray_FLOAT);
261-
262- if (do_gray_conversion) {
263- float max_value = (float )((1L << bit_depth) - 1 );
264- for (png_uint_32 y = 0 ; y < height; y++) {
265- png_byte* row = row_pointers[y];
266- for (png_uint_32 x = 0 ; x < width; x++) {
267- float value = row[x] / max_value;
268- size_t offset = y*A->strides [0 ] + x*A->strides [1 ];
269- *(float *)(A->data + offset + 0 *A->strides [2 ]) = value;
270- *(float *)(A->data + offset + 1 *A->strides [2 ]) = value;
271- *(float *)(A->data + offset + 2 *A->strides [2 ]) = value;
272- *(float *)(A->data + offset + 3 *A->strides [2 ]) = 1 .0f ;
273- }
274- }
275- } else {
276- for (png_uint_32 y = 0 ; y < height; y++) {
277- png_byte* row = row_pointers[y];
278- for (png_uint_32 x = 0 ; x < width; x++) {
279- png_byte* ptr = (rgba) ? &(row[x*4 ]) : &(row[x*3 ]);
280- size_t offset = y*A->strides [0 ] + x*A->strides [1 ];
281- *(float *)(A->data + offset + 0 *A->strides [2 ]) = (float )(ptr[0 ]/255.0 );
282- *(float *)(A->data + offset + 1 *A->strides [2 ]) = (float )(ptr[1 ]/255.0 );
283- *(float *)(A->data + offset + 2 *A->strides [2 ]) = (float )(ptr[2 ]/255.0 );
284- *(float *)(A->data + offset + 3 *A->strides [2 ]) = rgba ? (float )(ptr[3 ]/255.0 ) : 1 .0f ;
285- }
257+ if (info_ptr->color_type & PNG_COLOR_MASK_ALPHA)
258+ dimensions[2 ] = 4 ; // RGBA images
259+ else if (info_ptr->color_type & PNG_COLOR_MASK_COLOR)
260+ dimensions[2 ] = 3 ; // RGB images
261+ else
262+ dimensions[2 ] = 1 ; // Greyscale images
263+ // For gray, return an x by y array, not an x by y by 1
264+ int num_dims = (info_ptr->color_type & PNG_COLOR_MASK_COLOR) ? 3 : 2 ;
265+
266+ double max_value = (1 << ((bit_depth < 8 ) ? 8 : bit_depth)) - 1 ;
267+ PyArrayObject *A = (PyArrayObject *) PyArray_SimpleNew (num_dims, dimensions, PyArray_FLOAT);
268+
269+ for (png_uint_32 y = 0 ; y < height; y++) {
270+ png_byte* row = row_pointers[y];
271+ for (png_uint_32 x = 0 ; x < width; x++) {
272+ size_t offset = y*A->strides [0 ] + x*A->strides [1 ];
273+ if (bit_depth == 16 ) {
274+ png_uint_16* ptr = &reinterpret_cast <png_uint_16*> (row)[x * dimensions[2 ]];
275+ for (png_uint_32 p = 0 ; p < dimensions[2 ]; p++)
276+ *(float *)(A->data + offset + p*A->strides [2 ]) = (float )(ptr[p]) / max_value;
277+ } else {
278+ png_byte* ptr = &(row[x * dimensions[2 ]]);
279+ for (png_uint_32 p = 0 ; p < dimensions[2 ]; p++)
280+ {
281+ *(float *)(A->data + offset + p*A->strides [2 ]) = (float )(ptr[p]) / max_value;
282+ }
283+ }
286284 }
287285 }
288-
286+
289287 // free the png memory
290288 png_read_end (png_ptr, info_ptr);
291289 png_destroy_read_struct (&png_ptr, &info_ptr, png_infopp_NULL);
0 commit comments