Skip to content

Commit

Permalink
feat(postgres): add support for native enums (#4296)
Browse files Browse the repository at this point in the history
By default, the PostgreSQL driver, represents enums as a text columns
with check constraints. Since v6, you can opt-in for a native enums by
setting the `nativeEnumName` option.

```ts
@entity()
export class User {

  @enum({ items: () => UserRole, nativeEnumName: 'user_role' })
  role!: UserRole;

}

export enum UserRole {
  ADMIN = 'admin',
  MODERATOR = 'moderator',
  USER = 'user',
}
```

Closes #2764
  • Loading branch information
B4nan committed Nov 5, 2023
1 parent e649480 commit 8515380
Show file tree
Hide file tree
Showing 17 changed files with 459 additions and 26 deletions.
68 changes: 68 additions & 0 deletions docs/docs/defining-entities.md
Expand Up @@ -550,6 +550,74 @@ properties: {
</TabItem>
</Tabs>

### PostgreSQL native enums

By default, the PostgreSQL driver, represents enums as a text columns with check constraints. Since v6, you can opt-in for a native enums by setting the `nativeEnumName` option.

<Tabs
groupId="entity-def"
defaultValue="reflect-metadata"
values={[
{label: 'reflect-metadata', value: 'reflect-metadata'},
{label: 'ts-morph', value: 'ts-morph'},
{label: 'EntitySchema', value: 'entity-schema'},
]
}>
<TabItem value="reflect-metadata">

```ts title="./entities/Author.ts"
@Entity()
export class User {

@Enum({ items: () => UserRole, nativeEnumName: 'user_role' })
role!: UserRole;

}

export enum UserRole {
ADMIN = 'admin',
MODERATOR = 'moderator',
USER = 'user',
}
```

</TabItem>
<TabItem value="ts-morph">

```ts title="./entities/Author.ts"
@Entity()
export class User {

@Enum({ items: () => UserRole, nativeEnumName: 'user_role' })
role!: UserRole;

}

export enum UserRole {
ADMIN = 'admin',
MODERATOR = 'moderator',
USER = 'user',
}
```

</TabItem>
<TabItem value="entity-schema">

```ts title="./entities/Author.ts"
export enum UserRole {
ADMIN = 'admin',
MODERATOR = 'moderator',
USER = 'user',
}

properties: {
role: { enum: true, nativeEnumName: 'user_role', items: () => UserRole },
},
```

</TabItem>
</Tabs>

## Enum arrays

We can also use array of values for enum, in that case, `EnumArrayType` type will be used automatically, that will validate items on flush.
Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/decorators/Enum.ts
Expand Up @@ -22,4 +22,6 @@ export function Enum<T extends object>(options: EnumOptions<AnyEntity> | (() =>
export interface EnumOptions<T> extends PropertyOptions<T> {
items?: (number | string)[] | (() => Dictionary);
array?: boolean;
/** for postgres, by default it uses text column with check constraint */
nativeEnumName?: string;
}
5 changes: 5 additions & 0 deletions packages/core/src/metadata/EntitySchema.ts
Expand Up @@ -120,6 +120,11 @@ export class EntitySchema<T = any, U = never> {
prop.enum = false;
}

// force string labels on native enums
if (prop.nativeEnumName && Array.isArray(prop.items)) {
prop.items = prop.items.map(val => '' + val);
}

this.addProperty(name, this.internal ? type : type || 'enum', prop);
}

Expand Down
4 changes: 3 additions & 1 deletion packages/core/src/metadata/MetadataDiscovery.ts
Expand Up @@ -1205,7 +1205,9 @@ export class MetadataDiscovery {
private getMappedType(prop: EntityProperty): Type<unknown> {
let t = prop.columnTypes?.[0] ?? prop.type?.toLowerCase();

if (prop.enum) {
if (prop.nativeEnumName) {
t = 'enum';
} else if (prop.enum) {
t = prop.items?.every(item => Utils.isString(item)) ? 'enum' : 'tinyint';
}

Expand Down
5 changes: 5 additions & 0 deletions packages/core/src/platforms/Platform.ts
Expand Up @@ -48,6 +48,11 @@ export abstract class Platform {
return false;
}

/** for postgres native enums */
supportsNativeEnums(): boolean {
return false;
}

getSchemaHelper(): unknown {
return undefined;
}
Expand Down
1 change: 1 addition & 0 deletions packages/core/src/typings.ts
Expand Up @@ -305,6 +305,7 @@ export interface EntityProperty<T = any> {
hidden?: boolean;
enum?: boolean;
items?: (number | string)[];
nativeEnumName?: string; // for postgres, by default it uses text column with check constraint
version?: boolean;
concurrencyCheck?: boolean;
eager?: boolean;
Expand Down
26 changes: 23 additions & 3 deletions packages/knex/src/schema/DatabaseSchema.ts
Expand Up @@ -11,13 +11,15 @@ export class DatabaseSchema {

private tables: DatabaseTable[] = [];
private namespaces = new Set<string>();
private nativeEnums: Dictionary<unknown[]> = {}; // for postgres

constructor(private readonly platform: AbstractSqlPlatform,
readonly name: string) { }

addTable(name: string, schema: string | undefined | null, comment?: string): DatabaseTable {
const namespaceName = schema ?? this.name;
const table = new DatabaseTable(this.platform, name, namespaceName);
table.nativeEnums = this.nativeEnums;
table.comment = comment;
this.tables.push(table);

Expand All @@ -40,6 +42,15 @@ export class DatabaseSchema {
return !!this.getTable(name);
}

setNativeEnums(nativeEnums: Dictionary<unknown[]>): void {
this.nativeEnums = nativeEnums;
this.tables.forEach(t => t.nativeEnums = nativeEnums);
}

getNativeEnums(): Dictionary<unknown[]> {
return this.nativeEnums;
}

hasNamespace(namespace: string) {
return this.namespaces.has(namespace);
}
Expand All @@ -48,20 +59,29 @@ export class DatabaseSchema {
return [...this.namespaces];
}

static async create(connection: AbstractSqlConnection, platform: AbstractSqlPlatform, config: Configuration, schemaName?: string): Promise<DatabaseSchema> {
const schema = new DatabaseSchema(platform, schemaName ?? config.get('schema'));
static async create(connection: AbstractSqlConnection, platform: AbstractSqlPlatform, config: Configuration, schemaName?: string, schemas?: string[]): Promise<DatabaseSchema> {
const schema = new DatabaseSchema(platform, schemaName ?? config.get('schema') ?? platform.getDefaultSchemaName());
const allTables = await connection.execute<Table[]>(platform.getSchemaHelper()!.getListTablesSQL());
const parts = config.get('migrations').tableName!.split('.');
const migrationsTableName = parts[1] ?? parts[0];
const migrationsSchemaName = parts.length > 1 ? parts[0] : config.get('schema', platform.getDefaultSchemaName());
const tables = allTables.filter(t => t.table_name !== migrationsTableName || (t.schema_name && t.schema_name !== migrationsSchemaName));
await platform.getSchemaHelper()!.loadInformationSchema(schema, connection, tables);
await platform.getSchemaHelper()!.loadInformationSchema(schema, connection, tables, schemas && schemas.length > 0 ? schemas : undefined);

return schema;
}

static fromMetadata(metadata: EntityMetadata[], platform: AbstractSqlPlatform, config: Configuration, schemaName?: string): DatabaseSchema {
const schema = new DatabaseSchema(platform, schemaName ?? config.get('schema'));
const nativeEnums: Dictionary<unknown[]> = {};

for (const meta of metadata) {
meta.props
.filter(prop => prop.nativeEnumName)
.forEach(prop => nativeEnums[prop.nativeEnumName!] = prop.items?.map(val => '' + val) ?? []);
}

schema.setNativeEnums(nativeEnums);

for (const meta of metadata) {
const table = schema.addTable(meta.collection, this.getSchemaName(meta, config, schemaName));
Expand Down
6 changes: 4 additions & 2 deletions packages/knex/src/schema/DatabaseTable.ts
Expand Up @@ -12,6 +12,7 @@ export class DatabaseTable {
private indexes: IndexDef[] = [];
private checks: CheckDef[] = [];
private foreignKeys: Dictionary<ForeignKey> = {};
public nativeEnums: Dictionary<unknown[]> = {}; // for postgres
public comment?: string;

constructor(private readonly platform: AbstractSqlPlatform,
Expand Down Expand Up @@ -54,7 +55,7 @@ export class DatabaseTable {
const type = v.name in enums ? 'enum' : v.type;
v.mappedType = this.platform.getMappedType(type);
v.default = v.default?.toString().startsWith('nextval(') ? null : v.default;
v.enumItems = enums[v.name] || [];
v.enumItems ??= enums[v.name] || [];
o[v.name] = v;

return o;
Expand Down Expand Up @@ -100,11 +101,12 @@ export class DatabaseTable {
autoincrement: prop.autoincrement ?? primary,
primary,
nullable: this.columns[field]?.nullable ?? !!prop.nullable,
nativeEnumName: prop.nativeEnumName,
length: prop.length,
precision: prop.precision,
scale: prop.scale,
default: prop.defaultRaw,
enumItems: prop.items?.every(Utils.isString) ? prop.items as string[] : undefined,
enumItems: prop.nativeEnumName || prop.items?.every(Utils.isString) ? prop.items as string[] : undefined,
comment: prop.comment,
extra: prop.extra,
ignoreSchemaChanges: prop.ignoreSchemaChanges,
Expand Down
10 changes: 9 additions & 1 deletion packages/knex/src/schema/SchemaHelper.ts
Expand Up @@ -63,7 +63,15 @@ export abstract class SchemaHelper {
return {};
}

async loadInformationSchema(schema: DatabaseSchema, connection: AbstractSqlConnection, tables: Table[]): Promise<void> {
getDropNativeEnumSQL(name: string, schema?: string): string {
throw new Error('Not supported by given driver');
}

getAlterNativeEnumSQL(name: string, schema?: string, value?: string): string {
throw new Error('Not supported by given driver');
}

async loadInformationSchema(schema: DatabaseSchema, connection: AbstractSqlConnection, tables: Table[], schemas?: string[]): Promise<void> {
for (const t of tables) {
const table = schema.addTable(t.table_name, t.schema_name);
table.comment = t.table_comment;
Expand Down
33 changes: 31 additions & 2 deletions packages/knex/src/schema/SqlSchemaGenerator.ts
Expand Up @@ -117,7 +117,8 @@ export class SqlSchemaGenerator extends AbstractSchemaGenerator<AbstractSqlDrive
await this.ensureDatabase();
const wrap = options.wrap ?? this.options.disableForeignKeys;
const metadata = this.getOrderedMetadata(options.schema).reverse();
const schema = await DatabaseSchema.create(this.connection, this.platform, this.config, options.schema);
const schemas = this.getTargetSchema(options.schema).getNamespaces();
const schema = await DatabaseSchema.create(this.connection, this.platform, this.config, options.schema, schemas);
let ret = '';

// remove FKs explicitly if we can't use cascading statement and we don't disable FK checks (we need this for circular relations)
Expand All @@ -138,6 +139,13 @@ export class SqlSchemaGenerator extends AbstractSchemaGenerator<AbstractSqlDrive
ret += await this.dump(this.dropTable(meta.collection, this.getSchemaName(meta, options)), '\n');
}

if (this.platform.supportsNativeEnums()) {
for (const columnName of Object.keys(schema.getNativeEnums())) {
const sql = this.helper.getDropNativeEnumSQL(columnName, options.schema ?? this.config.get('schema'));
ret += await this.dump(this.knex.schema.raw(sql), '\n');
}
}

if (options.dropMigrationsTable) {
ret += await this.dump(this.dropTable(this.config.get('migrations').tableName!, this.config.get('schema')), '\n');
}
Expand Down Expand Up @@ -185,10 +193,12 @@ export class SqlSchemaGenerator extends AbstractSchemaGenerator<AbstractSqlDrive
options.safe ??= false;
options.dropTables ??= true;
const toSchema = this.getTargetSchema(options.schema);
const fromSchema = options.fromSchema ?? await DatabaseSchema.create(this.connection, this.platform, this.config, options.schema);
const schemas = toSchema.getNamespaces();
const fromSchema = options.fromSchema ?? await DatabaseSchema.create(this.connection, this.platform, this.config, options.schema, schemas);
const wildcardSchemaTables = Object.values(this.metadata.getAll()).filter(meta => meta.schema === '*').map(meta => meta.tableName);
fromSchema.prune(options.schema, wildcardSchemaTables);
toSchema.prune(options.schema, wildcardSchemaTables);
toSchema.setNativeEnums(fromSchema.getNativeEnums());

return { fromSchema, toSchema };
}
Expand Down Expand Up @@ -318,6 +328,13 @@ export class SqlSchemaGenerator extends AbstractSchemaGenerator<AbstractSqlDrive
private alterTable(diff: TableDifference, safe: boolean): Knex.SchemaBuilder[] {
const ret: Knex.SchemaBuilder[] = [];
const [schemaName, tableName] = this.splitTableName(diff.name);
const changedNativeEnums: [string, string[], string[]][] = [];

for (const { column, changedProperties } of Object.values(diff.changedColumns)) {
if (column.nativeEnumName && changedProperties.has('enumItems') && column.nativeEnumName in diff.fromTable.nativeEnums) {
changedNativeEnums.push([column.nativeEnumName, column.enumItems!, diff.fromTable.getColumn(column.name)!.enumItems!]);
}
}

ret.push(this.createSchemaBuilder(schemaName).alterTable(tableName, table => {
for (const index of Object.values(diff.removedIndexes)) {
Expand Down Expand Up @@ -363,6 +380,10 @@ export class SqlSchemaGenerator extends AbstractSchemaGenerator<AbstractSqlDrive
continue;
}

if (changedProperties.size === 1 && changedProperties.has('enumItems') && column.nativeEnumName) {
continue;
}

const col = this.helper.createTableColumn(table, column, diff.fromTable, changedProperties).alter();
this.helper.configureColumn(column, col, this.knex, changedProperties);
}
Expand Down Expand Up @@ -422,6 +443,14 @@ export class SqlSchemaGenerator extends AbstractSchemaGenerator<AbstractSqlDrive
}
}));

if (this.platform.supportsNativeEnums()) {
changedNativeEnums.forEach(([enumName, itemsNew, itemsOld]) => {
// postgres allows only adding new items
const newItems = itemsNew.filter(val => !itemsOld.includes(val));
ret.push(...newItems.map(val => this.knex.schema.raw(this.helper.getAlterNativeEnumSQL(enumName, schemaName, val))));
});
}

return ret;
}

Expand Down
1 change: 1 addition & 0 deletions packages/knex/src/typings.ts
Expand Up @@ -47,6 +47,7 @@ export interface Column {
scale?: number;
default?: string | null;
comment?: string;
nativeEnumName?: string;
enumItems?: string[];
primary?: boolean;
unique?: boolean;
Expand Down
26 changes: 16 additions & 10 deletions packages/postgresql/src/PostgreSqlConnection.ts
Expand Up @@ -71,19 +71,29 @@ export class PostgreSqlConnection extends AbstractSqlConnection {
const type = col.getColumnType();
const colName = this.client.wrapIdentifier(col.getColumnName(), col.columnBuilder.queryContext());
const constraintName = `${this.tableNameRaw.replace(/^.*\.(.*)$/, '$1')}_${col.getColumnName()}_check`;
that.dropColumnDefault.call(this, col, colName);
const useNative = col.args?.[2]?.useNative;
const alterType = col.columnBuilder.alterType;
const alterNullable = col.columnBuilder.alterNullable;
const defaultTo = col.modified.defaultTo;

if (defaultTo != null) {
that.dropColumnDefault.call(this, col, colName);
}

if (col.type === 'enu' && !useNative) {
if (alterType) {
this.pushQuery({ sql: `alter table ${quotedTableName} alter column ${colName} type text using (${colName}::text)`, bindings: [] });
}

if (col.type === 'enu') {
this.pushQuery({ sql: `alter table ${quotedTableName} alter column ${colName} type text using (${colName}::text)`, bindings: [] });
/* istanbul ignore else */
if (options.createForeignKeyConstraints) {
if (options.createForeignKeyConstraints && alterNullable) {
this.pushQuery({ sql: `alter table ${quotedTableName} add constraint "${constraintName}" ${type.replace(/^text /, '')}`, bindings: [] });
}
} else if (type === 'uuid') {
// we need to drop the default as it would be invalid
this.pushQuery({ sql: `alter table ${quotedTableName} alter column ${colName} drop default`, bindings: [] });
this.pushQuery({ sql: `alter table ${quotedTableName} alter column ${colName} type ${type} using (${colName}::text::uuid)`, bindings: [] });
} else {
} else if (alterType) {
this.pushQuery({ sql: `alter table ${quotedTableName} alter column ${colName} type ${type} using (${colName}::${type})`, bindings: [] });
}

Expand Down Expand Up @@ -124,11 +134,7 @@ export class PostgreSqlConnection extends AbstractSqlConnection {
const quotedTableName = this.tableName();
const defaultTo = col.modified.defaultTo;

if (!defaultTo) {
return;
}

if (defaultTo[0] === null) {
if (defaultTo?.[0] == null) {
this.pushQuery({ sql: `alter table ${quotedTableName} alter column ${colName} drop default`, bindings: [] });
}
}
Expand Down
10 changes: 9 additions & 1 deletion packages/postgresql/src/PostgreSqlPlatform.ts
Expand Up @@ -17,6 +17,10 @@ export class PostgreSqlPlatform extends AbstractSqlPlatform {
return true;
}

override supportsNativeEnums(): boolean {
return true;
}

override supportsCustomPrimaryKeyNames(): boolean {
return true;
}
Expand Down Expand Up @@ -128,7 +132,11 @@ export class PostgreSqlPlatform extends AbstractSqlPlatform {
return 'double precision';
}

override getEnumTypeDeclarationSQL(column: { fieldNames: string[]; items?: unknown[] }): string {
override getEnumTypeDeclarationSQL(column: { fieldNames: string[]; items?: unknown[]; nativeEnumName?: string }): string {
if (column.nativeEnumName) {
return column.nativeEnumName;
}

if (column.items?.every(item => Utils.isString(item))) {
return 'text';
}
Expand Down

0 comments on commit 8515380

Please sign in to comment.