Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Analyics: stop double reporting posthog utds #8801

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package im.vector.app.features

import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import im.vector.app.InstrumentedTest
import im.vector.app.features.analytics.ReportedDecryptionFailurePersistence
import kotlinx.coroutines.test.runTest
import org.amshove.kluent.shouldBeEqualTo
import org.junit.Test
import org.junit.runner.RunWith

@RunWith(AndroidJUnit4::class)
class ReportedDecryptionFailurePersistenceTest : InstrumentedTest {

private val context = InstrumentationRegistry.getInstrumentation().targetContext

@Test
fun shouldPersistReportedUtds() = runTest {
val persistence = ReportedDecryptionFailurePersistence(context)
persistence.load()

val eventIds = listOf("$0000", "$0001", "$0002", "$0003")
eventIds.forEach {
persistence.markAsReported(it)
}

eventIds.forEach {
persistence.hasBeenReported(it) shouldBeEqualTo true
}

persistence.hasBeenReported("$0004") shouldBeEqualTo false

persistence.persist()

// Load a new one
val persistence2 = ReportedDecryptionFailurePersistence(context)
persistence2.load()

eventIds.forEach {
persistence2.hasBeenReported(it) shouldBeEqualTo true
}
}

@Test
fun testSaturation() = runTest {
val persistence = ReportedDecryptionFailurePersistence(context)

for (i in 1..6000) {
persistence.markAsReported("000$i")
}

// This should have saturated the bloom filter, making the rate of false positives too high.
// A new bloom filter should have been created to avoid that and the recent reported events should still be in the new filter.
for (i in 5800..6000) {
persistence.hasBeenReported("000$i") shouldBeEqualTo true
}

// Old ones should not be there though
for (i in 1..1000) {
persistence.hasBeenReported("000$i") shouldBeEqualTo false
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ private const val MAX_WAIT_MILLIS = 60_000
class DecryptionFailureTracker @Inject constructor(
private val analyticsTracker: AnalyticsTracker,
private val sessionDataSource: ActiveSessionDataSource,
private val decryptionFailurePersistence: ReportedDecryptionFailurePersistence,
private val clock: Clock
) : Session.Listener, LiveEventListener {

Expand All @@ -76,9 +77,6 @@ class DecryptionFailureTracker @Inject constructor(
// Only accessed on a `post` call, ensuring sequential access
private val trackedEventsMap = mutableMapOf<String, DecryptionFailure>()

// List of eventId that have been reported, to avoid double reporting
private val alreadyReported = mutableListOf<String>()

// Mutex to ensure sequential access to internal state
private val mutex = Mutex()

Expand All @@ -98,10 +96,16 @@ class DecryptionFailureTracker @Inject constructor(
this.scope = scope
}
observeActiveSession()
post {
decryptionFailurePersistence.load()
}
}

fun stop() {
Timber.v("Stop DecryptionFailureTracker")
post {
decryptionFailurePersistence.persist()
}
activeSessionSourceDisposable.cancel(CancellationException("Closing DecryptionFailureTracker"))

activeSession?.removeListener(this)
Expand All @@ -123,6 +127,7 @@ class DecryptionFailureTracker @Inject constructor(
delay(CHECK_INTERVAL)
post {
checkFailures()
decryptionFailurePersistence.persist()
currentTicker = null
if (trackedEventsMap.isNotEmpty()) {
// Reschedule
Expand All @@ -136,15 +141,15 @@ class DecryptionFailureTracker @Inject constructor(
.distinctUntilChanged()
.onEach {
Timber.v("Active session changed ${it.getOrNull()?.myUserId}")
it.orNull()?.let { session ->
it.getOrNull()?.let { session ->
post {
onSessionActive(session)
}
}
}.launchIn(scope)
}

private fun onSessionActive(session: Session) {
private suspend fun onSessionActive(session: Session) {
Timber.v("onSessionActive ${session.myUserId} previous: ${activeSession?.myUserId}")
val sessionId = session.sessionId
if (sessionId == activeSession?.sessionId) {
Expand Down Expand Up @@ -201,7 +206,8 @@ class DecryptionFailureTracker @Inject constructor(
// already tracked
return
}
if (alreadyReported.contains(eventId)) {
if (decryptionFailurePersistence.hasBeenReported(eventId)) {
Timber.v("Event $eventId already reported")
// already reported
return
}
Expand Down Expand Up @@ -236,7 +242,7 @@ class DecryptionFailureTracker @Inject constructor(
}
}

private fun handleEventDecrypted(eventId: String) {
private suspend fun handleEventDecrypted(eventId: String) {
Timber.v("Handle event decrypted $eventId time: ${clock.epochMillis()}")
// Only consider if it was tracked as a failure
val trackedFailure = trackedEventsMap[eventId] ?: return
Expand Down Expand Up @@ -269,7 +275,7 @@ class DecryptionFailureTracker @Inject constructor(
}

// This will mutate the trackedEventsMap, so don't call it while iterating on it.
private fun reportFailure(decryptionFailure: DecryptionFailure) {
private suspend fun reportFailure(decryptionFailure: DecryptionFailure) {
Timber.v("Report failure for event ${decryptionFailure.failedEventId}")
val error = decryptionFailure.toAnalyticsEvent()

Expand All @@ -278,10 +284,10 @@ class DecryptionFailureTracker @Inject constructor(
// now remove from tracked
trackedEventsMap.remove(decryptionFailure.failedEventId)
// mark as already reported
alreadyReported.add(decryptionFailure.failedEventId)
decryptionFailurePersistence.markAsReported(decryptionFailure.failedEventId)
}

private fun checkFailures() {
private suspend fun checkFailures() {
val now = clock.epochMillis()
Timber.v("Check failures now $now")
// report the definitely failed
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
/*
* Copyright (c) 2024 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package im.vector.app.features.analytics

import android.content.Context
import android.util.LruCache
import com.google.common.hash.BloomFilter
import com.google.common.hash.Funnels
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.withContext
import timber.log.Timber
import java.io.File
import java.io.FileOutputStream
import javax.inject.Inject

private const val REPORTED_UTD_FILE_NAME = "im.vector.analytics.reported_utd"
private const val EXPECTED_INSERTIONS = 5000

/**
* This class is used to keep track of the reported decryption failures to avoid double reporting.
* It uses a bloom filter to limit the memory/disk usage.
*/
class ReportedDecryptionFailurePersistence @Inject constructor(
private val context: Context,
) {

// Keep a cache of recent reported failures in memory.
// They will be persisted to the a new bloom filter if the previous one is getting saturated.
// Should be around 30KB max in memory.
// Also allows to have 0% false positive rate for recent failures.
private val inMemoryReportedFailures: LruCache<String, Unit> = LruCache(300)

// Thread-safe and lock-free.
// The expected insertions is 5000, and expected false positive probability of 3% when close to max capability.
// The persisted size is expected to be around 5KB (100 times less than if it was raw strings).
private var bloomFilter: BloomFilter<String> = BloomFilter.create<String>(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS)

/**
* Mark an event as reported.
* @param eventId the event id to mark as reported.
*/
suspend fun markAsReported(eventId: String) {
// Add to in memory cache.
inMemoryReportedFailures.put(eventId, Unit)
bloomFilter.put(eventId)

// check if the filter is getting saturated? and then replace
if (bloomFilter.approximateElementCount() > EXPECTED_INSERTIONS - 500) {
// The filter is getting saturated, and the false positive rate is increasing.
// It's time to replace the filter with a new one. And move the in-memory cache to the new filter.
bloomFilter = BloomFilter.create<String>(Funnels.stringFunnel(Charsets.UTF_8), EXPECTED_INSERTIONS)
inMemoryReportedFailures.snapshot().keys.forEach {
bloomFilter.put(it)
}
persist()
}
Timber.v("## Bloom filter stats: expectedFpp: ${bloomFilter.expectedFpp()}, size: ${bloomFilter.approximateElementCount()}")
}

/**
* Check if an event has been reported.
* @param eventId the event id to check.
* @return true if the event has been reported.
*/
fun hasBeenReported(eventId: String): Boolean {
// First check in memory cache.
if (inMemoryReportedFailures.get(eventId) != null) {
return true
}
return bloomFilter.mightContain(eventId)
}

/**
* Load the reported failures from disk.
*/
suspend fun load() {
withContext(Dispatchers.IO) {
try {
val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME)
if (file.exists()) {
file.inputStream().use {
bloomFilter = BloomFilter.readFrom(it, Funnels.stringFunnel(Charsets.UTF_8))
}
}
} catch (e: Throwable) {
Timber.e(e, "## Failed to load reported failures")
}
}
}

/**
* Persist the reported failures to disk.
*/
suspend fun persist() {
withContext(Dispatchers.IO) {
try {
val file = File(context.applicationContext.cacheDir, REPORTED_UTD_FILE_NAME)
if (!file.exists()) file.createNewFile()
FileOutputStream(file).buffered().use {
bloomFilter.writeTo(it)
}
Timber.v("## Successfully saved reported failures, size: ${file.length()}")
} catch (e: Throwable) {
Timber.e(e, "## Failed to save reported failures")
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import im.vector.app.test.fakes.FakeAnalyticsTracker
import im.vector.app.test.fakes.FakeClock
import im.vector.app.test.fakes.FakeSession
import im.vector.app.test.shared.createTimberTestRule
import io.mockk.coEvery
import io.mockk.every
import io.mockk.just
import io.mockk.mockk
Expand Down Expand Up @@ -60,9 +61,24 @@ class DecryptionFailureTrackerTest {

private val fakeClock = FakeClock()

val reportedEvents = mutableSetOf<String>()

private val fakePersistence = mockk<ReportedDecryptionFailurePersistence> {

coEvery { load() } just runs
coEvery { persist() } just runs
coEvery { markAsReported(any()) } coAnswers {
reportedEvents.add(firstArg())
}
every { hasBeenReported(any()) } answers {
reportedEvents.contains(firstArg())
}
}

private val decryptionFailureTracker = DecryptionFailureTracker(
fakeAnalyticsTracker,
fakeActiveSessionDataSource.instance,
fakePersistence,
fakeClock
)

Expand Down Expand Up @@ -101,6 +117,7 @@ class DecryptionFailureTrackerTest {

@Before
fun setupTest() {
reportedEvents.clear()
fakeMxOrgTestSession.fakeCryptoService.fakeCrossSigningService.givenIsCrossSigningVerifiedReturns(false)
}

Expand Down
Loading