@@ -26,7 +26,7 @@ function FocusLock(
26
26
{ className, disabled, autoFocus, restoreFocus, children } : FocusLockProps ,
27
27
ref : React . Ref < FocusLockRef >
28
28
) {
29
- const returnFocusToRef = useRef < HTMLOrSVGElement | null > ( null ) ;
29
+ const restoreFocusTargetRef = useRef < HTMLOrSVGElement | null > ( null ) ;
30
30
const containerRef = useRef < HTMLDivElement | null > ( null ) ;
31
31
32
32
const focusFirst = ( ) => {
@@ -44,21 +44,25 @@ function FocusLock(
44
44
// Captures focus when `autoFocus` is set, and the component is mounted or
45
45
// `disabled` changes from true to false.
46
46
useEffect ( ( ) => {
47
+ const assignRestoreFocusTarget = ( ) => {
48
+ if ( document . activeElement && ! containerRef . current ?. contains ( document . activeElement as Node ) ) {
49
+ restoreFocusTargetRef . current = document . activeElement as unknown as HTMLOrSVGElement ;
50
+ }
51
+ } ;
47
52
if ( autoFocus && ! disabled ) {
48
- returnFocusToRef . current = document . activeElement as HTMLOrSVGElement | null ;
53
+ assignRestoreFocusTarget ( ) ;
49
54
focusFirst ( ) ;
50
55
}
51
56
} , [ autoFocus , disabled ] ) ;
52
57
53
- // Restore focus if `restoreFocus` is set, and `disabled` changes from false
54
- // to true.
58
+ // Restore focus if `restoreFocus` is set, and `disabled` changes from false to true.
55
59
const [ previouslyDisabled , setPreviouslyDisabled ] = useState ( ! ! disabled ) ;
56
60
useEffect ( ( ) => {
57
61
if ( previouslyDisabled !== ! ! disabled ) {
58
62
setPreviouslyDisabled ( ! ! disabled ) ;
59
63
if ( restoreFocus && disabled ) {
60
- returnFocusToRef . current ?. focus ( ) ;
61
- returnFocusToRef . current = null ;
64
+ restoreFocusTargetRef . current ?. focus ( ) ;
65
+ restoreFocusTargetRef . current = null ;
62
66
}
63
67
}
64
68
} , [ previouslyDisabled , disabled , restoreFocus ] ) ;
@@ -68,8 +72,8 @@ function FocusLock(
68
72
const restoreFocusHandler = useCallback (
69
73
( elem : HTMLDivElement | null ) => {
70
74
if ( elem === null && restoreFocus ) {
71
- returnFocusToRef . current ?. focus ( ) ;
72
- returnFocusToRef . current = null ;
75
+ restoreFocusTargetRef . current ?. focus ( ) ;
76
+ restoreFocusTargetRef . current = null ;
73
77
}
74
78
} ,
75
79
[ restoreFocus ]
0 commit comments