diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 4dbca868bc7..48da68fe7e3 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -2148,7 +2148,8 @@ extern "C" { }; enum ggml_scale_flag { - GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8) + GGML_SCALE_FLAG_ALIGN_CORNERS = (1 << 8), + GGML_SCALE_FLAG_ANTIALIAS = (1 << 9), }; // interpolate diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index df28d67fb0b..cd1b5e5b944 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2500,6 +2500,9 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten if (op->op_params[0] != GGML_SCALE_MODE_NEAREST) { return false; } + if (op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS) { + return false; + } return true; } case GGML_OP_POOL_2D: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 2745fc54e15..608e82af69f 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7420,6 +7420,65 @@ static void ggml_compute_forward_upscale_f32( } } } + } else if (mode == GGML_SCALE_MODE_BILINEAR && (mode_flags & GGML_SCALE_FLAG_ANTIALIAS)) { + // Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) + // https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp + auto triangle_filter = [](float x) -> float { + return std::max(1.0f - fabsf(x), 0.0f); + }; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = std::max(1.0f, 1.0f / sf1); + const float invscale1 = 1.0f / support1; + const float support0 = std::max(1.0f, 1.0f / sf0); + const float invscale0 = 1.0f / support0; + + for (int64_t i3 = 0; i3 < ne3; i3++) { + const int64_t i03 = i3 / sf3; + for (int64_t i2 = ith; i2 < ne2; i2 += nth) { + const int64_t i02 = i2 / sf2; + for (int64_t i1 = 0; i1 < ne1; i1++) { + const float y = ((float) i1 + pixel_offset) / sf1; + for (int64_t i0 = 0; i0 < ne0; i0++) { + const float x = ((float) i0 + pixel_offset) / sf0; + + // the range of source pixels that contribute + const int64_t x_min = std::max(x - support0 + pixel_offset, 0); + const int64_t x_max = std::min(x + support0 + pixel_offset, ne00); + const int64_t y_min = std::max(y - support1 + pixel_offset, 0); + const int64_t y_max = std::min(y + support1 + pixel_offset, ne01); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0->data + sx*nb00 + sy*nb01 + i02*nb02 + i03*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + float * dst_ptr = (float *)((char *)dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3); + *dst_ptr = val; + } + } + } + } } else if (mode == GGML_SCALE_MODE_BILINEAR) { for (int64_t i3 = 0; i3 < ne3; i3++) { const int64_t i03 = i3 / sf3; diff --git a/ggml/src/ggml-cuda/upscale.cu b/ggml/src/ggml-cuda/upscale.cu index 687c669304d..6bdf3cd996b 100644 --- a/ggml/src/ggml-cuda/upscale.cu +++ b/ggml/src/ggml-cuda/upscale.cu @@ -81,6 +81,76 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst, dst[index] = result; } +// Similar to F.interpolate(..., mode="bilinear", align_corners=False, antialias=True) +// https://github.com/pytorch/pytorch/blob/8871ff29b743948d1225389d5b7068f37b22750b/aten/src/ATen/native/cpu/UpSampleKernel.cpp +static __global__ void upscale_f32_bilinear_antialias(const float * src0, float * dst, + const int nb00, const int nb01, const int nb02, const int nb03, + const int ne00_src, const int ne01_src, + const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, + const float sf0, const float sf1, const float sf2, const float sf3, + const float pixel_offset) { + const int64_t index = threadIdx.x + blockIdx.x * blockDim.x; + const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst; + + if (index >= dst_total_elements) { + return; + } + + const int i10_dst = index % ne10_dst; + const int i11_dst = (index / ne10_dst) % ne11_dst; + const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst; + const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst); + + const int i02_src = (int)(i12_dst / sf2); + const int i03_src = (int)(i13_dst / sf3); + + const float y = ((float)i11_dst + pixel_offset) / sf1; + const float x = ((float)i10_dst + pixel_offset) / sf0; + + // support and invscale, minimum 1 pixel for bilinear + const float support1 = max(1.0f / sf1, 1.0f); + const float invscale1 = 1.0f / support1; + const float support0 = max(1.0f / sf0, 1.0f); + const float invscale0 = 1.0f / support0; + + // the range of source pixels that contribute + const int64_t x_min = max(int64_t(0), int64_t(x - support0 + pixel_offset)); + const int64_t x_max = min(int64_t(ne00_src), int64_t(x + support0 + pixel_offset)); + const int64_t y_min = max(int64_t(0), int64_t(y - support1 + pixel_offset)); + const int64_t y_max = min(int64_t(ne01_src), int64_t(y + support1 + pixel_offset)); + + // bilinear filter with antialiasing + float val = 0.0f; + float total_weight = 0.0f; + + auto triangle_filter = [](float x) -> float { + return max(1.0f - fabsf(x), 0.0f); + }; + + for (int64_t sy = y_min; sy < y_max; sy++) { + const float weight_y = triangle_filter((sy - y + pixel_offset) * invscale1); + + for (int64_t sx = x_min; sx < x_max; sx++) { + const float weight_x = triangle_filter((sx - x + pixel_offset) * invscale0); + const float weight = weight_x * weight_y; + + if (weight <= 0.0f) { + continue; + } + + const float pixel = *(const float *)((const char *)src0 + sx*nb00 + sy*nb01 + i02_src*nb02 + i03_src*nb03); + val += pixel * weight; + total_weight += weight; + } + } + + if (total_weight > 0.0f) { + val /= total_weight; + } + + dst[index] = val; +} + namespace bicubic_interpolation { // https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm __device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch) @@ -161,11 +231,15 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst, const int ne00_src, const int ne01_src, const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst, const float sf0, const float sf1, const float sf2, const float sf3, - const float pixel_offset, cudaStream_t stream) { + const float pixel_offset, bool antialias, cudaStream_t stream) { const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst; const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE; - upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + if (antialias) { + upscale_f32_bilinear_antialias<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } else { + upscale_f32_bilinear<<>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); + } } static void upscale_f32_bicubic_cuda(const float * x, float * dst, @@ -207,9 +281,10 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { if (mode == GGML_SCALE_MODE_NEAREST) { upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); } else if (mode == GGML_SCALE_MODE_BILINEAR) { + const bool antialias = (mode_flags & GGML_SCALE_FLAG_ANTIALIAS); upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - sf0, sf1, sf2, sf3, pixel_offset, stream); + sf0, sf1, sf2, sf3, pixel_offset, antialias, stream); } else if (mode == GGML_SCALE_MODE_BICUBIC) { upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index 09b1b503118..3aad16a3ff7 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -894,7 +894,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_POOL_1D: return false; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_POOL_2D: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_PAD: diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index e5302f4550e..277a30d30ed 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3086,8 +3086,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: { ggml_scale_mode mode = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & 0xFF); + const bool antialias = (ggml_scale_mode)(ggml_get_op_params_i32(op, 0) & GGML_SCALE_FLAG_ANTIALIAS); return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR); + (mode == GGML_SCALE_MODE_NEAREST || mode == GGML_SCALE_MODE_BILINEAR) && !antialias; } case GGML_OP_CONV_2D: return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) || diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 3f1bdfb9f1b..e82b51206e2 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4597,7 +4597,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_IM2COL: return true; case GGML_OP_UPSCALE: - return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; + return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 66dd0bfabd2..95966ce1d8e 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -14113,6 +14113,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return true; case GGML_OP_UPSCALE: + return op->src[0]->type == GGML_TYPE_F32 && !(op->op_params[0] & GGML_SCALE_FLAG_ANTIALIAS); case GGML_OP_ACC: return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_CONCAT: diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b99345a2e93..17cf4d84bb8 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4891,6 +4891,8 @@ static struct ggml_tensor * ggml_interpolate_impl( int64_t ne3, uint32_t mode) { GGML_ASSERT((mode & 0xFF) < GGML_SCALE_MODE_COUNT); + // TODO: implement antialias for modes other than bilinear + GGML_ASSERT(!(mode & GGML_SCALE_FLAG_ANTIALIAS) || (mode & 0xFF) == GGML_SCALE_MODE_BILINEAR); struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, ne0, ne1, ne2, ne3); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 87a61aa1224..9645d0b3909 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7660,7 +7660,7 @@ static std::vector> make_test_cases_eval() { // test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {i, 2, 1, 3}, rand() % i + 1)); //} - for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC}) { + for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR, GGML_SCALE_MODE_BICUBIC, ggml_scale_mode(GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS)}) { test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode)); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, {512, 512, 3, 2}, 2, mode, true)); test_cases.emplace_back(new test_interpolate(GGML_TYPE_F32, {2, 5, 7, 11}, {5, 7, 11, 13}, mode)); diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 52ea542decc..db477bbbe9b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -2012,7 +2012,7 @@ struct clip_graph { ggml_tensor * pos_embd = model.position_embeddings; const int height = img.ny / patch_size; const int width = img.nx / patch_size; - const uint32_t mode = GGML_SCALE_MODE_BILINEAR; + const uint32_t mode = GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ANTIALIAS; const int n_per_side = (int)std::sqrt(pos_embd->ne[1]); GGML_ASSERT(pos_embd); @@ -2787,7 +2787,8 @@ struct clip_model_loader { { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - hparams.set_limit_image_tokens(64, 256); + // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 + hparams.set_limit_image_tokens(64, 1024); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -3737,12 +3738,13 @@ struct img_tool { const int width = inp_size.width; const int height = inp_size.height; + auto round_by_factor = [f = align_size](float x) { return static_cast(std::round(x / static_cast(f))) * f; }; auto ceil_by_factor = [f = align_size](float x) { return static_cast(std::ceil(x / static_cast(f))) * f; }; auto floor_by_factor = [f = align_size](float x) { return static_cast(std::floor(x / static_cast(f))) * f; }; // always align up first - int h_bar = std::max(align_size, ceil_by_factor(height)); - int w_bar = std::max(align_size, ceil_by_factor(width)); + int h_bar = std::max(align_size, round_by_factor(height)); + int w_bar = std::max(align_size, round_by_factor(width)); if (h_bar * w_bar > max_pixels) { const auto beta = std::sqrt(static_cast(height * width) / max_pixels); @@ -4357,7 +4359,8 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index dfad9cd7957..6690bf30046 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -304,6 +304,10 @@ struct mtmd_context { img_beg = "<|im_start|>"; img_end = "<|im_end|>"; + } else if (proj == PROJECTOR_TYPE_LFM2) { + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + } }