diff --git a/prisma/migrations/20230308164237_create_category/migration.sql b/prisma/migrations/20230308164237_create_category/migration.sql
new file mode 100644
index 0000000000..eb019a57eb
--- /dev/null
+++ b/prisma/migrations/20230308164237_create_category/migration.sql
@@ -0,0 +1,5 @@
+-- CreateEnum
+CREATE TYPE "CategoryType" AS ENUM ('RealisticModels', 'SemiRealisticModels', 'AnimeModels', 'Models', 'Characters', 'Places', 'Concepts', 'Clothings', 'Styles', 'Poses', 'QualityEnhancements', 'Others');
+
+-- AlterTable
+ALTER TABLE "Model" ADD COLUMN "category" "CategoryType" NOT NULL DEFAULT 'Models';
diff --git a/prisma/schema.prisma b/prisma/schema.prisma
index 17235f8aff..b9bdf98794 100644
--- a/prisma/schema.prisma
+++ b/prisma/schema.prisma
@@ -244,6 +244,21 @@ enum ModelType {
Poses
}
+enum CategoryType {
+ RealisticModels
+ SemiRealisticModels
+ AnimeModels
+ Models
+ Characters
+ Places
+ Concepts
+ Clothings
+ Styles
+ Poses
+ QualityEnhancements
+ Others
+}
+
enum ImportStatus {
Pending
Processing
@@ -295,6 +310,7 @@ model Model {
name String
description String?
type ModelType
+ category CategoryType @default(Models)
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
lastVersionAt DateTime?
diff --git a/src/components/Gallery/GalleryResources.tsx b/src/components/Gallery/GalleryResources.tsx
index 78946de3d4..ee071b36b3 100644
--- a/src/components/Gallery/GalleryResources.tsx
+++ b/src/components/Gallery/GalleryResources.tsx
@@ -47,9 +47,14 @@ export function GalleryResources({ imageId, modelId, reviewId }: Props) {
{connections.model.name}
+
+
{splitUppercase(connections.model.type)}
+
+ {splitUppercase(connections.model.category)}
+
{splitUppercase(data.type)}
+
+ {splitUppercase(data.category)}
+
{data.status !== ModelStatus.Published && (
{data.status}
diff --git a/src/components/InfiniteModels/InfiniteModelsFilters.tsx b/src/components/InfiniteModels/InfiniteModelsFilters.tsx
index 15a3a5c672..62049b5b95 100644
--- a/src/components/InfiniteModels/InfiniteModelsFilters.tsx
+++ b/src/components/InfiniteModels/InfiniteModelsFilters.tsx
@@ -1,6 +1,12 @@
import { create } from 'zustand';
import { useEffect } from 'react';
-import { ModelType, MetricTimeframe, CheckpointType, ModelStatus } from '@prisma/client';
+import {
+ ModelType,
+ MetricTimeframe,
+ CheckpointType,
+ ModelStatus,
+ CategoryType,
+} from '@prisma/client';
import { BrowsingMode, ModelSort } from '~/server/common/enums';
import { SelectMenu } from '~/components/SelectMenu/SelectMenu';
import { splitUppercase } from '~/utils/string-helpers';
@@ -34,6 +40,7 @@ export const useFilters = create<{
setPeriod: (period?: MetricTimeframe) => void;
setTypes: (types?: ModelType[]) => void;
setCheckpointType: (checkpointType?: CheckpointType) => void;
+ setCategories: (categories?: CategoryType[]) => void;
setBaseModels: (baseModels?: BaseModel[]) => void;
setBrowsingMode: (browsingMode?: BrowsingMode, keep?: boolean) => void;
setStatus: (status?: ModelStatus[]) => void;
@@ -64,6 +71,12 @@ export const useFilters = create<{
!!type ? setCookie('f_ckptType', type) : deleteCookie('f_ckptType');
});
},
+ setCategories: (categories) => {
+ set((state) => {
+ state.filters.categories = categories;
+ !!categories?.length ? setCookie('f_categories', categories) : deleteCookie('f_categories');
+ });
+ },
setBaseModels: (baseModels) => {
set((state) => {
state.filters.baseModels = baseModels;
@@ -159,6 +172,8 @@ export function InfiniteModelsFilter() {
const defaultBrowsingMode = user?.showNsfw ? BrowsingMode.All : BrowsingMode.SFW;
const setTypes = useFilters((state) => state.setTypes);
const types = useFilters((state) => state.filters.types ?? cookies.types ?? []);
+ const setCategories = useFilters((state) => state.setCategories);
+ const categories = useFilters((state) => state.filters.categories ?? cookies.categories ?? []);
const setStatus = useFilters((state) => state.setStatus);
const status = useFilters((state) => state.filters.status ?? cookies.status ?? []);
const setBaseModels = useFilters((state) => state.setBaseModels);
@@ -178,12 +193,14 @@ export function InfiniteModelsFilter() {
const filterLength =
types.length +
+ categories.length +
baseModels.length +
status.length +
(showNSFWToggle && browsingMode !== defaultBrowsingMode ? 1 : 0) +
(showCheckpointType && checkpointType !== 'all' ? 1 : 0);
const handleClear = () => {
setTypes([]);
+ setCategories([]);
setBaseModels([]);
setStatus([]);
setBrowsingMode(defaultBrowsingMode);
@@ -275,6 +292,20 @@ export function InfiniteModelsFilter() {
))}
+
+ setCategories(categories)}
+ multiple
+ my={4}
+ >
+ {Object.values(CategoryType).map((cat, index) => (
+
+ {splitUppercase(cat)}
+
+ ))}
+
{showCheckpointType ? (
<>
diff --git a/src/components/Model/ModelForm/ModelForm.tsx b/src/components/Model/ModelForm/ModelForm.tsx
index 39cafc1611..7bd43bf413 100644
--- a/src/components/Model/ModelForm/ModelForm.tsx
+++ b/src/components/Model/ModelForm/ModelForm.tsx
@@ -20,6 +20,7 @@ import {
Model,
ModelStatus,
ModelType,
+ CategoryType,
TagTarget,
} from '@prisma/client';
import { openConfirmModal } from '@mantine/modals';
@@ -185,6 +186,7 @@ export function ModelForm({ model }: Props) {
allowNoCredit: model?.allowNoCredit ?? true,
allowDifferentLicense: model?.allowDifferentLicense ?? true,
type: model?.type ?? ModelType.Checkpoint,
+ category: CategoryType.Models,
status: model?.status ?? ModelStatus.Published,
tagsOnModels: model?.tagsOnModels.map(({ tag }) => tag.name) ?? [],
modelVersions: model?.modelVersions.map(({ images, files, baseModel, ...version }) => ({
@@ -268,7 +270,12 @@ export function ModelForm({ model }: Props) {
}, [tagsOnModels, tags]);
const mutating = addMutation.isLoading || updateMutation.isLoading;
- const [type, allowDerivatives, status] = form.watch(['type', 'allowDerivatives', 'status']);
+ const [type, category, allowDerivatives, status] = form.watch([
+ 'type',
+ 'category',
+ 'allowDerivatives',
+ 'status',
+ ]);
const acceptsTrainedWords = ['Checkpoint', 'TextualInversion', 'LORA'].includes(type);
const isTextualInversion = type === 'TextualInversion';
@@ -498,6 +505,21 @@ export function ModelForm({ model }: Props) {
{errors.checkpointType.message}
)}
+
+
+ ({
+ label: splitUppercase(cat),
+ value: cat,
+ }))}
+ withAsterisk
+ />
+
+ {errors.category && {errors.category.message}}
+
),
},
+ {
+ label: 'Category',
+ value: (
+
+
+ {splitUppercase(model.category)}
+
+
+ ),
+ },
{
label: 'Downloads',
value: {(model.rank?.downloadCountAllTime ?? 0).toLocaleString()},
diff --git a/src/providers/CookiesProvider.tsx b/src/providers/CookiesProvider.tsx
index ff790696c7..49dafbcc00 100644
--- a/src/providers/CookiesProvider.tsx
+++ b/src/providers/CookiesProvider.tsx
@@ -5,6 +5,7 @@ import {
MetricTimeframe,
ModelStatus,
ModelType,
+ CategoryType,
} from '@prisma/client';
import React, { createContext, useContext } from 'react';
import { z } from 'zod';
@@ -23,6 +24,7 @@ export const modelFilterSchema = z.object({
sort: z.nativeEnum(ModelSort).optional(),
period: z.nativeEnum(MetricTimeframe).optional(),
types: z.nativeEnum(ModelType).array().optional(),
+ categories: z.nativeEnum(CategoryType).array().optional(),
checkpointType: z.nativeEnum(CheckpointType).optional(),
baseModels: z.enum(constants.baseModels).array().optional(),
browsingMode: z.nativeEnum(BrowsingMode).optional(),
@@ -77,6 +79,7 @@ export function parseCookies(
sort: cookies?.['f_sort'],
period: cookies?.['f_period'],
types: cookies?.['f_types'],
+ categories: cookies?.['f_categories'],
baseModels: cookies?.['f_baseModels'],
browsingMode: cookies?.['f_browsingMode'],
status: cookies?.['f_status'],
@@ -110,6 +113,7 @@ const zodParse = z
sort: z.string(),
period: z.string(),
types: z.string(),
+ categories: z.string(),
baseModels: z.string(),
browsingMode: z.string(),
status: z.string(),
diff --git a/src/server/controllers/model.controller.ts b/src/server/controllers/model.controller.ts
index 1b90af1e48..05d260fdeb 100644
--- a/src/server/controllers/model.controller.ts
+++ b/src/server/controllers/model.controller.ts
@@ -149,6 +149,7 @@ export const getModelsInfiniteHandler = async ({
id: true,
name: true,
type: true,
+ category: true,
nsfw: true,
status: true,
createdAt: true,
diff --git a/src/server/schema/model.schema.ts b/src/server/schema/model.schema.ts
index f53daa592a..bec1b18fe8 100644
--- a/src/server/schema/model.schema.ts
+++ b/src/server/schema/model.schema.ts
@@ -1,5 +1,6 @@
import {
ModelType,
+ CategoryType,
ModelStatus,
MetricTimeframe,
CommercialUse,
@@ -39,6 +40,11 @@ export const getAllModelsSchema = z.object({
.transform((rel) => (!rel ? undefined : Array.isArray(rel) ? rel : [rel]))
.optional(),
checkpointType: z.nativeEnum(CheckpointType).optional(),
+ categories: z
+ .union([z.nativeEnum(CategoryType), z.nativeEnum(CategoryType).array()])
+ .optional()
+ .transform((rel) => (!rel ? undefined : Array.isArray(rel) ? rel : [rel]))
+ .optional(),
baseModels: z
.union([z.enum(constants.baseModels), z.enum(constants.baseModels).array()])
.optional()
@@ -80,6 +86,7 @@ export const modelSchema = licensingSchema.extend({
name: z.string().min(1, 'Name cannot be empty.'),
description: getSanitizedStringSchema().nullish(),
type: z.nativeEnum(ModelType),
+ category: z.nativeEnum(CategoryType),
status: z.nativeEnum(ModelStatus),
checkpointType: z.nativeEnum(CheckpointType).nullish(),
tagsOnModels: z.array(tagSchema).nullish(),
diff --git a/src/server/selectors/model.selector.ts b/src/server/selectors/model.selector.ts
index 6fade71970..515fdbdc73 100644
--- a/src/server/selectors/model.selector.ts
+++ b/src/server/selectors/model.selector.ts
@@ -114,6 +114,7 @@ export const modelWithDetailsSelect = (includeNSFW = true, user?: SessionUser) =
poi: true,
nsfw: true,
type: true,
+ category: true,
updatedAt: true,
deletedAt: true,
status: true,
diff --git a/src/server/services/image.service.ts b/src/server/services/image.service.ts
index 2b93c8a51a..2309129ad0 100644
--- a/src/server/services/image.service.ts
+++ b/src/server/services/image.service.ts
@@ -188,6 +188,7 @@ export const getImageConnectionsById = ({ id, modelId, reviewId }: GetImageConne
id: true,
name: true,
type: true,
+ category: true,
rank: {
select: {
downloadCountAllTime: true,
diff --git a/src/server/services/model.service.ts b/src/server/services/model.service.ts
index edaaa3a73d..fe7b65373d 100644
--- a/src/server/services/model.service.ts
+++ b/src/server/services/model.service.ts
@@ -49,6 +49,7 @@ export const getModels = async ({
username,
baseModels,
types,
+ categories,
sort,
period = MetricTimeframe.AllTime,
rating,
@@ -139,6 +140,7 @@ export const getModels = async ({
: undefined,
user: username || user ? { username: username ?? user } : undefined,
type: types?.length ? { in: types } : undefined,
+ category: categories?.length ? { in: categories } : undefined,
nsfw:
browsingMode === BrowsingMode.All
? undefined