diff --git a/prisma/migrations/20230308164237_create_category/migration.sql b/prisma/migrations/20230308164237_create_category/migration.sql new file mode 100644 index 0000000000..eb019a57eb --- /dev/null +++ b/prisma/migrations/20230308164237_create_category/migration.sql @@ -0,0 +1,5 @@ +-- CreateEnum +CREATE TYPE "CategoryType" AS ENUM ('RealisticModels', 'SemiRealisticModels', 'AnimeModels', 'Models', 'Characters', 'Places', 'Concepts', 'Clothings', 'Styles', 'Poses', 'QualityEnhancements', 'Others'); + +-- AlterTable +ALTER TABLE "Model" ADD COLUMN "category" "CategoryType" NOT NULL DEFAULT 'Models'; diff --git a/prisma/schema.prisma b/prisma/schema.prisma index 17235f8aff..b9bdf98794 100644 --- a/prisma/schema.prisma +++ b/prisma/schema.prisma @@ -244,6 +244,21 @@ enum ModelType { Poses } +enum CategoryType { + RealisticModels + SemiRealisticModels + AnimeModels + Models + Characters + Places + Concepts + Clothings + Styles + Poses + QualityEnhancements + Others +} + enum ImportStatus { Pending Processing @@ -295,6 +310,7 @@ model Model { name String description String? type ModelType + category CategoryType @default(Models) createdAt DateTime @default(now()) updatedAt DateTime @updatedAt lastVersionAt DateTime? diff --git a/src/components/Gallery/GalleryResources.tsx b/src/components/Gallery/GalleryResources.tsx index 78946de3d4..ee071b36b3 100644 --- a/src/components/Gallery/GalleryResources.tsx +++ b/src/components/Gallery/GalleryResources.tsx @@ -47,9 +47,14 @@ export function GalleryResources({ imageId, modelId, reviewId }: Props) { {connections.model.name} + + {splitUppercase(connections.model.type)} + + {splitUppercase(connections.model.category)} + {splitUppercase(data.type)} + + {splitUppercase(data.category)} + {data.status !== ModelStatus.Published && ( {data.status} diff --git a/src/components/InfiniteModels/InfiniteModelsFilters.tsx b/src/components/InfiniteModels/InfiniteModelsFilters.tsx index 15a3a5c672..62049b5b95 100644 --- a/src/components/InfiniteModels/InfiniteModelsFilters.tsx +++ b/src/components/InfiniteModels/InfiniteModelsFilters.tsx @@ -1,6 +1,12 @@ import { create } from 'zustand'; import { useEffect } from 'react'; -import { ModelType, MetricTimeframe, CheckpointType, ModelStatus } from '@prisma/client'; +import { + ModelType, + MetricTimeframe, + CheckpointType, + ModelStatus, + CategoryType, +} from '@prisma/client'; import { BrowsingMode, ModelSort } from '~/server/common/enums'; import { SelectMenu } from '~/components/SelectMenu/SelectMenu'; import { splitUppercase } from '~/utils/string-helpers'; @@ -34,6 +40,7 @@ export const useFilters = create<{ setPeriod: (period?: MetricTimeframe) => void; setTypes: (types?: ModelType[]) => void; setCheckpointType: (checkpointType?: CheckpointType) => void; + setCategories: (categories?: CategoryType[]) => void; setBaseModels: (baseModels?: BaseModel[]) => void; setBrowsingMode: (browsingMode?: BrowsingMode, keep?: boolean) => void; setStatus: (status?: ModelStatus[]) => void; @@ -64,6 +71,12 @@ export const useFilters = create<{ !!type ? setCookie('f_ckptType', type) : deleteCookie('f_ckptType'); }); }, + setCategories: (categories) => { + set((state) => { + state.filters.categories = categories; + !!categories?.length ? setCookie('f_categories', categories) : deleteCookie('f_categories'); + }); + }, setBaseModels: (baseModels) => { set((state) => { state.filters.baseModels = baseModels; @@ -159,6 +172,8 @@ export function InfiniteModelsFilter() { const defaultBrowsingMode = user?.showNsfw ? BrowsingMode.All : BrowsingMode.SFW; const setTypes = useFilters((state) => state.setTypes); const types = useFilters((state) => state.filters.types ?? cookies.types ?? []); + const setCategories = useFilters((state) => state.setCategories); + const categories = useFilters((state) => state.filters.categories ?? cookies.categories ?? []); const setStatus = useFilters((state) => state.setStatus); const status = useFilters((state) => state.filters.status ?? cookies.status ?? []); const setBaseModels = useFilters((state) => state.setBaseModels); @@ -178,12 +193,14 @@ export function InfiniteModelsFilter() { const filterLength = types.length + + categories.length + baseModels.length + status.length + (showNSFWToggle && browsingMode !== defaultBrowsingMode ? 1 : 0) + (showCheckpointType && checkpointType !== 'all' ? 1 : 0); const handleClear = () => { setTypes([]); + setCategories([]); setBaseModels([]); setStatus([]); setBrowsingMode(defaultBrowsingMode); @@ -275,6 +292,20 @@ export function InfiniteModelsFilter() { ))} + + setCategories(categories)} + multiple + my={4} + > + {Object.values(CategoryType).map((cat, index) => ( + + {splitUppercase(cat)} + + ))} + {showCheckpointType ? ( <> diff --git a/src/components/Model/ModelForm/ModelForm.tsx b/src/components/Model/ModelForm/ModelForm.tsx index 39cafc1611..7bd43bf413 100644 --- a/src/components/Model/ModelForm/ModelForm.tsx +++ b/src/components/Model/ModelForm/ModelForm.tsx @@ -20,6 +20,7 @@ import { Model, ModelStatus, ModelType, + CategoryType, TagTarget, } from '@prisma/client'; import { openConfirmModal } from '@mantine/modals'; @@ -185,6 +186,7 @@ export function ModelForm({ model }: Props) { allowNoCredit: model?.allowNoCredit ?? true, allowDifferentLicense: model?.allowDifferentLicense ?? true, type: model?.type ?? ModelType.Checkpoint, + category: CategoryType.Models, status: model?.status ?? ModelStatus.Published, tagsOnModels: model?.tagsOnModels.map(({ tag }) => tag.name) ?? [], modelVersions: model?.modelVersions.map(({ images, files, baseModel, ...version }) => ({ @@ -268,7 +270,12 @@ export function ModelForm({ model }: Props) { }, [tagsOnModels, tags]); const mutating = addMutation.isLoading || updateMutation.isLoading; - const [type, allowDerivatives, status] = form.watch(['type', 'allowDerivatives', 'status']); + const [type, category, allowDerivatives, status] = form.watch([ + 'type', + 'category', + 'allowDerivatives', + 'status', + ]); const acceptsTrainedWords = ['Checkpoint', 'TextualInversion', 'LORA'].includes(type); const isTextualInversion = type === 'TextualInversion'; @@ -498,6 +505,21 @@ export function ModelForm({ model }: Props) { {errors.checkpointType.message} )} + + + ({ + label: splitUppercase(cat), + value: cat, + }))} + withAsterisk + /> + + {errors.category && {errors.category.message}} + ), }, + { + label: 'Category', + value: ( + + + {splitUppercase(model.category)} + + + ), + }, { label: 'Downloads', value: {(model.rank?.downloadCountAllTime ?? 0).toLocaleString()}, diff --git a/src/providers/CookiesProvider.tsx b/src/providers/CookiesProvider.tsx index ff790696c7..49dafbcc00 100644 --- a/src/providers/CookiesProvider.tsx +++ b/src/providers/CookiesProvider.tsx @@ -5,6 +5,7 @@ import { MetricTimeframe, ModelStatus, ModelType, + CategoryType, } from '@prisma/client'; import React, { createContext, useContext } from 'react'; import { z } from 'zod'; @@ -23,6 +24,7 @@ export const modelFilterSchema = z.object({ sort: z.nativeEnum(ModelSort).optional(), period: z.nativeEnum(MetricTimeframe).optional(), types: z.nativeEnum(ModelType).array().optional(), + categories: z.nativeEnum(CategoryType).array().optional(), checkpointType: z.nativeEnum(CheckpointType).optional(), baseModels: z.enum(constants.baseModels).array().optional(), browsingMode: z.nativeEnum(BrowsingMode).optional(), @@ -77,6 +79,7 @@ export function parseCookies( sort: cookies?.['f_sort'], period: cookies?.['f_period'], types: cookies?.['f_types'], + categories: cookies?.['f_categories'], baseModels: cookies?.['f_baseModels'], browsingMode: cookies?.['f_browsingMode'], status: cookies?.['f_status'], @@ -110,6 +113,7 @@ const zodParse = z sort: z.string(), period: z.string(), types: z.string(), + categories: z.string(), baseModels: z.string(), browsingMode: z.string(), status: z.string(), diff --git a/src/server/controllers/model.controller.ts b/src/server/controllers/model.controller.ts index 1b90af1e48..05d260fdeb 100644 --- a/src/server/controllers/model.controller.ts +++ b/src/server/controllers/model.controller.ts @@ -149,6 +149,7 @@ export const getModelsInfiniteHandler = async ({ id: true, name: true, type: true, + category: true, nsfw: true, status: true, createdAt: true, diff --git a/src/server/schema/model.schema.ts b/src/server/schema/model.schema.ts index f53daa592a..bec1b18fe8 100644 --- a/src/server/schema/model.schema.ts +++ b/src/server/schema/model.schema.ts @@ -1,5 +1,6 @@ import { ModelType, + CategoryType, ModelStatus, MetricTimeframe, CommercialUse, @@ -39,6 +40,11 @@ export const getAllModelsSchema = z.object({ .transform((rel) => (!rel ? undefined : Array.isArray(rel) ? rel : [rel])) .optional(), checkpointType: z.nativeEnum(CheckpointType).optional(), + categories: z + .union([z.nativeEnum(CategoryType), z.nativeEnum(CategoryType).array()]) + .optional() + .transform((rel) => (!rel ? undefined : Array.isArray(rel) ? rel : [rel])) + .optional(), baseModels: z .union([z.enum(constants.baseModels), z.enum(constants.baseModels).array()]) .optional() @@ -80,6 +86,7 @@ export const modelSchema = licensingSchema.extend({ name: z.string().min(1, 'Name cannot be empty.'), description: getSanitizedStringSchema().nullish(), type: z.nativeEnum(ModelType), + category: z.nativeEnum(CategoryType), status: z.nativeEnum(ModelStatus), checkpointType: z.nativeEnum(CheckpointType).nullish(), tagsOnModels: z.array(tagSchema).nullish(), diff --git a/src/server/selectors/model.selector.ts b/src/server/selectors/model.selector.ts index 6fade71970..515fdbdc73 100644 --- a/src/server/selectors/model.selector.ts +++ b/src/server/selectors/model.selector.ts @@ -114,6 +114,7 @@ export const modelWithDetailsSelect = (includeNSFW = true, user?: SessionUser) = poi: true, nsfw: true, type: true, + category: true, updatedAt: true, deletedAt: true, status: true, diff --git a/src/server/services/image.service.ts b/src/server/services/image.service.ts index 2b93c8a51a..2309129ad0 100644 --- a/src/server/services/image.service.ts +++ b/src/server/services/image.service.ts @@ -188,6 +188,7 @@ export const getImageConnectionsById = ({ id, modelId, reviewId }: GetImageConne id: true, name: true, type: true, + category: true, rank: { select: { downloadCountAllTime: true, diff --git a/src/server/services/model.service.ts b/src/server/services/model.service.ts index edaaa3a73d..fe7b65373d 100644 --- a/src/server/services/model.service.ts +++ b/src/server/services/model.service.ts @@ -49,6 +49,7 @@ export const getModels = async ({ username, baseModels, types, + categories, sort, period = MetricTimeframe.AllTime, rating, @@ -139,6 +140,7 @@ export const getModels = async ({ : undefined, user: username || user ? { username: username ?? user } : undefined, type: types?.length ? { in: types } : undefined, + category: categories?.length ? { in: categories } : undefined, nsfw: browsingMode === BrowsingMode.All ? undefined