Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ui): model manager UI pass #5886

Merged
merged 24 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
bd23c83
feat(ui): model manager UI tweaks
psychedelicious Mar 7, 2024
9f4ca32
feat(ui): use stickyscrollable for models list
psychedelicious Mar 7, 2024
b6a7c3a
feat(ui): improved model list styling
psychedelicious Mar 7, 2024
21b5c70
feat(ui): add main model trigger phrases
psychedelicious Mar 7, 2024
0e88fba
fix(mm): only loras and main models get `trigger_phrases`
psychedelicious Mar 7, 2024
f4878fc
chore(ui): typegen
psychedelicious Mar 7, 2024
5a6cec9
fix(ui): missing translation
psychedelicious Mar 7, 2024
07dbbe6
fix(ui): typing issues related to trigger phrase changes
psychedelicious Mar 7, 2024
832133f
fix(ui): do not persist model manager state
psychedelicious Mar 7, 2024
01e4fef
fix(mm): model images reload when changed
psychedelicious Mar 7, 2024
70ae9e3
feat(ui): model header styling
psychedelicious Mar 7, 2024
dc90085
feat(ui): move model save/close buttons to model header
psychedelicious Mar 7, 2024
791e25b
fix(ui): reset model edit form state with new values
psychedelicious Mar 7, 2024
79ed1b6
tweak(ui): use check icon for model save button
psychedelicious Mar 7, 2024
cdc7edd
tweak(ui): update default settings layouts
psychedelicious Mar 7, 2024
127ea01
tweak(ui): style model edit
psychedelicious Mar 7, 2024
ef04d8f
tweak(ui): style trigger phrases
psychedelicious Mar 7, 2024
8d5c6b2
tweak(ui): add colors to base/format badges
psychedelicious Mar 7, 2024
2a074d8
perf(mm): add manual query cache updates for the update model route
psychedelicious Mar 7, 2024
c5d488b
fix(ui): default settings linked incorrectly
psychedelicious Mar 7, 2024
06b4ef2
fix(ui): display trigger phrases for loras in mm editor
psychedelicious Mar 7, 2024
bb2b1b4
fix(ui): clear pending trigger phrase immediately
psychedelicious Mar 7, 2024
09e618c
chore(ui): lint
psychedelicious Mar 7, 2024
aa74c20
docs(mm): update comment about model images
psychedelicious Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from send2trash import send2trash

from invokeai.app.services.invoker import Invoker
from invokeai.app.util.misc import uuid_string
from invokeai.app.util.thumbnails import make_thumbnail

from .model_images_base import ModelImageFileStorageBase
Expand Down Expand Up @@ -56,7 +57,12 @@ def get_url(self, model_key: str) -> str | None:
if not self._validate_path(path):
return

return self._invoker.services.urls.get_model_image_url(model_key)
url = self._invoker.services.urls.get_model_image_url(model_key)

# The image URL never changes, so we must add random query string to it to prevent caching
url += f"?{uuid_string()}"

return url

def delete(self, model_key: str) -> None:
try:
Expand Down
24 changes: 14 additions & 10 deletions invokeai/backend/model_manager/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,6 @@ class ModelConfigBase(BaseModel):
source_api_response: Optional[str] = Field(
description="The original API response from the source, as stringified JSON.", default=None
)
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)
default_settings: Optional[ModelDefaultSettings] = Field(
description="Default settings for this model", default=None
)
Expand Down Expand Up @@ -187,21 +186,24 @@ class DiffusersConfigBase(ModelConfigBase):
repo_variant: Optional[ModelRepoVariant] = ModelRepoVariant.Default


class LoRALyCORISConfig(ModelConfigBase):
class LoRAConfigBase(ModelConfigBase):
type: Literal[ModelType.LoRA] = ModelType.LoRA
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)


class LoRALyCORISConfig(LoRAConfigBase):
"""Model config for LoRA/Lycoris models."""

type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.LyCORIS] = ModelFormat.LyCORIS

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.LoRA.value}.{ModelFormat.LyCORIS.value}")


class LoRADiffusersConfig(ModelConfigBase):
class LoRADiffusersConfig(LoRAConfigBase):
"""Model config for LoRA/Diffusers models."""

type: Literal[ModelType.LoRA] = ModelType.LoRA
format: Literal[ModelFormat.Diffusers] = ModelFormat.Diffusers

@staticmethod
Expand Down Expand Up @@ -275,10 +277,14 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.TextualInversion.value}.{ModelFormat.EmbeddingFolder.value}")


class MainCheckpointConfig(CheckpointConfigBase):
class MainConfigBase(ModelConfigBase):
type: Literal[ModelType.Main] = ModelType.Main
trigger_phrases: Optional[set[str]] = Field(description="Set of trigger phrases for this model", default=None)


class MainCheckpointConfig(CheckpointConfigBase, MainConfigBase):
"""Model config for main checkpoint models."""

