From 4485a8f69c2072aadd6ac038b77d2c418520498c Mon Sep 17 00:00:00 2001 From: Michael Schmidt Date: Fri, 28 Jul 2023 04:28:36 +0200 Subject: [PATCH] Upgrade to Navi 0.6.0 (#2001) --- .../image_dimension/resize/resize_to_side.py | 12 +- package-lock.json | 14 +-- package.json | 2 +- src/common/nodes/TypeState.ts | 5 +- src/common/types/chainner-builtin.ts | 85 +++++++------ src/common/types/chainner-scope.ts | 15 +-- src/common/types/explain.ts | 19 ++- src/common/types/json.ts | 67 +---------- src/common/types/mismatch.ts | 43 ++++--- src/common/types/pretty.ts | 113 ++++++++++++------ src/common/types/util.ts | 82 ++++++------- src/renderer/components/TypeTag.tsx | 86 ++++++------- .../components/inputs/DirectoryInput.tsx | 16 +-- 13 files changed, 274 insertions(+), 285 deletions(-) diff --git a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py index 954a990ef..738cb4c00 100644 --- a/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py +++ b/backend/src/packages/chaiNNer_standard/image_dimension/resize/resize_to_side.py @@ -135,25 +135,25 @@ def compareCondition(b: uint): bool { SideSelection::Width => if compareCondition(w) { same } else { Size { width: target, - height: max(int & round((target / w) * h), 1) + height: max(round((target / w) * h), 1) } }, SideSelection::Height => if compareCondition(h) { same } else { Size { - width: max(int & round((target / h) * w), 1), + width: max(round((target / h) * w), 1), height: target } }, SideSelection::ShorterSide => if compareCondition(min(h, w)) { same } else { Size { - width: max(int & round((target / min(h, w)) * w), 1), - height: max(int & round((target / min(h, w)) * h), 1) + width: max(round((target / min(h, w)) * w), 1), + height: max(round((target / min(h, w)) * h), 1) } }, SideSelection::LongerSide => if compareCondition(max(h, w)) { same } else { Size { - width: max(int & round((target / max(h, w)) * w), 1), - height: max(int & round((target / max(h, w)) * h), 1) + width: max(round((target / max(h, w)) * w), 1), + height: max(round((target / max(h, w)) * h), 1) } }, }; diff --git a/package-lock.json b/package-lock.json index 27110c150..d0ef80879 100644 --- a/package-lock.json +++ b/package-lock.json @@ -10,7 +10,7 @@ "license": "GPLv3", "dependencies": { "@babel/plugin-transform-react-jsx": "^7.17.12", - "@chainner/navi": "^0.5.0", + "@chainner/navi": "^0.6.0", "@chakra-ui/icons": "^2.0.11", "@chakra-ui/react": "^2.3.5", "@emotion/react": "^11.9.0", @@ -953,9 +953,9 @@ "dev": true }, "node_modules/@chainner/navi": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.5.0.tgz", - "integrity": "sha512-90WKzq27kzUmZhFaJiwsHJjvTSphVhMXJ5sSGRt8rgOynYCaiciUEZJZWwGukkFCuRbNeoD9kBHRHuvop8hA6A==" + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.6.0.tgz", + "integrity": "sha512-R3VtwlVArE9ngIduLqMAd8GOOEotVR/Rgy4+d5UobDPIj7c/5Fs9ZXb3j7SrAxlOtRgzrD979SUYuLT0n4+WOA==" }, "node_modules/@chakra-ui/accordion": { "version": "2.1.1", @@ -25845,9 +25845,9 @@ "dev": true }, "@chainner/navi": { - "version": "0.5.0", - "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.5.0.tgz", - "integrity": "sha512-90WKzq27kzUmZhFaJiwsHJjvTSphVhMXJ5sSGRt8rgOynYCaiciUEZJZWwGukkFCuRbNeoD9kBHRHuvop8hA6A==" + "version": "0.6.0", + "resolved": "https://registry.npmjs.org/@chainner/navi/-/navi-0.6.0.tgz", + "integrity": "sha512-R3VtwlVArE9ngIduLqMAd8GOOEotVR/Rgy4+d5UobDPIj7c/5Fs9ZXb3j7SrAxlOtRgzrD979SUYuLT0n4+WOA==" }, "@chakra-ui/accordion": { "version": "2.1.1", diff --git a/package.json b/package.json index 4f81911a2..eaa43d041 100644 --- a/package.json +++ b/package.json @@ -104,7 +104,7 @@ }, "dependencies": { "@babel/plugin-transform-react-jsx": "^7.17.12", - "@chainner/navi": "^0.5.0", + "@chainner/navi": "^0.6.0", "@chakra-ui/icons": "^2.0.11", "@chakra-ui/react": "^2.3.5", "@emotion/react": "^11.9.0", diff --git a/src/common/nodes/TypeState.ts b/src/common/nodes/TypeState.ts index 96a6eb9df..b19028657 100644 --- a/src/common/nodes/TypeState.ts +++ b/src/common/nodes/TypeState.ts @@ -1,7 +1,8 @@ -import { EvaluationError, NonNeverType, StructType, Type, isSameType } from '@chainner/navi'; +import { EvaluationError, NonNeverType, Type, isSameType } from '@chainner/navi'; import { EdgeData, InputId, NodeData, OutputId, SchemaId } from '../common-types'; import { log } from '../log'; import { FunctionDefinition, FunctionInstance } from '../types/function'; +import { nullType } from '../types/util'; import { EMPTY_MAP } from '../util'; import { EdgeState } from './EdgeState'; import type { Edge, Node } from 'reactflow'; @@ -95,7 +96,7 @@ export class TypeState { } if (inputValue === undefined && definition.inputNullable.has(id)) { - return new StructType('null'); + return nullType; } return undefined; diff --git a/src/common/types/chainner-builtin.ts b/src/common/types/chainner-builtin.ts index 782b0620f..eabb7bedd 100644 --- a/src/common/types/chainner-builtin.ts +++ b/src/common/types/chainner-builtin.ts @@ -1,21 +1,22 @@ import { Arg, Int, - IntIntervalType, Intrinsic, NeverType, NumberPrimitive, StringLiteralType, StringPrimitive, StringType, + StructInstanceType, StructType, - StructTypeField, + createInstance, + getStructDescriptor, handleNumberLiterals, intersect, literal, wrapQuaternary, + wrapScopedUnary, wrapTernary, - wrapUnary, } from '@chainner/navi'; import path from 'path'; import { ColorJson } from '../common-types'; @@ -386,48 +387,46 @@ export const padCenter = wrapTernary((filePath: StringPrimitive) => { - if (filePath.type === 'literal') { - const base = path.basename(filePath.value); - const ext = path.extname(base); - const basename = ext ? base.slice(0, -ext.length) : base; - return new StructType('SplitFilePath', [ - new StructTypeField( - 'dir', - new StructType('Directory', [ - new StructTypeField('path', literal(path.dirname(filePath.value))), - ]) - ), - new StructTypeField('basename', literal(basename)), - new StructTypeField('ext', literal(ext)), - ]); +export const splitFilePath = wrapScopedUnary( + (scope, filePath: StringPrimitive): StructInstanceType => { + const splitFilePathDesc = getStructDescriptor(scope, 'SplitFilePath'); + const directoryDesc = getStructDescriptor(scope, 'Directory'); + + if (filePath.type === 'literal') { + const base = path.basename(filePath.value); + const ext = path.extname(base); + const basename = ext ? base.slice(0, -ext.length) : base; + + return createInstance(splitFilePathDesc, { + dir: createInstance(directoryDesc, { + path: literal(path.dirname(filePath.value)), + }), + basename: literal(basename), + ext: literal(ext), + }); + } + return createInstance(splitFilePathDesc); } - return new StructType('SplitFilePath', [ - new StructTypeField( - 'dir', - new StructType('Directory', [new StructTypeField('path', StringType.instance)]) - ), - new StructTypeField('basename', StringType.instance), - new StructTypeField('ext', StringType.instance), - ]); -}); +); -export const parseColorJson = wrapUnary((json) => { - if (json.type === 'literal') { - try { - const value = JSON.parse(json.value) as unknown; - if (value && typeof value === 'object' && 'kind' in value && 'values' in value) { - const color = value as ColorJson; - return new StructType('Color', [ - new StructTypeField('channels', literal(color.values.length)), - ]); +export const parseColorJson = wrapScopedUnary( + (scope, json: StringPrimitive): Arg => { + const colorDesc = getStructDescriptor(scope, 'Color'); + + if (json.type === 'literal') { + try { + const value = JSON.parse(json.value) as unknown; + if (value && typeof value === 'object' && 'kind' in value && 'values' in value) { + const color = value as ColorJson; + return createInstance(colorDesc, { + channels: literal(color.values.length), + }); + } + } catch { + // noop } - } catch { - // noop + return NeverType.instance; } - return NeverType.instance; + return createInstance(colorDesc); } - return new StructType('Color', [ - new StructTypeField('channels', new IntIntervalType(1, Infinity)), - ]); -}); +); diff --git a/src/common/types/chainner-scope.ts b/src/common/types/chainner-scope.ts index 2a90fc3b8..299c53478 100644 --- a/src/common/types/chainner-scope.ts +++ b/src/common/types/chainner-scope.ts @@ -6,6 +6,7 @@ import { SourceDocument, Type, globalScope, + makeScoped, parseDefinitions, } from '@chainner/navi'; import { lazy } from '../util'; @@ -151,12 +152,12 @@ intrinsic def parseColorJson(json: string): Color; export const getChainnerScope = lazy((): Scope => { const builder = new ScopeBuilder('Chainner scope', globalScope); - const intrinsic: Record Type> = { - formatPattern: formatTextPattern, - regexReplace, - padStart, - padEnd, - padCenter, + const intrinsic: Record Type> = { + formatPattern: makeScoped(formatTextPattern), + regexReplace: makeScoped(regexReplace), + padStart: makeScoped(padStart), + padEnd: makeScoped(padEnd), + padCenter: makeScoped(padCenter), splitFilePath, parseColorJson, }; @@ -167,7 +168,7 @@ export const getChainnerScope = lazy((): Scope => { if (!(d.name in intrinsic)) { throw new Error(`Unable to find definition for intrinsic ${d.name}`); } - const fn = intrinsic[d.name] as (...args: Type[]) => Type; + const fn = intrinsic[d.name] as (scope: Scope, ...args: Type[]) => Type; builder.add(IntrinsicFunctionDefinition.from(d, fn)); } else { builder.add(d); diff --git a/src/common/types/explain.ts b/src/common/types/explain.ts index 773ba40b2..838a6e13f 100644 --- a/src/common/types/explain.ts +++ b/src/common/types/explain.ts @@ -3,13 +3,14 @@ import { NumberPrimitive, NumericLiteralType, StringPrimitive, - StructType, + StructValueType, Type, UnionType, ValueType, + isStructInstance, } from '@chainner/navi'; import { joinEnglish } from '../util'; -import { IntNumberType, isColor, isDirectory, isImage } from './util'; +import { IntNumberType, getFields, isColor, isDirectory, isImage } from './util'; const isInt = (n: Type, min = -Infinity, max = Infinity): n is IntIntervalType => { return n.underlying === 'number' && n.type === 'int-interval' && n.min === min && n.max === max; @@ -77,17 +78,14 @@ const explainString = (s: StringPrimitive): string | undefined => { if (s.excluded.size === 1 && s.excluded.has('')) return 'a non-empty string'; }; -const explainStruct = (s: StructType, options: ExplainOptions): string | undefined => { +const explainStruct = (s: StructValueType, options: ExplainOptions): string | undefined => { const detailed = (base: string | undefined, detail: string): string | undefined => { if (options.detailed && base) return `${base} ${detail}`; return base; }; if (isImage(s)) { - const width = s.fields[0].type; - const height = s.fields[1].type; - const channels = s.fields[2].type; - + const { width, height, channels } = getFields(s); if (isInt(width, 1) && isInt(height, 1)) { if (isInt(channels, 1)) return detailed('an image', 'of any size and any colorspace'); return detailed(formatChannelNumber(channels, 'image'), 'of any size'); @@ -95,18 +93,17 @@ const explainStruct = (s: StructType, options: ExplainOptions): string | undefin } if (isColor(s)) { - const channels = s.fields[0].type; - + const { channels } = getFields(s); if (isInt(channels, 1)) return detailed('a color', 'of any colorspace'); return formatChannelNumber(channels, 'color'); } if (isDirectory(s)) { - const path = s.fields[0].type; + const { path } = getFields(s); if (path.type === 'string') return 'a directory path'; } - if (s.name === 'Seed') { + if (isStructInstance(s) && s.descriptor.name === 'Seed') { return 'a seed (for randomness)'; } }; diff --git a/src/common/types/json.ts b/src/common/types/json.ts index a2a33998b..21caa07bd 100644 --- a/src/common/types/json.ts +++ b/src/common/types/json.ts @@ -1,4 +1,5 @@ import { + Bounds, Expression, FieldAccessExpression, FunctionCallExpression, @@ -94,12 +95,6 @@ export interface MatchExpressionJson { arms: MatchArmJson[]; } -const toNumberJson = (number: number): NumberJson => { - if (Number.isNaN(number)) return 'NaN'; - if (number === Infinity) return 'inf'; - if (number === -Infinity) return '-inf'; - return number; -}; const fromNumberJson = (number: NumberJson): number => { if (number === 'NaN') return NaN; if (number === 'inf') return Infinity; @@ -107,64 +102,6 @@ const fromNumberJson = (number: NumberJson): number => { return number; }; -export const toJson = (e: Expression): ExpressionJson => { - switch (e.type) { - case 'any': - return 'any'; - case 'never': - return 'never'; - case 'number': - return 'number'; - case 'string': - return 'string'; - case 'interval': - return { type: 'interval', min: toNumberJson(e.min), max: toNumberJson(e.max) }; - case 'int-interval': - return { type: 'int-interval', min: toNumberJson(e.min), max: toNumberJson(e.max) }; - case 'literal': - if (e.underlying === 'number') { - return { type: 'numeric-literal', value: toNumberJson(e.value) }; - } - return { type: 'string-literal', value: e.value }; - case 'union': - return { type: 'union', items: e.items.map(toJson) }; - case 'intersection': - return { type: 'intersection', items: e.items.map(toJson) }; - case 'struct': - return { - type: 'named', - name: e.name, - fields: Object.fromEntries(e.fields.map((f) => [e.name, toJson(f.type)])), - }; - case 'named': - return { - type: 'named', - name: e.name, - }; - case 'field-access': - return { type: 'field-access', of: toJson(e.of), field: e.field }; - case 'function-call': - return { type: 'function-call', name: e.functionName, args: e.args.map(toJson) }; - case 'match': { - return { - type: 'match', - of: toJson(e.of), - arms: e.arms.map((a) => ({ - pattern: toJson(a.pattern), - binding: a.binding, - to: toJson(a.to), - })), - }; - } - case 'scope': - throw new Error('Converting scoped expressions to JSON is currently not supported.'); - case 'inverted-set': - throw new Error('Converting scoped expressions to JSON is currently not supported.'); - default: - return assertNever(e); - } -}; - export const fromJson = (e: ExpressionJson): Expression => { if (typeof e === 'boolean') { return new NamedExpression(e ? 'true' : 'false'); @@ -186,7 +123,7 @@ export const fromJson = (e: ExpressionJson): Expression => { case 'string-literal': return new StringLiteralType(e.value); case 'interval': - return new IntervalType(fromNumberJson(e.min), fromNumberJson(e.max)); + return new IntervalType(fromNumberJson(e.min), fromNumberJson(e.max), Bounds.Inclusive); case 'int-interval': return new IntIntervalType(fromNumberJson(e.min), fromNumberJson(e.max)); case 'union': diff --git a/src/common/types/mismatch.ts b/src/common/types/mismatch.ts index 434dbcc87..3b32ae1cd 100644 --- a/src/common/types/mismatch.ts +++ b/src/common/types/mismatch.ts @@ -1,16 +1,17 @@ import { IntersectionExpression, NamedExpression, - StructType, + StructInstanceType, Type, evaluate, isDisjointWith, + isStructInstance, } from '@chainner/navi'; import { assertNever } from '../util'; import { getChainnerScope } from './chainner-scope'; import { explain, formatChannelNumber } from './explain'; import { prettyPrintType } from './pretty'; -import { isColor, isImage } from './util'; +import { getFields, isColor, isImage } from './util'; export type AssignmentErrorTrace = FieldAssignmentError | GeneralAssignmentError; export interface GeneralAssignmentError { @@ -20,8 +21,8 @@ export interface GeneralAssignmentError { } export interface FieldAssignmentError { type: 'Field'; - assigned: StructType; - definition: StructType; + assigned: StructInstanceType; + definition: StructInstanceType; field: string; inner: AssignmentErrorTrace; } @@ -36,22 +37,22 @@ export const generateAssignmentErrorTrace = ( } if ( - assigned.type === 'struct' && - definition.type === 'struct' && - assigned.name === definition.name + isStructInstance(assigned) && + isStructInstance(definition) && + assigned.descriptor === definition.descriptor ) { // find the first field that causes the mismatch for (let i = 0; i < assigned.fields.length; i += 1) { const a = assigned.fields[i]; const d = definition.fields[i]; - const inner = generateAssignmentErrorTrace(a.type, d.type); + const inner = generateAssignmentErrorTrace(a, d); if (inner) { return { type: 'Field', assigned, definition, - field: a.name, + field: assigned.descriptor.fields[i].name, inner, }; } @@ -69,8 +70,18 @@ const shortTypeNames = (t: Type): Set => { return new Set([t.underlying]); case 'never': return new Set(); - case 'struct': - return new Set([t.name]); + case 'struct': { + switch (t.type) { + case 'instance': + return new Set([t.descriptor.name]); + case 'struct': + return new Set([t.toString()]); + case 'inverted-set': + return new Set([`not(${[...t.excluded].map((d) => d.name).join(' | ')})`]); + default: + return assertNever(t); + } + } case 'union': return new Set(t.items.flatMap((i) => [...shortTypeNames(i)])); default: @@ -115,7 +126,7 @@ export const printErrorTrace = (trace: AssignmentErrorTrace): string[] => { prettyPrintType ); return [ - `The **${trace.assigned.name}** types are incompatible because **${trace.field}: ${a}** is not connectable with **${trace.field}: ${d}**.`, + `The **${trace.assigned.descriptor.name}** types are incompatible because **${trace.field}: ${a}** is not connectable with **${trace.field}: ${d}**.`, ]; } @@ -139,8 +150,8 @@ export const simpleError = ( ); if (isImage(d)) { - const aChannels = assigned.fields[2].type; - const dChannels = d.fields[2].type; + const aChannels = getFields(assigned).channels; + const dChannels = getFields(d).channels; if (isDisjointWith(aChannels, dChannels)) { const aString = formatChannelNumber(aChannels, 'image'); @@ -163,8 +174,8 @@ export const simpleError = ( ); if (isColor(d)) { - const aChannels = assigned.fields[0].type; - const dChannels = d.fields[0].type; + const aChannels = getFields(assigned).channels; + const dChannels = getFields(d).channels; if (isDisjointWith(aChannels, dChannels)) { const aString = formatChannelNumber(aChannels, 'color'); diff --git a/src/common/types/pretty.ts b/src/common/types/pretty.ts index 9fcc742bb..0c577b608 100644 --- a/src/common/types/pretty.ts +++ b/src/common/types/pretty.ts @@ -1,20 +1,23 @@ +/* eslint-disable @typescript-eslint/no-use-before-define */ /* eslint-disable no-continue */ -import { Type, ValueType } from '@chainner/navi'; +import { + NumberPrimitive, + StringPrimitive, + StructValueType, + Type, + UnionType, + ValueType, +} from '@chainner/navi'; import { assertNever } from '../util'; -export const prettyPrintType = (type: Type): string => { +const prettyPrintNumber = (type: NumberPrimitive): string => { switch (type.type) { - case 'any': - case 'never': case 'literal': case 'number': - case 'string': case 'interval': + case 'non-int-interval': return type.toString(); - case 'inverted-set': - return `not(${[...type.excluded].map((s) => JSON.stringify(s)).join(' | ')})`; - case 'int-interval': if (type.min === -Infinity && type.max === Infinity) { return 'int'; @@ -27,38 +30,80 @@ export const prettyPrintType = (type: Type): string => { } return type.toString(); - case 'union': { - const literals: number[] = []; - const other: ValueType[] = []; - for (const item of type.items) { - if (item.underlying === 'number') { - if (item.type === 'literal' && Number.isFinite(item.value)) { - literals.push(item.value); - continue; - } - if (item.type === 'int-interval' && item.min + 1 === item.max) { - literals.push(item.min); - literals.push(item.max); - continue; - } - } - other.push(item); - } - - const union = [...literals, ...other.map(prettyPrintType)].join(' | '); + default: + return assertNever(type); + } +}; +const prettyPrintString = (type: StringPrimitive): string => { + switch (type.type) { + case 'literal': + case 'string': + return type.toString(); - // hacky way to detect boolean - if (union === 'false | true') return 'bool'; + case 'inverted-set': + return `not(${[...type.excluded].map((s) => JSON.stringify(s)).join(' | ')})`; - return union; + default: + return assertNever(type); + } +}; +const prettyPrintStruct = (type: StructValueType): string => { + switch (type.type) { + case 'instance': { + if (type.fields.length === 0) return type.descriptor.name; + const fields = type.descriptor.fields + .map((f, i) => `${f.name}: ${prettyPrintType(type.fields[i])}`) + .join(', '); + return `${type.descriptor.name} { ${fields} }`; } + case 'inverted-set': + return `not(${[...type.excluded].map((s) => s.name).join(' | ')})`; case 'struct': - if (type.fields.length === 0) return type.name; - return `${type.name} { ${type.fields - .map((f) => `${f.name}: ${prettyPrintType(f.type)}`) - .join(', ')} }`; + return type.toString(); + + default: + return assertNever(type); + } +}; +const prettyPrintUnion = (type: UnionType): string => { + const literals: number[] = []; + const other: ValueType[] = []; + for (const item of type.items) { + if (item.underlying === 'number') { + if (item.type === 'literal' && Number.isFinite(item.value)) { + literals.push(item.value); + continue; + } + if (item.type === 'int-interval' && item.min + 1 === item.max) { + literals.push(item.min); + literals.push(item.max); + continue; + } + } + other.push(item); + } + + const union = [...literals, ...other.map(prettyPrintType)].join(' | '); + + // hacky way to detect boolean + if (union === 'false | true') return 'bool'; + return union; +}; +export const prettyPrintType = (type: Type): string => { + switch (type.underlying) { + case 'any': + case 'never': + return type.toString(); + case 'number': + return prettyPrintNumber(type); + case 'string': + return prettyPrintString(type); + case 'struct': + return prettyPrintStruct(type); + case 'union': + return prettyPrintUnion(type); default: return assertNever(type); } diff --git a/src/common/types/util.ts b/src/common/types/util.ts index e303667ec..b0c8e000c 100644 --- a/src/common/types/util.ts +++ b/src/common/types/util.ts @@ -1,64 +1,60 @@ import { - Expression, IntIntervalType, NonNeverType, NumericLiteralType, StringPrimitive, - StructExpression, - StructExpressionField, - StructType, + StructDescriptor, + StructInstanceType, Type, UnionType, + getStructDescriptor, + isStructInstance, without, } from '@chainner/navi'; +import { getChainnerScope } from './chainner-scope'; export type IntNumberType = | NumericLiteralType | IntIntervalType | UnionType; -export const isImage = ( - type: Type -): type is StructType & { - readonly name: 'Image'; - readonly fields: readonly [ - { readonly name: 'width'; readonly type: IntNumberType }, - { readonly name: 'height'; readonly type: IntNumberType }, - { readonly name: 'channels'; readonly type: IntNumberType } - ]; -} => { - return type.type === 'struct' && type.name === 'Image' && type.fields.length === 3; +interface KnownStructDefinitions { + Image: { + readonly width: IntNumberType; + readonly height: IntNumberType; + readonly channels: IntNumberType; + }; + Color: { + readonly channels: IntNumberType; + }; + Directory: { + readonly path: StringPrimitive; + }; +} +interface KnownInstance { + readonly descriptor: StructDescriptor & { readonly name: N }; +} +const createAssertFn = ( + name: N +): ((type: Type) => type is StructInstanceType & KnownInstance) => { + const fn = (type: Type) => isStructInstance(type) && type.descriptor.name === name; + return fn as never; }; -export const isColor = ( - type: Type -): type is StructType & { - readonly name: 'Color'; - readonly fields: readonly [{ readonly name: 'channels'; readonly type: IntNumberType }]; -} => { - return type.type === 'struct' && type.name === 'Color' && type.fields.length === 1; +export const isImage = createAssertFn('Image'); +export const isColor = createAssertFn('Color'); +export const isDirectory = createAssertFn('Directory'); + +export const getFields = ( + type: StructInstanceType & KnownInstance +): KnownStructDefinitions[N] => { + const fields: Record = {}; + type.descriptor.fields.forEach((field, i) => { + fields[field.name] = type.fields[i]; + }); + return fields as never; }; -export const isDirectory = ( - type: Type -): type is StructType & { - readonly name: 'Directory'; - readonly fields: readonly [{ readonly name: 'path'; readonly type: StringPrimitive }]; -} => { - return type.type === 'struct' && type.name === 'Directory' && type.fields.length === 1; -}; - -export const getField = (struct: StructType, field: string): NonNeverType | undefined => { - return struct.fields.find((f) => f.name === field)?.type; -}; - -const nullType = new StructType('null'); +export const nullType = getStructDescriptor(getChainnerScope(), 'null').default; export const withoutNull = (type: Type): Type => without(type, nullType); - -export const struct = (name: string, fields: Record): StructExpression => { - return new StructExpression( - name, - Object.entries(fields).map(([n, e]) => new StructExpressionField(n, e)) - ); -}; diff --git a/src/renderer/components/TypeTag.tsx b/src/renderer/components/TypeTag.tsx index 31296adf5..16297666a 100644 --- a/src/renderer/components/TypeTag.tsx +++ b/src/renderer/components/TypeTag.tsx @@ -1,8 +1,14 @@ -import { NeverType, Type, isNumericLiteral, isStringLiteral } from '@chainner/navi'; +import { + NeverType, + Type, + isNumericLiteral, + isStringLiteral, + isStructInstance, +} from '@chainner/navi'; import { Tag, Tooltip, forwardRef } from '@chakra-ui/react'; import React, { memo } from 'react'; import { useTranslation } from 'react-i18next'; -import { getField, isColor, isDirectory, isImage, withoutNull } from '../../common/types/util'; +import { getFields, isColor, isDirectory, isImage, withoutNull } from '../../common/types/util'; import { assertNever } from '../../common/util'; const getColorMode = (channels: number) => { @@ -28,68 +34,68 @@ const getTypeText = (type: Type): TagValue[] => { if (isStringLiteral(type)) return [{ kind: 'string', value: type.value }]; const tags: TagValue[] = []; - if (type.type === 'struct') { - if (isImage(type)) { - const [width, height, channels] = type.fields; - if (isNumericLiteral(width.type) && isNumericLiteral(height.type)) { - tags.push({ - kind: 'literal', - value: `${width.type.toString()}x${height.type.toString()}`, - }); - } - if (isNumericLiteral(channels.type)) { - const mode = getColorMode(channels.type.value); - if (mode) { - tags.push({ kind: 'literal', value: mode }); - } + if (isImage(type)) { + const { width, height, channels } = getFields(type); + if (isNumericLiteral(width) && isNumericLiteral(height)) { + tags.push({ + kind: 'literal', + value: `${width.toString()}x${height.toString()}`, + }); + } + if (isNumericLiteral(channels)) { + const mode = getColorMode(channels.value); + if (mode) { + tags.push({ kind: 'literal', value: mode }); } } + } - if (isColor(type)) { - const [channels] = type.fields; - if (isNumericLiteral(channels.type)) { - const mode = getColorMode(channels.type.value); - if (mode) { - tags.push({ kind: 'literal', value: mode }); - } + if (isColor(type)) { + const { channels } = getFields(type); + if (isNumericLiteral(channels)) { + const mode = getColorMode(channels.value); + if (mode) { + tags.push({ kind: 'literal', value: mode }); } } + } - if (isDirectory(type)) { - const [path] = type.fields; - - if (isStringLiteral(path.type)) { - tags.push({ kind: 'path', value: path.type.value }); - } + if (isDirectory(type)) { + const { path } = getFields(type); + if (isStringLiteral(path)) { + tags.push({ kind: 'path', value: path.value }); } + } + if (isStructInstance(type)) { if ( - type.name === 'PyTorchModel' || - type.name === 'NcnnNetwork' || - type.name === 'OnnxModel' + type.descriptor.name === 'PyTorchModel' || + type.descriptor.name === 'NcnnNetwork' || + type.descriptor.name === 'OnnxModel' ) { - const scale = getField(type, 'scale') ?? NeverType.instance; + const scale = type.getField('scale') ?? NeverType.instance; if (isNumericLiteral(scale)) { tags.push({ kind: 'literal', value: `${scale.toString()}x` }); } - const subType = getField(type, 'subType') ?? NeverType.instance; + const subType = type.getField('subType') ?? NeverType.instance; if (isStringLiteral(subType)) { tags.push({ kind: 'literal', value: subType.value }); } } } + if (type.type === 'union') { if (type.items.length === 2) { const [color, image] = type.items; if (isColor(color) && isImage(image)) { - const colorChannels = color.fields[0]; - const imageChannels = image.fields[2]; + const colorChannels = getFields(color).channels; + const imageChannels = getFields(image).channels; if ( - isNumericLiteral(colorChannels.type) && - isNumericLiteral(imageChannels.type) && - colorChannels.type.value === imageChannels.type.value + isNumericLiteral(colorChannels) && + isNumericLiteral(imageChannels) && + colorChannels.value === imageChannels.value ) { - const mode = getColorMode(colorChannels.type.value); + const mode = getColorMode(colorChannels.value); if (mode) { tags.push({ kind: 'literal', value: mode }); } diff --git a/src/renderer/components/inputs/DirectoryInput.tsx b/src/renderer/components/inputs/DirectoryInput.tsx index 3eb6e566e..6a919ce70 100644 --- a/src/renderer/components/inputs/DirectoryInput.tsx +++ b/src/renderer/components/inputs/DirectoryInput.tsx @@ -1,4 +1,4 @@ -import { Type } from '@chainner/navi'; +import { Type, isStringLiteral } from '@chainner/navi'; import { Icon, Input, @@ -15,6 +15,7 @@ import { useTranslation } from 'react-i18next'; import { BsFolderPlus } from 'react-icons/bs'; import { MdContentCopy, MdFolder } from 'react-icons/md'; import { ipcRenderer } from '../../../common/safeIpc'; +import { getFields, isDirectory } from '../../../common/types/util'; import { useContextMenu } from '../../hooks/useContextMenu'; import { useInputRefactor } from '../../hooks/useInputRefactor'; import { useLastDirectory } from '../../hooks/useLastDirectory'; @@ -22,15 +23,10 @@ import { MaybeLabel } from './InputContainer'; import { InputProps } from './props'; const getDirectoryPath = (type: Type): string | undefined => { - if ( - type.type === 'struct' && - type.name === 'Directory' && - type.fields.length > 0 && - type.fields[0].name === 'path' - ) { - const pathType = type.fields[0].type; - if (pathType.underlying === 'string' && pathType.type === 'literal') { - return pathType.value; + if (isDirectory(type)) { + const { path } = getFields(type); + if (isStringLiteral(path)) { + return path.value; } } return undefined;