Skip to content

Commit

Permalink
[ruby] Handle Rescue Exception Lists (#4575)
Browse files Browse the repository at this point in the history
  • Loading branch information
DavidBakerEffendi committed May 20, 2024
1 parent 22260c7 commit ec29b98
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -604,20 +604,23 @@ trait AstForExpressionsCreator(implicit withSchemaValidation: ValidationMode) {
protected def astForRescueExpression(node: RescueExpression): Ast = {
val tryAst = astForStatementList(node.body.asStatementList)
val rescueAsts = node.rescueClauses
.map {
case x: RescueClause =>
// TODO: add exception assignment
astForStatementList(x.thenClause.asStatementList)
case x => astForUnknown(x)
.map { x =>
val classes =
x.exceptionClassList.map(e => scope.tryResolveTypeReference(e.text).map(_.name).getOrElse(e.text)).toSeq
val variables = x.variables
.flatMap { v =>
handleVariableOccurrence(v)
scope.lookupVariable(v.text)
}
.collect {
case x: NewLocal => Ast(x.dynamicTypeHintFullName(classes))
case x: NewMethodParameterIn => Ast(x.dynamicTypeHintFullName(classes))
}
.toList
astForStatementList(x.thenClause.asStatementList).withChildren(variables)
}
val elseAst = node.elseClause.map {
case x: ElseClause => astForStatementList(x.thenClause.asStatementList)
case x => astForUnknown(x)
}
val ensureAst = node.ensureClause.map {
case x: EnsureClause => astForStatementList(x.thenClause.asStatementList)
case x => astForUnknown(x)
}
val elseAst = node.elseClause.map { x => astForStatementList(x.thenClause.asStatementList) }
val ensureAst = node.ensureClause.map { x => astForStatementList(x.thenClause.asStatementList) }
tryCatchAst(
NewControlStructure()
.controlStructureType(ControlStructureTypes.TRY)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
// Ensure never returns a value, only the main body, rescue & else clauses
RescueExpression(
transform(body),
rescueClauses.map(transform),
elseClause.map(transform).orElse(defaultElseBranch(node.span)),
rescueClauses.map(transform).collect { case x: RescueClause => x },
elseClause.map(transform).orElse(defaultElseBranch(node.span)).collect { case x: ElseClause => x },
ensureClause
)(node.span)
case WhileExpression(condition, body) => WhileExpression(condition, transform(body))(node.span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ object RubyIntermediateAst {
}

implicit class RubyNodeHelper(node: RubyNode) {
def asStatementList = node match
def asStatementList: StatementList = node match {
case stmtList: StatementList => stmtList
case _ => StatementList(List(node))(node.span)

}
}

final case class Unknown()(span: TextSpan) extends RubyNode(span)
Expand Down Expand Up @@ -150,16 +150,16 @@ object RubyIntermediateAst {

final case class RescueExpression(
body: RubyNode,
rescueClauses: List[RubyNode],
elseClause: Option[RubyNode],
ensureClause: Option[RubyNode]
rescueClauses: List[RescueClause],
elseClause: Option[ElseClause],
ensureClause: Option[EnsureClause]
)(span: TextSpan)
extends RubyNode(span)
with ControlFlowExpression

final case class RescueClause(
exceptionClassList: Option[RubyNode],
assignment: Option[RubyNode],
variables: Option[RubyNode],
thenClause: RubyNode
)(span: TextSpan)
extends RubyNode(span)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1083,10 +1083,11 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitBodyStatement(ctx: RubyParser.BodyStatementContext): RubyNode = {
val body = visit(ctx.compoundStatement())
val rescueClauses = Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList)
val elseClause = Option(ctx.elseClause).map(visit)
val ensureClause = Option(ctx.ensureClause).map(visit)
val body = visit(ctx.compoundStatement())
val rescueClauses =
Option(ctx.rescueClause.asScala).fold(List())(_.map(visit).toList).collect { case x: RescueClause => x }
val elseClause = Option(ctx.elseClause).map(visit).collect { case x: ElseClause => x }
val ensureClause = Option(ctx.ensureClause).map(visit).collect { case x: EnsureClause => x }

if (rescueClauses.isEmpty && elseClause.isEmpty && ensureClause.isEmpty) {
visit(ctx.compoundStatement())
Expand All @@ -1096,16 +1097,14 @@ class RubyNodeCreator extends RubyParserBaseVisitor[RubyNode] {
}

override def visitExceptionClassList(ctx: RubyParser.ExceptionClassListContext): RubyNode = {
// Requires implementing multiple rhs with splatting
logger.warn(s"Exception class lists are not handled: '${ctx.toTextSpan}'")
Unknown()(ctx.toTextSpan)
Option(ctx.multipleRightHandSide()).map(visitMultipleRightHandSide).getOrElse(visit(ctx.operatorExpression()))
}

override def visitRescueClause(ctx: RubyParser.RescueClauseContext): RubyNode = {
val exceptionClassList = Option(ctx.exceptionClassList).map(visit)
val elseClause = Option(ctx.exceptionVariableAssignment).map(visit)
val variables = Option(ctx.exceptionVariableAssignment).map(visit)
val thenClause = visit(ctx.thenClause)
RescueClause(exceptionClassList, elseClause, thenClause)(ctx.toTextSpan)
RescueClause(exceptionClassList, variables, thenClause)(ctx.toTextSpan)
}

override def visitEnsureClause(ctx: RubyParser.EnsureClauseContext): RubyNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -575,27 +575,32 @@ class ClassTests extends RubyCode2CpgFixture {
"Bodies that aren't StatementList" should {
val cpg = code("""
| class EventWebhook
| # * *Args* :
| # - +public_key+ -> elliptic curve public key
| # - +payload+ -> event payload in the request body
| # - +signature+ -> signature value obtained from the 'X-Twilio-Email-Event-Webhook-Signature' header
| # - +timestamp+ -> timestamp value obtained from the 'X-Twilio-Email-Event-Webhook-Timestamp' header
| ERRORS = [CustomErrorA, CustomErrorB]
|
| def verify_signature(public_key, payload, signature, timestamp)
| verify_engine
| timestamped_playload = "#{timestamp}#{payload}"
| payload_digest = Digest::SHA256.digest(timestamped_playload)
| timestamped_payload = "#{timestamp}#{payload}"
| payload_digest = Digest::SHA256.digest(timestamped_payload)
| decoded_signature = Base64.decode64(signature)
| public_key.dsa_verify_asn1(payload_digest, decoded_signature)
| rescue StandardError
| rescue *ERRORS => splat_errors
| false
| rescue StandardError => some_variable
| false
| end
| end
|""".stripMargin)
"not throw an execption" in {
inside(cpg.method.name("verify_signature").l) {
case verifySigMethod :: Nil => // Passing case
case _ => fail("Expected method for verify_sginature")
}

"successfully parse and create the method" in {
cpg.method.nameExact("verify_signature").nonEmpty shouldBe true
}

"create the `StandardError` local variable" in {
cpg.local.nameExact("some_variable").dynamicTypeHintFullName.toList shouldBe List("__builtin.StandardError")
}

"create the splatted error local variable" in {
cpg.local.nameExact("splat_errors").size shouldBe 1
}
}

Expand Down

0 comments on commit ec29b98

Please sign in to comment.