-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
TransactionBoundQueryContextTest.scala
180 lines (155 loc) · 7.68 KB
/
TransactionBoundQueryContextTest.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
/*
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.spi.v3_1
import java.net.URL
import java.util.Collections
import org.mockito.Mockito._
import org.neo4j.cypher.internal.compiler.v3_1.helpers.DynamicIterable
import org.neo4j.cypher.internal.frontend.v3_1.SemanticDirection
import org.neo4j.cypher.internal.frontend.v3_1.test_helpers.CypherFunSuite
import org.neo4j.cypher.internal.spi.TransactionalContextWrapperv3_1
import org.neo4j.cypher.internal.spi.v3_1.TransactionBoundQueryContext.IndexSearchMonitor
import org.neo4j.cypher.javacompat.internal.GraphDatabaseCypherService
import org.neo4j.graphdb._
import org.neo4j.graphdb.config.Setting
import org.neo4j.graphdb.factory.GraphDatabaseSettings
import org.neo4j.kernel.api._
import org.neo4j.kernel.api.security.AccessMode
import org.neo4j.kernel.impl.api.{KernelStatement, KernelTransactionImplementation, StatementOperationParts}
import org.neo4j.kernel.impl.coreapi.{InternalTransaction, PropertyContainerLocker}
import org.neo4j.kernel.impl.proc.Procedures
import org.neo4j.kernel.impl.query.{Neo4jTransactionalContext, Neo4jTransactionalContextFactory, QuerySource}
import org.neo4j.storageengine.api.StorageStatement
import org.neo4j.test.TestGraphDatabaseFactory
import scala.collection.JavaConverters._
class TransactionBoundQueryContextTest extends CypherFunSuite {
var graph: GraphDatabaseCypherService = null
var outerTx: InternalTransaction = null
var statement: KernelStatement = null
val indexSearchMonitor = mock[IndexSearchMonitor]
val locker = mock[PropertyContainerLocker]
override def beforeEach() {
super.beforeEach()
graph = new GraphDatabaseCypherService(new TestGraphDatabaseFactory().newImpermanentDatabase())
outerTx = mock[InternalTransaction]
val kernelTransaction = mock[KernelTransactionImplementation]
when(kernelTransaction.mode()).thenReturn(AccessMode.Static.FULL)
val storeStatement = mock[StorageStatement]
val operations = mock[StatementOperationParts](RETURNS_DEEP_STUBS)
statement = new KernelStatement(kernelTransaction, null, storeStatement, new Procedures())
statement.initialize(null, operations)
statement.acquire()
}
override def afterEach() {
graph.getGraphDatabaseService.shutdown()
}
test("should mark transaction successful if successful") {
// GIVEN
when(outerTx.failure()).thenThrow(new AssertionError("Shouldn't be called"))
val tc = new Neo4jTransactionalContext(graph, outerTx, KernelTransaction.Type.`implicit`, AccessMode.Static.FULL,
statement, null, locker, null, null, null, null)
val transactionalContext = TransactionalContextWrapperv3_1(tc)
val context = new TransactionBoundQueryContext(transactionalContext)(indexSearchMonitor)
// WHEN
context.transactionalContext.close(success = true)
// THEN
verify(outerTx).success()
verify(outerTx).close()
verifyNoMoreInteractions(outerTx)
}
test("should mark transaction failed if not successful") {
// GIVEN
when(outerTx.success()).thenThrow(new AssertionError("Shouldn't be called"))
val tc = new Neo4jTransactionalContext(graph, outerTx, KernelTransaction.Type.`implicit`, AccessMode.Static.FULL,
statement, null, locker, null, null, null, null)
val transactionalContext = TransactionalContextWrapperv3_1(tc)
val context = new TransactionBoundQueryContext(transactionalContext)(indexSearchMonitor)
// WHEN
context.transactionalContext.close(success = false)
// THEN
verify(outerTx).failure()
verify(outerTx).close()
verifyNoMoreInteractions(outerTx)
}
test("should return fresh but equal iterators") {
// GIVEN
val relTypeName = "LINK"
val node = createMiniGraph(relTypeName)
val tx = graph.beginTransaction(KernelTransaction.Type.explicit, AccessMode.Static.READ)
val transactionalContext = TransactionalContextWrapperv3_1(createTransactionContext(graph, tx))
val context = new TransactionBoundQueryContext(transactionalContext)(indexSearchMonitor)
// WHEN
val iterable = DynamicIterable(context.getRelationshipsForIds(node, SemanticDirection.BOTH, None))
// THEN
val iteratorA: Iterator[Relationship] = iterable.iterator
val iteratorB: Iterator[Relationship] = iterable.iterator
iteratorA should not equal iteratorB
iteratorA.toList should equal(iteratorB.toList)
2 should equal(iterable.size)
tx.success()
tx.close()
}
test("should deny non-whitelisted URL protocols for loading") {
// GIVEN
val tx = graph.beginTransaction(KernelTransaction.Type.explicit, AccessMode.Static.READ)
val transactionalContext = TransactionalContextWrapperv3_1(createTransactionContext(graph, tx))
val context = new TransactionBoundQueryContext(transactionalContext)(indexSearchMonitor)
// THEN
context.getImportURL(new URL("http://localhost:7474/data.csv")) should equal(Right(new URL("http://localhost:7474/data.csv")))
context.getImportURL(new URL("file:///tmp/foo/data.csv")) should equal(Right(new URL("file:///tmp/foo/data.csv")))
context.getImportURL(new URL("jar:file:/tmp/blah.jar!/tmp/foo/data.csv")) should equal(Left("loading resources via protocol 'jar' is not permitted"))
tx.success()
tx.close()
}
test("should deny file URLs when not allowed by config") {
// GIVEN
graph.getGraphDatabaseService.shutdown()
val config = Map[Setting[_], String](GraphDatabaseSettings.allow_file_urls -> "false")
graph = new GraphDatabaseCypherService(new TestGraphDatabaseFactory().newImpermanentDatabase(config.asJava))
val tx = graph.beginTransaction(KernelTransaction.Type.explicit, AccessMode.Static.READ)
val transactionalContext = TransactionalContextWrapperv3_1(createTransactionContext(graph, tx))
val context = new TransactionBoundQueryContext(transactionalContext)(indexSearchMonitor)
// THEN
context.getImportURL(new URL("http://localhost:7474/data.csv")) should equal (Right(new URL("http://localhost:7474/data.csv")))
context.getImportURL(new URL("file:///tmp/foo/data.csv")) should equal (Left("configuration property 'dbms.security.allow_csv_import_from_file_urls' is false"))
tx.success()
tx.close()
}
private def createTransactionContext(graphDatabaseCypherService: GraphDatabaseCypherService, transaction: InternalTransaction) = {
val contextFactory = new Neo4jTransactionalContextFactory(graphDatabaseCypherService, new PropertyContainerLocker)
contextFactory.newContext(QuerySource.UNKNOWN, transaction, "no query", Collections.emptyMap())
}
private def createMiniGraph(relTypeName: String): Node = {
val relType = RelationshipType.withName(relTypeName)
val tx = graph.beginTransaction(KernelTransaction.Type.explicit, AccessMode.Static.WRITE)
try {
val node = graph.createNode()
val other1 = graph.createNode()
val other2 = graph.createNode()
node.createRelationshipTo(other1, relType)
other2.createRelationshipTo(node, relType)
tx.success()
node
}
finally {
tx.close()
}
}
}