Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Priority training option #1183

Merged
merged 8 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ function ResourceSelectCard({

const resourceFilter = resources.find((x) => x.type === data.type);
const versions = data.versions.filter((version) => {
if (isTraining && version.baseModel === 'SDXL Lightning') return false;
if (isTraining && !['SD 1.4', 'SD 1.5', 'SDXL 1.0', 'Pony'].includes(version.baseModel))
return false;
if (canGenerate === undefined) return true;
return (
version.canGenerate === canGenerate &&
Expand Down
47 changes: 41 additions & 6 deletions src/components/Training/Form/TrainingImages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ import {
import { saveAs } from 'file-saver';
import JSZip from 'jszip';
import { isEqual } from 'lodash-es';
import React, { useEffect, useState } from 'react';
import React, { useEffect, useRef, useState } from 'react';
import { dialogStore } from '~/components/Dialog/dialogStore';
import { ImageDropzone } from '~/components/Image/ImageDropzone/ImageDropzone';
import { useSignalContext } from '~/components/Signals/SignalsProvider';
Expand All @@ -57,7 +57,11 @@ import {
useTrainingImageStore,
} from '~/store/training.store';
import { TrainingModelData } from '~/types/router';
import { showErrorNotification, showSuccessNotification } from '~/utils/notifications';
import {
showErrorNotification,
showSuccessNotification,
showWarningNotification,
} from '~/utils/notifications';
import { bytesToKB } from '~/utils/number-helpers';
import { trpc } from '~/utils/trpc';
import { isDefined } from '~/utils/type-guards';
Expand Down Expand Up @@ -95,8 +99,8 @@ const imageExts: { [key: string]: string } = {
webp: 'image/webp',
};

const maxWidth = 1024;
const maxHeight = 1024;
const maxWidth = 2048;
const maxHeight = 2048;

const createImage = (url: string): Promise<HTMLImageElement> =>
new Promise((resolve, reject) => {
Expand Down Expand Up @@ -140,6 +144,7 @@ export const TrainingFormImages = ({ model }: { model: NonNullable<TrainingModel
const [loadingZip, setLoadingZip] = useState<boolean>(false);
const [modelFileId, setModelFileId] = useState<number | undefined>(undefined);
const [selectedTags, setSelectedTags] = useState<string[]>([]);
const showImgResize = useRef(false);

const theme = useMantineTheme();
const { classes, cx } = useStyles();
Expand Down Expand Up @@ -221,11 +226,13 @@ export const TrainingFormImages = ({ model }: { model: NonNullable<TrainingModel
if (width > maxWidth) {
height = height * (maxWidth / width);
width = maxWidth;
showImgResize.current = true;
}
} else {
if (height > maxHeight) {
width = width * (maxHeight / height);
height = maxHeight;
showImgResize.current = true;
}
}

Expand Down Expand Up @@ -280,6 +287,15 @@ export const TrainingFormImages = ({ model }: { model: NonNullable<TrainingModel
})
);

if (showImgResize.current) {
showWarningNotification({
title: 'Some images resized',
message: `Max allowed image dimensions are ${maxWidth}x${maxHeight}.`,
autoClose: 5000,
});
showImgResize.current = false;
}

if (showNotif) {
if (parsedFiles.length > 0) {
showSuccessNotification({
Expand Down Expand Up @@ -316,6 +332,15 @@ export const TrainingFormImages = ({ model }: { model: NonNullable<TrainingModel
})
);

if (showImgResize.current) {
showWarningNotification({
title: 'Some images resized',
message: `Max allowed image dimensions are ${maxWidth}x${maxHeight}.`,
autoClose: 5000,
});
showImgResize.current = false;
}

const filteredFiles = newFiles.flat().filter(isDefined);
if (filteredFiles.length > MAX_FILES_ALLOWED - imageList.length) {
showErrorNotification({
Expand Down Expand Up @@ -487,7 +512,7 @@ export const TrainingFormImages = ({ model }: { model: NonNullable<TrainingModel
});

try {
await upload(
const uploadResp = await upload(
{
file: blobFile,
type: UploadType.TrainingImages,
Expand Down Expand Up @@ -526,14 +551,24 @@ export const TrainingFormImages = ({ model }: { model: NonNullable<TrainingModel
icon: <IconX size={18} />,
color: 'red',
title: 'Failed to upload archive.',
message: '',
message: 'Please try again (or contact us if it continues)',
});
}
} else {
throw new Error('Missing version data.');
}
}
);
if (!uploadResp) {
setZipping(false);
updateNotification({
id: notificationId,
icon: <IconX size={18} />,
color: 'red',
title: 'Failed to upload archive.',
message: 'Please try again (or contact us if it continues)',
});
}
} catch (e) {
setZipping(false);
updateNotification({
Expand Down
36 changes: 28 additions & 8 deletions src/components/Training/Form/TrainingSubmit.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import { CurrencyBadge } from '~/components/Currency/CurrencyBadge';
import { CurrencyIcon } from '~/components/Currency/CurrencyIcon';
import { DescriptionTable } from '~/components/DescriptionTable/DescriptionTable';
import InputResourceSelect from '~/components/ImageGeneration/GenerationForm/ResourceSelect';
import { InfoPopover } from '~/components/InfoPopover/InfoPopover';
import {
blockedCustomModels,
goBack,
Expand Down Expand Up @@ -521,6 +522,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
samplePrompt2: z.string(),
samplePrompt3: z.string(),
staging: z.boolean().optional(),
highPriority: z.boolean().optional(),
});

// @ts-ignore ignoring because the reducer will use default functions in the next step in place of actual values
Expand All @@ -529,6 +531,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
samplePrompt2: thisTrainingDetails?.samplePrompts?.[1] ?? '',
samplePrompt3: thisTrainingDetails?.samplePrompts?.[2] ?? '',
staging: false,
highPriority: false,
...(thisTrainingDetails?.params
? thisTrainingDetails.params
: trainingSettings.reduce((a, v) => ({ ...a, [v.name]: v.default }), {})),
Expand Down Expand Up @@ -560,7 +563,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
});

const watchFields = form.watch(['maxTrainEpochs', 'numRepeats', 'trainBatchSize']);
const watchFieldsBuzz = form.watch(['targetSteps']);
const watchFieldsBuzz = form.watch(['targetSteps', 'highPriority']);
const watchFieldOptimizer = form.watch('optimizerType');

// apply default overrides for base model upon selection
Expand Down Expand Up @@ -603,7 +606,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
}, [watchFields]);

useEffect(() => {
const [targetSteps] = watchFieldsBuzz;
const [targetSteps, highPriority] = watchFieldsBuzz;
const eta = calcEta({
cost: status.cost,
baseModel: formBaseModelType,
Expand All @@ -614,6 +617,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
cost: status.cost,
eta,
isCustom,
isPriority: highPriority ?? false,
});
setEtaMins(eta);
setBuzzCost(price);
Expand All @@ -636,6 +640,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
trpc.training.createRequestDryRun.useQuery(
{
baseModel: formBaseModel,
isPriority: watchFieldsBuzz[1],
// cost: debouncedEtaMins,
},
{
Expand Down Expand Up @@ -773,6 +778,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
samplePrompt2,
samplePrompt3,
staging,
highPriority,
customModelSelect, //unsent
optimizerArgs, //unsent
...paramData
Expand Down Expand Up @@ -810,6 +816,7 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
samplePrompts: [samplePrompt1, samplePrompt2, samplePrompt3],
params: paramData,
staging,
highPriority,
},
};

Expand Down Expand Up @@ -1295,10 +1302,28 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
</Accordion.Panel>
</Accordion.Item>
</Accordion>
<Group mt="lg">
<InputSwitch
name="highPriority"
label={
<Group spacing={4} noWrap>
<InfoPopover size="xs" iconProps={{ size: 16 }}>
Jump to the front of the training queue and ensure that your training run is
uninterrupted.
</InfoPopover>
<Text>High Priority</Text>
</Group>
}
labelPosition="left"
/>
{currentUser?.isModerator && (
<InputSwitch name="staging" label="Test Mode" labelPosition="left" />
)}
</Group>
<Paper
shadow="xs"
radius="sm"
mt="lg"
mt="md"
w="fit-content"
px="md"
py="xs"
Expand Down Expand Up @@ -1339,11 +1364,6 @@ export const TrainingFormSubmit = ({ model }: { model: NonNullable<TrainingModel
</>
)}
</Stack>
{currentUser?.isModerator && (
<Group position="right" my="xs">
<InputSwitch name="staging" label="Test Mode" labelPosition="left" />
</Group>
)}
<Group mt="xl" position="right">
<Button variant="default" onClick={() => goBack(model.id, thisStep)}>
Back
Expand Down
8 changes: 8 additions & 0 deletions src/pages/api/auth/impersonate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { z } from 'zod';
import { civTokenEncrypt } from '~/pages/api/auth/civ-token';
import { dbRead } from '~/server/db/client';
import { getFeatureFlags } from '~/server/services/feature-flags.service';
import { trackModActivity } from '~/server/services/moderator.service';
import { AuthedEndpoint } from '~/server/utils/endpoint-helpers';

const schema = z.object({
Expand Down Expand Up @@ -30,6 +31,13 @@ export default AuthedEndpoint(async function handler(req, res, user) {

try {
const token = civTokenEncrypt(userId.toString());

await trackModActivity(user.id, {
entityType: 'impersonate',
entityId: userId,
activity: 'on',
});

return res.status(200).json({ token });
} catch (error: unknown) {
return res.status(500).send((error as Error).message);
Expand Down
2 changes: 2 additions & 0 deletions src/server/http/orchestrator/orchestrator.types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ export namespace Orchestrator {
export type ClearAssetsJobResponse = Orchestrator.JobResponse<ClearAssetsJob>;

const imageResourceTrainingJobInputDryRunSchema = z.object({
priority: z.union([z.number(), z.enum(['high', 'normal', 'low'])]),
// interruptible: z.boolean(),
model: z.string(),
cost: z.number(),
trainingData: z.string(),
Expand Down
1 change: 1 addition & 0 deletions src/server/schema/model-version.schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ export const trainingDetailsObj = z.object({
params: trainingDetailsParams.optional(),
samplePrompts: z.array(z.string()).optional(),
staging: z.boolean().optional(),
highPriority: z.boolean().optional(),
});

export const modelVersionUpsertSchema = z.object({
Expand Down
5 changes: 5 additions & 0 deletions src/server/schema/training.schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export const createTrainingRequestSchema = z.object({
export type CreateTrainingRequestDryRunInput = z.infer<typeof createTrainingRequestDryRunSchema>;
export const createTrainingRequestDryRunSchema = z.object({
baseModel: z.string().nullable(),
isPriority: z.boolean().optional(),
// cost: z.number().optional(),
});

Expand Down Expand Up @@ -41,6 +42,8 @@ const trainingCostSchema = z.object({
hourlyCost: z.number().min(0),
baseBuzz: z.number().min(0),
customModelBuzz: z.number().min(0),
priorityBuzz: z.number().min(0),
priorityBuzzPct: z.number().min(0),
minEta: z.number().min(1),
});
export type TrainingCost = z.infer<typeof trainingCostSchema>;
Expand All @@ -64,6 +67,8 @@ export const defaultTrainingCost: TrainingCost = {
hourlyCost: 0.44,
baseBuzz: 500,
customModelBuzz: 500,
priorityBuzz: 100,
priorityBuzzPct: 0.1,
minEta: 5,
};

Expand Down
6 changes: 6 additions & 0 deletions src/server/services/moderator.service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ type ReportModActivity = {
activity: 'review';
};

type ImpersonateModActivity = {
entityType: 'impersonate';
activity: 'on' | 'off'; // off is currently not used
};

type ModActivity = {
entityId?: number | number[];
} & (
Expand All @@ -41,6 +46,7 @@ type ModActivity = {
| ImageModActivity
| ReportModActivity
| ArticleModActivity
| ImpersonateModActivity
);

export async function trackModActivity(userId: number, input: ModActivity) {
Expand Down
Loading