@@ -921,48 +921,60 @@ transform::detail::checkApplyToOne(Operation *transformOp,
921921  //  Check that the right kind of value was produced.
922922  for  (const  auto  &[ptr, res] :
923923       llvm::zip (partialResult, transformOp->getResults ())) {
924-     if  (ptr.isNull ()) {
925-       return  emitDiag () << " null result #" getResultNumber ()
926-                         << "  produced" 
924+     if  (ptr.isNull ())
925+       continue ;
926+     if  (res.getType ().template  isa <TransformHandleTypeInterface>() &&
927+         !ptr.is <Operation *>()) {
928+       return  emitDiag () << " application of " 
929+                         << "  expected to produce an Operation * for result #" 
930+                         << res.getResultNumber ();
927931    }
928-     if  (ptr. is <Operation * >() &&
929-         !res. getType (). template   isa <TransformHandleTypeInterface >()) {
932+     if  (res. getType (). template   isa <TransformParamTypeInterface >() &&
933+         !ptr. is <Attribute >()) {
930934      return  emitDiag () << " application of " 
931935                        << "  expected to produce an Attribute for result #" 
932936                        << res.getResultNumber ();
933937    }
934-     if  (ptr. is <Attribute >() &&
935-         !res. getType (). template   isa <TransformParamTypeInterface >()) {
938+     if  (res. getType (). template   isa <TransformValueHandleTypeInterface >() &&
939+         !ptr. is <Value >()) {
936940      return  emitDiag () << " application of " 
937-                         << "  expected to produce an Operation *  for result #" 
941+                         << "  expected to produce a Value  for result #" 
938942                        << res.getResultNumber ();
939943    }
940944  }
941945  return  success ();
942946}
943947
948+ template  <typename  T>
949+ static  SmallVector<T> castVector (ArrayRef<transform::MappedValue> range) {
950+   return  llvm::to_vector (llvm::map_range (
951+       range, [](transform::MappedValue value) { return  value.get <T>(); }));
952+ }
953+ 
944954void  transform::detail::setApplyToOneResults (
945955    Operation *transformOp, TransformResults &transformResults,
946956    ArrayRef<ApplyToEachResultList> results) {
957+   SmallVector<SmallVector<MappedValue>> transposed;
958+   transposed.resize (transformOp->getNumResults ());
959+   for  (const  ApplyToEachResultList &partialResults : results) {
960+     if  (llvm::any_of (partialResults,
961+                      [](MappedValue value) { return  value.isNull (); }))
962+       continue ;
963+     assert (transformOp->getNumResults () == partialResults.size () &&
964+            " expected as many partial results as op as results" 
965+     for  (auto  &[i, value] : llvm::enumerate (partialResults))
966+       transposed[i].push_back (value);
967+   }
968+ 
947969  for  (OpResult r : transformOp->getResults ()) {
970+     unsigned  position = r.getResultNumber ();
948971    if  (r.getType ().isa <TransformParamTypeInterface>()) {
949-       auto  params = llvm::to_vector (
950-           llvm::map_range (results, [r](const  ApplyToEachResultList &oneResult) {
951-             return  oneResult[r.getResultNumber ()].get <Attribute>();
952-           }));
953-       transformResults.setParams (r, params);
972+       transformResults.setParams (r,
973+                                  castVector<Attribute>(transposed[position]));
954974    } else  if  (r.getType ().isa <TransformValueHandleTypeInterface>()) {
955-       auto  values = llvm::to_vector (
956-           llvm::map_range (results, [r](const  ApplyToEachResultList &oneResult) {
957-             return  oneResult[r.getResultNumber ()].get <Value>();
958-           }));
959-       transformResults.setValues (r, values);
975+       transformResults.setValues (r, castVector<Value>(transposed[position]));
960976    } else  {
961-       auto  payloads = llvm::to_vector (
962-           llvm::map_range (results, [r](const  ApplyToEachResultList &oneResult) {
963-             return  oneResult[r.getResultNumber ()].get <Operation *>();
964-           }));
965-       transformResults.set (r, payloads);
977+       transformResults.set (r, castVector<Operation *>(transposed[position]));
966978    }
967979  }
968980}
0 commit comments