type: Literal[ModelType.Main] = ModelType.Main
variant: ModelVariantType = ModelVariantType.Normal
prediction_type: SchedulerPredictionType = SchedulerPredictionType.Epsilon
upcast_attention: bool = False
Expand All @@ -288,11 +294,9 @@ def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Checkpoint.value}")


class MainDiffusersConfig(DiffusersConfigBase):
class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
"""Model config for main diffusers models."""

type: Literal[ModelType.Main] = ModelType.Main

@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
Expand Down
2 changes: 2 additions & 0 deletions invokeai/frontend/web/public/locales/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,8 @@
"syncModels": "Sync Models",
"syncModelsDesc": "If your models are out of sync with the backend, you can refresh them up using this option. This is generally handy in cases where you add models to the InvokeAI root folder or autoimport directory after the application has booted.",
"triggerPhrases": "Trigger Phrases",
"loraTriggerPhrases": "LoRA Trigger Phrases",
"mainModelTriggerPhrases": "Main Model Trigger Phrases",
"typePhraseHere": "Type phrase here",
"upcastAttention": "Upcast Attention",
"uploadImage": "Upload Image",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import type { ButtonProps } from '@invoke-ai/ui-library';
import { Button } from '@invoke-ai/ui-library';
import { useFeatureStatus } from 'features/system/hooks/useFeatureStatus';
import { memo } from 'react';
import { useTranslation } from 'react-i18next';
import { PiArrowsClockwiseBold } from 'react-icons/pi';

import { useSyncModels } from './useSyncModels';

export const SyncModelsButton = memo((props: Omit<ButtonProps, 'aria-label'>) => {
const { t } = useTranslation();
const { syncModels, isLoading } = useSyncModels();
const isSyncModelEnabled = useFeatureStatus('syncModels').isFeatureEnabled;

if (!isSyncModelEnabled) {
return null;
}

return (
<Button
leftIcon={<PiArrowsClockwiseBold />}
isLoading={isLoading}
onClick={syncModels}
size="sm"
variant="ghost"
{...props}
>
{t('modelManager.syncModels')}
</Button>
);
});

SyncModelsButton.displayName = 'SyncModelsButton';
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,5 @@ export const modelManagerV2PersistConfig: PersistConfig<ModelManagerState> = {
name: modelManagerV2Slice.name,
initialState: initialModelManagerState,
migrate: migrateModelManagerState,
persistDenylist: [],
persistDenylist: ['selectedModelKey', 'selectedModelMode', 'filteredModelType', 'searchTerm'],
};
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import { Box, Button, Flex, Heading } from '@invoke-ai/ui-library';
import { Button, Flex, Heading, Spacer } from '@invoke-ai/ui-library';
import { useAppDispatch } from 'app/store/storeHooks';
import { SyncModelsIconButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsIconButton';
import { SyncModelsButton } from 'features/modelManagerV2/components/SyncModels/SyncModelsButton';
import { setSelectedModelKey } from 'features/modelManagerV2/store/modelManagerV2Slice';
import { useCallback } from 'react';
import { useTranslation } from 'react-i18next';
import { PiPlusBold } from 'react-icons/pi';

import ModelList from './ModelManagerPanel/ModelList';
import { ModelListNavigation } from './ModelManagerPanel/ModelListNavigation';
Expand All @@ -16,20 +17,19 @@ export const ModelManager = () => {
}, [dispatch]);

return (
<Box layerStyle="first" p={3} borderRadius="base" w="50%" h="full">
<Flex w="full" p={3} justifyContent="space-between" alignItems="center">
<Flex gap={2}>
<Heading fontSize="xl">{t('common.modelManager')}</Heading>
<SyncModelsIconButton />
</Flex>
<Button colorScheme="invokeYellow" onClick={handleClickAddModel}>
<Flex flexDir="column" layerStyle="first" p={4} gap={4} borderRadius="base" w="50%" h="full">
<Flex w="full" gap={4} justifyContent="space-between" alignItems="center">
<Heading fontSize="xl">{t('common.modelManager')}</Heading>
<Spacer />
<SyncModelsButton size="sm" />
<Button size="sm" colorScheme="invokeYellow" leftIcon={<PiPlusBold />} onClick={handleClickAddModel}>
{t('modelManager.addModels')}
</Button>
</Flex>
<Box layerStyle="second" p={3} borderRadius="base" w="full" h="full">
<Flex flexDir="column" layerStyle="second" p={4} gap={4} borderRadius="base" w="full" h="full">
<ModelListNavigation />
<ModelList />
</Box>
</Box>
</Flex>
</Flex>
);
};
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import { Badge } from '@invoke-ai/ui-library';
import { MODEL_TYPE_SHORT_MAP } from 'features/parameters/types/constants';
import { memo } from 'react';
import type { BaseModelType } from 'services/api/types';

type Props = {
base: BaseModelType;
};

