Skip to content

Commit

Permalink
8303737: C2: Load can bypass subtype check that enforces it's from th…
Browse files Browse the repository at this point in the history
…e right object type

Backport-of: 52983ed529182901db4e33857bfeab2727e235df
  • Loading branch information
shipilev committed Nov 1, 2023
1 parent ae5b92c commit c3c2b9f
Show file tree
Hide file tree
Showing 8 changed files with 496 additions and 49 deletions.
81 changes: 58 additions & 23 deletions src/hotspot/share/opto/castnode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@
//=============================================================================
// If input is already higher or equal to cast type, then this is an identity.
Node* ConstraintCastNode::Identity(PhaseGVN* phase) {
if (_dependency == UnconditionalDependency) {
return this;
}
Node* dom = dominating_cast(phase, phase);
if (dom != nullptr) {
return dom;
}
if (_dependency != RegularDependency) {
return this;
}
return phase->type(in(1))->higher_equal_speculative(_type) ? in(1) : this;
return higher_equal_types(phase, in(1)) ? in(1) : this;
}

//------------------------------Value------------------------------------------
Expand Down Expand Up @@ -100,47 +100,62 @@ Node *ConstraintCastNode::Ideal(PhaseGVN *phase, bool can_reshape) {
return (in(0) && remove_dead_region(phase, can_reshape)) ? this : nullptr;
}

uint ConstraintCastNode::hash() const {
return TypeNode::hash() + (int)_dependency + (_extra_types != nullptr ? _extra_types->hash() : 0);
}

bool ConstraintCastNode::cmp(const Node &n) const {
return TypeNode::cmp(n) && ((ConstraintCastNode&)n)._dependency == _dependency;
if (!TypeNode::cmp(n)) {
return false;
}
ConstraintCastNode& cast = (ConstraintCastNode&) n;
if (cast._dependency != _dependency) {
return false;
}
if (_extra_types == nullptr || cast._extra_types == nullptr) {
return _extra_types == cast._extra_types;
}
return _extra_types->eq(cast._extra_types);
}

uint ConstraintCastNode::size_of() const {
return sizeof(*this);
}

