The KdTree implementation here - https://github.com/Voidious/Diamond/blob/master/ags/utils/dataStructures/trees/thirdGenKD/KdTree.java
has a NearestNeighbour search. How hard would that be to port into JTS?
I tried NN with STRtree but performance was terrible compared to KdTree for query (on 500k 2D points).