const BASE_COLOR_MAP: Record<BaseModelType, string> = {
any: 'base',
'sd-1': 'green',
'sd-2': 'teal',
sdxl: 'invokeBlue',
'sdxl-refiner': 'invokeBlue',
};

const ModelBaseBadge = ({ base }: Props) => {
return (
<Badge flexGrow={0} colorScheme={BASE_COLOR_MAP[base]} variant="subtle">
{MODEL_TYPE_SHORT_MAP[base]}
</Badge>
);
};

export default memo(ModelBaseBadge);
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import { Badge } from '@invoke-ai/ui-library';
import { memo } from 'react';
import type { AnyModelConfig } from 'services/api/types';

type Props = {
format: AnyModelConfig['format'];
};

const FORMAT_NAME_MAP: Record<AnyModelConfig['format'], string> = {
diffusers: 'diffusers',
lycoris: 'lycoris',
checkpoint: 'checkpoint',
invokeai: 'internal',
embedding_file: 'embedding',
embedding_folder: 'embedding',
};

const FORMAT_COLOR_MAP: Record<AnyModelConfig['format'], string> = {
diffusers: 'base',
lycoris: 'base',
checkpoint: 'orange',
invokeai: 'base',
embedding_file: 'base',
embedding_folder: 'base',
};

const ModelFormatBadge = ({ format }: Props) => {
return (
<Badge flexGrow={0} colorScheme={FORMAT_COLOR_MAP[format]} variant="subtle">
{FORMAT_NAME_MAP[format]}
</Badge>
);
};

export default memo(ModelFormatBadge);
Original file line number Diff line number Diff line change
@@ -1,24 +1,39 @@
import { Box, Image } from '@invoke-ai/ui-library';
import { Flex, Icon, Image } from '@invoke-ai/ui-library';
import { typedMemo } from 'common/util/typedMemo';
import { PiImage } from 'react-icons/pi';

type Props = {
image_url?: string;
image_url?: string | null;
};

export const MODEL_IMAGE_THUMBNAIL_SIZE = '40px';
const FALLBACK_ICON_SIZE = '24px';

const ModelImage = ({ image_url }: Props) => {
if (!image_url) {
return <Box height="50px" minWidth="50px" />;
return (
<Flex
height={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
bg="base.650"
borderRadius="base"
alignItems="center"
justifyContent="center"
>
<Icon color="base.500" as={PiImage} boxSize={FALLBACK_ICON_SIZE} />
</Flex>
);
}

return (
<Image
src={image_url}
objectFit="cover"
objectPosition="50% 50%"
height="50px"
width="50px"
minHeight="50px"
minWidth="50px"
height={MODEL_IMAGE_THUMBNAIL_SIZE}
width={MODEL_IMAGE_THUMBNAIL_SIZE}
minHeight={MODEL_IMAGE_THUMBNAIL_SIZE}
minWidth={MODEL_IMAGE_THUMBNAIL_SIZE}
borderRadius="base"
/>
);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { Flex, Spinner, Text } from '@invoke-ai/ui-library';
import type { EntityState } from '@reduxjs/toolkit';
import { useAppSelector } from 'app/store/storeHooks';
import ScrollableContent from 'common/components/OverlayScrollbars/ScrollableContent';
import { forEach } from 'lodash-es';
import { memo } from 'react';
import { ALL_BASE_MODELS } from 'services/api/constants';
Expand Down Expand Up @@ -73,8 +74,8 @@ const ModelList = () => {
});

return (
<Flex flexDirection="column" p={4}>
<Flex flexDirection="column" maxHeight={window.innerHeight - 130} overflow="scroll">
<ScrollableContent>
<Flex flexDirection="column" w="full" h="full" gap={4}>
{/* Main Model List */}
{isLoadingMainModels && <FetchingModelsLoader loadingMessage="Loading Main..." />}
{!isLoadingMainModels && filteredMainModels.length > 0 && (
Expand Down Expand Up @@ -118,7 +119,7 @@ const ModelList = () => {
<ModelListWrapper title="T2I Adapters" modelList={filteredT2iAdapterModels} key="t2i-adapters" />
)}
</Flex>
</Flex>
</ScrollableContent>
);
};

Expand Down Expand Up @@ -148,7 +149,7 @@ const modelsFilter = <T extends AnyModelConfig>(

const FetchingModelsLoader = memo(({ loadingMessage }: { loadingMessage?: string }) => {
return (
<Flex flexDirection="column" gap={4} borderRadius={4} p={4} bg="base.800">
<Flex flexDirection="column" gap={4} borderRadius="base" p={4} bg="base.800">
<Flex justifyContent="center" alignItems="center" flexDirection="column" p={4} gap={8}>
<Spinner />
<Text variant="subtext">{loadingMessage ? loadingMessage : 'Fetching...'}</Text>
Expand Down

This file was deleted.

Loading
Loading