Node* ConstraintCastNode::make_cast(int opcode, Node* c, Node *n, const Type *t, DependencyType dependency) {
Node* ConstraintCastNode::make_cast(int opcode, Node* c, Node* n, const Type* t, DependencyType dependency,
const TypeTuple* extra_types) {
switch(opcode) {
case Op_CastII: {
Node* cast = new CastIINode(n, t, dependency);
Node* cast = new CastIINode(n, t, dependency, false, extra_types);
cast->set_req(0, c);
return cast;
}
case Op_CastLL: {
Node* cast = new CastLLNode(n, t, dependency);
Node* cast = new CastLLNode(n, t, dependency, extra_types);
cast->set_req(0, c);
return cast;
}
case Op_CastPP: {
Node* cast = new CastPPNode(n, t, dependency);
Node* cast = new CastPPNode(n, t, dependency, extra_types);
cast->set_req(0, c);
return cast;
}
case Op_CastFF: {
Node* cast = new CastFFNode(n, t, dependency);
Node* cast = new CastFFNode(n, t, dependency, extra_types);
cast->set_req(0, c);
return cast;
}
case Op_CastDD: {
Node* cast = new CastDDNode(n, t, dependency);
Node* cast = new CastDDNode(n, t, dependency, extra_types);
cast->set_req(0, c);
return cast;
}
case Op_CastVV: {
Node* cast = new CastVVNode(n, t, dependency);
Node* cast = new CastVVNode(n, t, dependency, extra_types);
cast->set_req(0, c);
return cast;
}
case Op_CheckCastPP: return new CheckCastPPNode(c, n, t, dependency);
case Op_CheckCastPP: return new CheckCastPPNode(c, n, t, dependency, extra_types);
default:
fatal("Bad opcode %d", opcode);
}
Expand All @@ -150,10 +165,10 @@ Node* ConstraintCastNode::make_cast(int opcode, Node* c, Node *n, const Type *t,
Node* ConstraintCastNode::make(Node* c, Node *n, const Type *t, DependencyType dependency, BasicType bt) {
switch(bt) {
case T_INT: {
return make_cast(Op_CastII, c, n, t, dependency);
return make_cast(Op_CastII, c, n, t, dependency, nullptr);
}
case T_LONG: {
return make_cast(Op_CastLL, c, n, t, dependency);
return make_cast(Op_CastLL, c, n, t, dependency, nullptr);
}
default:
fatal("Bad basic type %s", type2name(bt));
Expand Down Expand Up @@ -186,7 +201,7 @@ TypeNode* ConstraintCastNode::dominating_cast(PhaseGVN* gvn, PhaseTransform* pt)
u->outcnt() > 0 &&
u->Opcode() == opc &&
u->in(0) != nullptr &&
u->bottom_type()->higher_equal(type())) {
higher_equal_types(gvn, u)) {
if (pt->is_dominator(u->in(0), ctl)) {
return u->as_Type();
}
Expand All @@ -202,9 +217,28 @@ TypeNode* ConstraintCastNode::dominating_cast(PhaseGVN* gvn, PhaseTransform* pt)
return nullptr;
}

bool ConstraintCastNode::higher_equal_types(PhaseGVN* phase, const Node* other) const {
const Type* t = phase->type(other);
if (!t->higher_equal_speculative(type())) {
return false;
}
if (_extra_types != nullptr) {
for (uint i = 0; i < _extra_types->cnt(); ++i) {
if (!t->higher_equal_speculative(_extra_types->field_at(i))) {
return false;
}
}
}
return true;
}

#ifndef PRODUCT
void ConstraintCastNode::dump_spec(outputStream *st) const {
TypeNode::dump_spec(st);
if (_extra_types != nullptr) {
st->print(" extra types: ");
_extra_types->dump_on(st);
}
if (_dependency != RegularDependency) {
st->print(" %s dependency", _dependency == StrongDependency ? "strong" : "unconditional");
}
Expand Down Expand Up @@ -523,20 +557,21 @@ Node* CastP2XNode::Identity(PhaseGVN* phase) {
return this;
}

Node* ConstraintCastNode::make_cast_for_type(Node* c, Node* in, const Type* type, DependencyType dependency) {
Node* ConstraintCastNode::make_cast_for_type(Node* c, Node* in, const Type* type, DependencyType dependency,
const TypeTuple* types) {
Node* cast= nullptr;
if (type->isa_int()) {
cast = make_cast(Op_CastII, c, in, type, dependency);
cast = make_cast(Op_CastII, c, in, type, dependency, types);
} else if (type->isa_long()) {
cast = make_cast(Op_CastLL, c, in, type, dependency);
cast = make_cast(Op_CastLL, c, in, type, dependency, types);
} else if (type->isa_float()) {
cast = make_cast(Op_CastFF, c, in, type, dependency);
cast = make_cast(Op_CastFF, c, in, type, dependency, types);
} else if (type->isa_double()) {
cast = make_cast(Op_CastDD, c, in, type, dependency);
cast = make_cast(Op_CastDD, c, in, type, dependency, types);
} else if (type->isa_vect()) {
cast = make_cast(Op_CastVV, c, in, type, dependency);
cast = make_cast(Op_CastVV, c, in, type, dependency, types);
} else if (type->isa_ptr()) {
cast = make_cast(Op_CastPP, c, in, type, dependency);
cast = make_cast(Op_CastPP, c, in, type, dependency, types);
}
return cast;
}
Expand Down
60 changes: 40 additions & 20 deletions src/hotspot/share/opto/castnode.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,20 @@ class ConstraintCastNode: public TypeNode {
const DependencyType _dependency;
virtual bool cmp( const Node &n ) const;
virtual uint size_of() const;
virtual uint hash() const; // Check the type
const Type* widen_type(const PhaseGVN* phase, const Type* res, BasicType bt) const;

private:
// PhiNode::Ideal() transforms a Phi that merges a single uncasted value into a single cast pinned at the region.
// The types of cast nodes eliminated as a consequence of this transformation are collected and stored here so the
// type dependencies carried by the cast are known. The cast can then be eliminated if the type of its input is
// narrower (or equal) than all the types it carries.
const TypeTuple* _extra_types;

public:
ConstraintCastNode(Node *n, const Type *t, DependencyType dependency)
: TypeNode(t,2), _dependency(dependency) {
ConstraintCastNode(Node* n, const Type* t, ConstraintCastNode::DependencyType dependency,
const TypeTuple* extra_types)
: TypeNode(t,2), _dependency(dependency), _extra_types(extra_types) {
init_class_id(Class_ConstraintCast);
init_req(1, n);
}
Expand All @@ -59,14 +68,15 @@ class ConstraintCastNode: public TypeNode {
virtual bool depends_only_on_test() const { return _dependency == RegularDependency; }
bool carry_dependency() const { return _dependency != RegularDependency; }
TypeNode* dominating_cast(PhaseGVN* gvn, PhaseTransform* pt) const;
static Node* make_cast(int opcode, Node* c, Node *n, const Type *t, DependencyType dependency);
static Node* make_cast(int opcode, Node* c, Node* n, const Type* t, DependencyType dependency, const TypeTuple* extra_types);
static Node* make(Node* c, Node *n, const Type *t, DependencyType dependency, BasicType bt);

#ifndef PRODUCT
virtual void dump_spec(outputStream *st) const;
#endif

static Node* make_cast_for_type(Node* c, Node* in, const Type* type, DependencyType dependency);
static Node* make_cast_for_type(Node* c, Node* in, const Type* type, DependencyType dependency,
const TypeTuple* types);

Node* optimize_integer_cast(PhaseGVN* phase, BasicType bt);

Expand All @@ -91,6 +101,16 @@ class ConstraintCastNode: public TypeNode {
}
}
}

bool higher_equal_types(PhaseGVN* phase, const Node* other) const;

int extra_types_count() const {
return _extra_types == nullptr ? 0 : _extra_types->cnt();
}

const Type* extra_type_at(int i) const {
return _extra_types->field_at(i);
}
};

//------------------------------CastIINode-------------------------------------
Expand All @@ -103,12 +123,12 @@ class CastIINode: public ConstraintCastNode {
virtual uint size_of() const;

public:
CastIINode(Node* n, const Type* t, DependencyType dependency = RegularDependency, bool range_check_dependency = false)
: ConstraintCastNode(n, t, dependency), _range_check_dependency(range_check_dependency) {
CastIINode(Node* n, const Type* t, DependencyType dependency = RegularDependency, bool range_check_dependency = false, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types), _range_check_dependency(range_check_dependency) {
init_class_id(Class_CastII);
}
CastIINode(Node* ctrl, Node* n, const Type* t, DependencyType dependency = RegularDependency, bool range_check_dependency = false)
: ConstraintCastNode(n, t, dependency), _range_check_dependency(range_check_dependency) {
: ConstraintCastNode(n, t, dependency, nullptr), _range_check_dependency(range_check_dependency) {
init_class_id(Class_CastII);
init_req(0, ctrl);
}
Expand All @@ -134,12 +154,12 @@ class CastIINode: public ConstraintCastNode {
class CastLLNode: public ConstraintCastNode {
public:
CastLLNode(Node* ctrl, Node* n, const Type* t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency) {
: ConstraintCastNode(n, t, dependency, nullptr) {
init_class_id(Class_CastLL);
init_req(0, ctrl);
}
CastLLNode(Node* n, const Type* t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency){
CastLLNode(Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types) {
init_class_id(Class_CastLL);
}

Expand All @@ -151,8 +171,8 @@ class CastLLNode: public ConstraintCastNode {

class CastFFNode: public ConstraintCastNode {
public:
CastFFNode(Node* n, const Type* t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency){
CastFFNode(Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types) {
init_class_id(Class_CastFF);
}
virtual int Opcode() const;
Expand All @@ -161,8 +181,8 @@ class CastFFNode: public ConstraintCastNode {

class CastDDNode: public ConstraintCastNode {
public:
CastDDNode(Node* n, const Type* t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency){
CastDDNode(Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types) {
init_class_id(Class_CastDD);
}
virtual int Opcode() const;
Expand All @@ -171,8 +191,8 @@ class CastDDNode: public ConstraintCastNode {

class CastVVNode: public ConstraintCastNode {
public:
CastVVNode(Node* n, const Type* t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency){
CastVVNode(Node* n, const Type* t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types) {
init_class_id(Class_CastVV);
}
virtual int Opcode() const;
Expand All @@ -184,8 +204,8 @@ class CastVVNode: public ConstraintCastNode {
// cast pointer to pointer (different type)
class CastPPNode: public ConstraintCastNode {
public:
CastPPNode (Node *n, const Type *t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency) {
CastPPNode (Node *n, const Type *t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types) {
}
virtual int Opcode() const;
virtual uint ideal_reg() const { return Op_RegP; }
Expand All @@ -195,8 +215,8 @@ class CastPPNode: public ConstraintCastNode {
// for _checkcast, cast pointer to pointer (different type), without JOIN,
class CheckCastPPNode: public ConstraintCastNode {
public:
CheckCastPPNode(Node *c, Node *n, const Type *t, DependencyType dependency = RegularDependency)
: ConstraintCastNode(n, t, dependency) {
CheckCastPPNode(Node *c, Node *n, const Type *t, DependencyType dependency = RegularDependency, const TypeTuple* types = nullptr)
: ConstraintCastNode(n, t, dependency, types) {
init_class_id(Class_CheckCastPP);
init_req(0, c);
}
Expand Down

1 comment on commit c3c2b9f

@openjdk-notifier
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.