diff --git a/src/lib/withMessages.js b/src/lib/withMessages.js index cf47823..befead5 100644 --- a/src/lib/withMessages.js +++ b/src/lib/withMessages.js @@ -12,7 +12,7 @@ function enhanceWithMessages(keyPrefix, WrappedComponent) { /** * The enhancer HOC. */ - function Enhancer(props) { + function WithMessages(props) { const messageSourceApi = useMessageSource(keyPrefix); if (process.env.NODE_ENV !== 'production') { const hasOwn = Object.prototype.hasOwnProperty; @@ -28,11 +28,17 @@ function enhanceWithMessages(keyPrefix, WrappedComponent) { ); } - return ; + // eslint-disable-next-line react/prop-types + const { forwardedRef, ...rest } = props; + return ; } - Enhancer.displayName = `WithMessages(${wrappedComponentName})`; - return hoistNonReactStatics(Enhancer, WrappedComponent); + WithMessages.displayName = `WithMessages(${wrappedComponentName})`; + + return hoistNonReactStatics( + React.forwardRef((props, ref) => ), + WrappedComponent, + ); } /** diff --git a/src/lib/withMessages.test.js b/src/lib/withMessages.test.js index f28675a..c9c0574 100644 --- a/src/lib/withMessages.test.js +++ b/src/lib/withMessages.test.js @@ -1,4 +1,4 @@ -import React from 'react'; +import React, { Component } from 'react'; import TestRenderer from 'react-test-renderer'; import { Provider as MessageSourceProvider } from './MessageSourceContext'; import * as MessageSource from './withMessages'; @@ -164,4 +164,42 @@ describe('withMessages', () => { expect(levelOneComponent.children[0]).toBe('Hello World'); expect(levelTwoComponent.children[0]).toBe('Hallo Welt'); }); + + it('supports ref forwarding', () => { + const NestedHOC = MessageSource.withMessages('hello')( + class Nested extends Component { + myMethod = () => { + return 100; + }; + + render() { + const { getMessageWithNamedParams } = this.props; // eslint-disable-line react/prop-types + return {getMessageWithNamedParams('hello.world')}; + } + }, + ); + + // eslint-disable-next-line react/no-multi-comp + class MyFwRefComponent extends Component { + nestedRef = React.createRef(); + + render() { + return ; + } + } + + const renderer = TestRenderer.create( + + + , + ); + + const { root } = renderer; + const fwRefCompInstance = root.findByType(MyFwRefComponent); + + expect(fwRefCompInstance.instance).toBeDefined(); + expect(fwRefCompInstance.instance.nestedRef.current).toBeDefined(); + expect(fwRefCompInstance.instance.nestedRef.current.myMethod).toBeDefined(); + expect(fwRefCompInstance.instance.nestedRef.current.myMethod()).toBe(100); + }); });