diff --git a/src/lib/withMessages.js b/src/lib/withMessages.js
index cf47823..befead5 100644
--- a/src/lib/withMessages.js
+++ b/src/lib/withMessages.js
@@ -12,7 +12,7 @@ function enhanceWithMessages(keyPrefix, WrappedComponent) {
/**
* The enhancer HOC.
*/
- function Enhancer(props) {
+ function WithMessages(props) {
const messageSourceApi = useMessageSource(keyPrefix);
if (process.env.NODE_ENV !== 'production') {
const hasOwn = Object.prototype.hasOwnProperty;
@@ -28,11 +28,17 @@ function enhanceWithMessages(keyPrefix, WrappedComponent) {
);
}
- return ;
+ // eslint-disable-next-line react/prop-types
+ const { forwardedRef, ...rest } = props;
+ return ;
}
- Enhancer.displayName = `WithMessages(${wrappedComponentName})`;
- return hoistNonReactStatics(Enhancer, WrappedComponent);
+ WithMessages.displayName = `WithMessages(${wrappedComponentName})`;
+
+ return hoistNonReactStatics(
+ React.forwardRef((props, ref) => ),
+ WrappedComponent,
+ );
}
/**
diff --git a/src/lib/withMessages.test.js b/src/lib/withMessages.test.js
index f28675a..c9c0574 100644
--- a/src/lib/withMessages.test.js
+++ b/src/lib/withMessages.test.js
@@ -1,4 +1,4 @@
-import React from 'react';
+import React, { Component } from 'react';
import TestRenderer from 'react-test-renderer';
import { Provider as MessageSourceProvider } from './MessageSourceContext';
import * as MessageSource from './withMessages';
@@ -164,4 +164,42 @@ describe('withMessages', () => {
expect(levelOneComponent.children[0]).toBe('Hello World');
expect(levelTwoComponent.children[0]).toBe('Hallo Welt');
});
+
+ it('supports ref forwarding', () => {
+ const NestedHOC = MessageSource.withMessages('hello')(
+ class Nested extends Component {
+ myMethod = () => {
+ return 100;
+ };
+
+ render() {
+ const { getMessageWithNamedParams } = this.props; // eslint-disable-line react/prop-types
+ return {getMessageWithNamedParams('hello.world')};
+ }
+ },
+ );
+
+ // eslint-disable-next-line react/no-multi-comp
+ class MyFwRefComponent extends Component {
+ nestedRef = React.createRef();
+
+ render() {
+ return ;
+ }
+ }
+
+ const renderer = TestRenderer.create(
+
+
+ ,
+ );
+
+ const { root } = renderer;
+ const fwRefCompInstance = root.findByType(MyFwRefComponent);
+
+ expect(fwRefCompInstance.instance).toBeDefined();
+ expect(fwRefCompInstance.instance.nestedRef.current).toBeDefined();
+ expect(fwRefCompInstance.instance.nestedRef.current.myMethod).toBeDefined();
+ expect(fwRefCompInstance.instance.nestedRef.current.myMethod()).toBe(100);
+ });
});