diff --git a/src/class.c b/src/class.c index 3246564ec7..35c3aa0405 100644 --- a/src/class.c +++ b/src/class.c @@ -131,6 +131,19 @@ mrb_class_outer_module(mrb_state *mrb, struct RClass *c) return mrb_class_ptr(outer); } +static void +check_if_class_or_module(mrb_state *mrb, mrb_value obj) +{ + switch (mrb_type(obj)) { + case MRB_TT_CLASS: + case MRB_TT_SCLASS: + case MRB_TT_MODULE: + return; + default: + mrb_raisef(mrb, E_TYPE_ERROR, "%S is not a class/module", mrb_inspect(mrb, obj)); + } +} + static struct RClass* define_module(mrb_state *mrb, mrb_sym name, struct RClass *outer) { @@ -160,6 +173,7 @@ mrb_define_module(mrb_state *mrb, const char *name) MRB_API struct RClass* mrb_vm_define_module(mrb_state *mrb, mrb_value outer, mrb_sym id) { + check_if_class_or_module(mrb, outer); return define_module(mrb, id, mrb_class_ptr(outer)); } @@ -232,15 +246,7 @@ mrb_vm_define_class(mrb_state *mrb, mrb_value outer, mrb_value super, mrb_sym id else { s = 0; } - switch (mrb_type(outer)) { - case MRB_TT_CLASS: - case MRB_TT_SCLASS: - case MRB_TT_MODULE: - break; - default: - mrb_raisef(mrb, E_TYPE_ERROR, "%S is not a class/module", outer); - break; - } + check_if_class_or_module(mrb, outer); c = define_class(mrb, id, s, mrb_class_ptr(outer)); mrb_class_inherited(mrb, mrb_class_real(c->super), c); diff --git a/test/t/class.rb b/test/t/class.rb index d4ecf99d09..720fd37fa3 100644 --- a/test/t/class.rb +++ b/test/t/class.rb @@ -383,3 +383,8 @@ def class_variable assert_equal("value", ClassVariableTest.class_variable) end + +assert('class with non-class/module outer raises TypeError') do + assert_raise(TypeError) { class 0::C1; end } + assert_raise(TypeError) { class []::C2; end } +end diff --git a/test/t/module.rb b/test/t/module.rb index 9852328ce2..ecb9694757 100644 --- a/test/t/module.rb +++ b/test/t/module.rb @@ -533,3 +533,7 @@ def modfunc; end assert_true M.respond_to?(:modfunc) end +assert('module with non-class/module outer raises TypeError') do + assert_raise(TypeError) { module 0::M1 end } + assert_raise(TypeError) { module []::M2 end } +end