diff --git a/oracle/migrator.go b/oracle/migrator.go index b1d8b54..e1050f2 100644 --- a/oracle/migrator.go +++ b/oracle/migrator.go @@ -484,6 +484,33 @@ func (m Migrator) DropConstraint(value interface{}, name string) error { return m.Migrator.DropConstraint(value, name) } +// CreateType creates or replaces an Oracle user-defined type +func (m Migrator) CreateType(typeName, typeKind, typeof string) error { + if typeName == "" || typeKind == "" || typeof == "" { + return fmt.Errorf("createType: both typeName and definition are required") + } + + sql := fmt.Sprintf(`CREATE OR REPLACE TYPE "%s" AS %s OF %s`, strings.ToLower(typeName), typeKind, typeof) + return m.DB.Exec(sql).Error +} + +// DropType drops an Oracle user-defined type +func (m Migrator) DropType(typeName string) error { + sql := fmt.Sprintf(`DROP TYPE "%s" FORCE`, strings.ToLower(typeName)) + return m.DB.Exec(sql).Error +} + +// HasType checks whether a user-defined type exists +func (m Migrator) HasType(typeName string) bool { + if typeName == "" { + return false + } + + var count int + err := m.DB.Raw(`SELECT COUNT(*) FROM USER_TYPES WHERE TYPE_NAME = UPPER(?)`, typeName).Scan(&count).Error + return err == nil && count > 0 +} + // DropIndex drops the index with the specified `name` from the table associated with `value` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { diff --git a/tests/migrate_test.go b/tests/migrate_test.go index 2da7589..a8e2d48 100644 --- a/tests/migrate_test.go +++ b/tests/migrate_test.go @@ -50,6 +50,7 @@ import ( "time" + "github.com/oracle-samples/gorm-oracle/oracle" . "github.com/oracle-samples/gorm-oracle/tests/utils" "github.com/stretchr/testify/assert" @@ -1970,6 +1971,84 @@ func TestOracleSequences(t *testing.T) { } } +func TestOracleTypeCreateDrop(t *testing.T) { + if DB.Dialector.Name() != "oracle" { + t.Skip("Skipping Oracle type test: not running on Oracle") + } + + const typeName = "email_list" + const tableName = "email_varray_tab" + + // Assert that DB.Migrator() is an oracle.Migrator (so we can use Oracle-specific methods) + m, ok := DB.Migrator().(oracle.Migrator) + if !ok { + t.Skip("Skipping: current dialect migrator is not Oracle-specific") + } + + // Drop type if it exists + t.Run("drop_existing_type_if_any", func(t *testing.T) { + err := m.DropType(typeName) + if err != nil && !strings.Contains(err.Error(), "ORA-04043") { + t.Fatalf("Unexpected error dropping type: %v", err) + } + }) + + // Create new VARRAY type + t.Run("create_varray_type", func(t *testing.T) { + err := m.CreateType(typeName, "VARRAY(10)", "VARCHAR2(60)") + if err != nil { + t.Fatalf("Failed to create Oracle type: %v", err) + } + + // Verify it exists + var count int + if err := DB.Raw(`SELECT COUNT(*) FROM USER_TYPES WHERE TYPE_NAME = LOWER(?)`, typeName).Scan(&count).Error; err != nil { + t.Fatalf("Failed to verify created type: %v", err) + } + if count == 0 { + t.Fatalf("Expected Oracle type %s to exist", typeName) + } + }) + + // Create table using the custom type + t.Run("create_table_using_custom_type", func(t *testing.T) { + createTableSQL := fmt.Sprintf(` + CREATE TABLE "%s" ( + "ID" NUMBER PRIMARY KEY, + "EMAILS" "%s" + )`, tableName, typeName) + + if err := DB.Exec(createTableSQL).Error; err != nil { + t.Fatalf("Failed to create table using type %s: %v", typeName, err) + } + + // Verify table exists + if !m.HasTable(tableName) { + t.Fatalf("Expected table %s to exist", tableName) + } + }) + + // Drop table and type + t.Run("drop_table_and_type", func(t *testing.T) { + if err := m.DropTable(tableName); err != nil { + t.Fatalf("Failed to drop table %s: %v", tableName, err) + } + + if err := m.DropType(typeName); err != nil { + t.Fatalf("Failed to drop type %s: %v", typeName, err) + } + + // Verify type is gone + var count int + if err := DB.Raw(`SELECT COUNT(*) FROM USER_TYPES WHERE TYPE_NAME = LOWER(?)`, typeName).Scan(&count).Error; err != nil { + t.Fatalf("Failed to verify dropped type: %v", err) + } + if count > 0 { + t.Fatalf("Expected Oracle type %s to be dropped", typeName) + } + }) +} + func TestOracleIndexes(t *testing.T) { if DB.Dialector.Name() != "oracle" { return