diff --git a/src/index.tsx b/src/index.tsx index e1ce255..a989a30 100644 --- a/src/index.tsx +++ b/src/index.tsx @@ -2,13 +2,22 @@ import * as React from 'react' import { useCallback, useEffect } from 'react' import { useBeforeUnload, useLocation, useNavigation } from 'react-router-dom' +type Direction = "vertical" | "horizontal"; +type ScrollAttribute = "scrollTop" | "scrollLeft"; +const DIRECTION: { [direction in Direction]: ScrollAttribute } = { + vertical: 'scrollTop', + horizontal: 'scrollLeft', +}; + export function ElementScrollRestoration({ elementQuery, + direction = "vertical", ...props -}: { elementQuery: string } & React.HTMLProps) { +}: { elementQuery: string; direction?: Direction } & React.HTMLProps) { const STORAGE_KEY = `position:${elementQuery}` const navigation = useNavigation() const location = useLocation() + const scrollAttribute = DIRECTION[direction] const updatePositions = useCallback(() => { const element = document.querySelector(elementQuery) @@ -26,7 +35,7 @@ export function ElementScrollRestoration({ } const newPositions = { ...positions, - [location.key]: element.scrollTop, + [location.key]: element[scrollAttribute], } sessionStorage.setItem(STORAGE_KEY, JSON.stringify(newPositions)) }, [STORAGE_KEY, elementQuery, location.key]) @@ -39,9 +48,9 @@ export function ElementScrollRestoration({ const positions = JSON.parse( sessionStorage.getItem(STORAGE_KEY) || '{}', ) as any - const storedY = positions[window.history.state.key] - if (typeof storedY === 'number') { - element.scrollTop = storedY + const stored = positions[window.history.state.key] + if (typeof stored === 'number') { + element[scrollAttribute] = stored } } catch (error: unknown) { console.error(error) @@ -56,7 +65,7 @@ export function ElementScrollRestoration({ updatePositions() }) - function restoreScroll(storageKey: string, elementQuery: string) { + function restoreScroll(storageKey: string, elementQuery: string, scrollAttribute: ScrollAttribute) { const element = document.querySelector(elementQuery) if (!element) { console.warn(`Element not found: ${elementQuery}. Cannot restore scroll.`) @@ -70,9 +79,9 @@ export function ElementScrollRestoration({ const positions = JSON.parse( sessionStorage.getItem(storageKey) || '{}', ) as any - const storedY = positions[window.history.state.key] - if (typeof storedY === 'number') { - element.scrollTop = storedY + const stored = positions[window.history.state.key] + if (typeof stored === 'number') { + element[scrollAttribute] = stored } } catch (error: unknown) { console.error(error) @@ -86,7 +95,7 @@ export function ElementScrollRestoration({ dangerouslySetInnerHTML={{ __html: `(${restoreScroll})(${JSON.stringify( STORAGE_KEY, - )}, ${JSON.stringify(elementQuery)})`, + )}, ${JSON.stringify(elementQuery)}, ${JSON.stringify(scrollAttribute)})`, }} /> )