Skip to content

Commit

Permalink
util: fix potential cross-tenants queries in InternalCallContextFactory
Browse files Browse the repository at this point in the history
Restrict usage of NonEntityDao. Because it doesn't validate the context,
this DAO should only be used internally. InternalCallContextFactory now
features safe wrappers to do the tenant check in code.

This also greatly simplifies the code as the CacheControllerDispatcher
doesn't need to be injected all over the place.

Signed-off-by: Pierre-Alexandre Meyer <pierre@mouraf.org>
  • Loading branch information
pierre committed Jan 22, 2015
1 parent 9613649 commit baa7c20
Show file tree
Hide file tree
Showing 39 changed files with 514 additions and 661 deletions.
@@ -1,7 +1,9 @@
/*
* Copyright 2010-2013 Ning, Inc.
* Copyright 2014-2015 Groupon, Inc
* Copyright 2014-2015 The Billing Project, LLC
*
* Ning licenses this file to you under the Apache License, version 2.0
* The Billing Project licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
Expand All @@ -22,7 +24,6 @@
import javax.inject.Named;

import org.killbill.billing.ObjectType;
import org.killbill.billing.account.api.AccountInternalApi;
import org.killbill.billing.callcontext.InternalCallContext;
import org.killbill.billing.entitlement.EntitlementTransitionType;
import org.killbill.billing.events.AccountChangeInternalEvent;
Expand All @@ -46,12 +47,10 @@
import org.killbill.billing.lifecycle.glue.BusModule;
import org.killbill.billing.notification.plugin.api.ExtBusEventType;
import org.killbill.billing.subscription.api.SubscriptionBaseTransitionType;
import org.killbill.billing.util.cache.Cachable.CacheType;
import org.killbill.billing.util.cache.CacheControllerDispatcher;
import org.killbill.billing.util.callcontext.CallOrigin;
import org.killbill.billing.util.callcontext.InternalCallContextFactory;
import org.killbill.billing.util.callcontext.TenantContext;
import org.killbill.billing.util.callcontext.UserType;
import org.killbill.billing.util.dao.NonEntityDao;
import org.killbill.bus.api.BusEvent;
import org.killbill.bus.api.PersistentBus;
import org.killbill.bus.api.PersistentBus.EventBusException;
Expand All @@ -70,23 +69,14 @@ public class BeatrixListener {

private final PersistentBus externalBus;
private final InternalCallContextFactory internalCallContextFactory;
private final AccountInternalApi accountApi;
private final NonEntityDao nonEntityDao;
private final CacheControllerDispatcher cacheControllerDispatcher;

protected final ObjectMapper objectMapper;

@Inject
public BeatrixListener(@Named(BusModule.EXTERNAL_BUS_NAMED) final PersistentBus externalBus,
final InternalCallContextFactory internalCallContextFactory,
final AccountInternalApi accountApi,
final CacheControllerDispatcher cacheControllerDispatcher,
final NonEntityDao nonEntityDao) {
final InternalCallContextFactory internalCallContextFactory) {
this.externalBus = externalBus;
this.internalCallContextFactory = internalCallContextFactory;
this.accountApi = accountApi;
this.nonEntityDao = nonEntityDao;
this.cacheControllerDispatcher = cacheControllerDispatcher;
this.objectMapper = new ObjectMapper();
objectMapper.registerModule(new JodaModule());
objectMapper.disable(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS);
Expand Down Expand Up @@ -243,19 +233,20 @@ private BusEvent computeExtBusEventEntryFromBusInternalEvent(final BusInternalEv

default:
}
final UUID accountId = getAccountIdFromRecordId(event.getBusEventType(), objectId, context.getAccountRecordId());
final UUID tenantId = nonEntityDao.retrieveIdFromObject(context.getTenantRecordId(), ObjectType.TENANT, cacheControllerDispatcher.getCacheController(CacheType.OBJECT_ID));

final TenantContext tenantContext = internalCallContextFactory.createTenantContext(context);
final UUID accountId = getAccountId(event.getBusEventType(), objectId, objectType, tenantContext);

return eventBusType != null ?
new DefaultBusExternalEvent(objectId, objectType, eventBusType, accountId, tenantId, context.getAccountRecordId(), context.getTenantRecordId(), context.getUserToken()) :
new DefaultBusExternalEvent(objectId, objectType, eventBusType, accountId, tenantContext.getTenantId(), context.getAccountRecordId(), context.getTenantRecordId(), context.getUserToken()) :
null;
}

private UUID getAccountIdFromRecordId(final BusInternalEventType eventType, final UUID objectId, final Long recordId) {
private UUID getAccountId(final BusInternalEventType eventType, final UUID objectId, final ObjectType objectType, final TenantContext context) {
// accountRecord_id is not set for ACCOUNT_CREATE event as we are in the transaction and value is known yet
if (eventType == BusInternalEventType.ACCOUNT_CREATE) {
return objectId;
}
return nonEntityDao.retrieveIdFromObject(recordId, ObjectType.ACCOUNT, cacheControllerDispatcher.getCacheController(CacheType.OBJECT_ID));
return internalCallContextFactory.getAccountId(objectId, objectType, context);
}
}
@@ -1,7 +1,7 @@
/*
* Copyright 2010-2013 Ning, Inc.
* Copyright 2014 Groupon, Inc
* Copyright 2014 The Billing Project, LLC
* Copyright 2014-2015 Groupon, Inc
* Copyright 2014-2015 The Billing Project, LLC
*
* The Billing Project licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
Expand All @@ -21,7 +21,6 @@
import java.util.UUID;

import org.joda.time.DateTime;
import org.killbill.billing.ObjectType;
import org.killbill.billing.callcontext.InternalCallContext;
import org.killbill.billing.entitlement.api.DefaultBlockingTransitionInternalEvent;
import org.killbill.billing.entitlement.api.DefaultEntitlement;
Expand All @@ -34,12 +33,10 @@
import org.killbill.billing.entitlement.engine.core.EntitlementNotificationKeyAction;
import org.killbill.billing.platform.api.LifecycleHandlerType;
import org.killbill.billing.platform.api.LifecycleHandlerType.LifecycleLevel;
import org.killbill.billing.util.cache.Cachable.CacheType;
import org.killbill.billing.util.cache.CacheControllerDispatcher;
import org.killbill.billing.util.callcontext.CallContext;
import org.killbill.billing.util.callcontext.CallOrigin;
import org.killbill.billing.util.callcontext.InternalCallContextFactory;
import org.killbill.billing.util.callcontext.UserType;
import org.killbill.billing.util.dao.NonEntityDao;
import org.killbill.bus.api.BusEvent;
import org.killbill.bus.api.PersistentBus;
import org.killbill.bus.api.PersistentBus.EventBusException;
Expand All @@ -62,29 +59,23 @@ public class DefaultEntitlementService implements EntitlementService {

private final EntitlementApi entitlementApi;
private final BlockingStateDao blockingStateDao;
private final NonEntityDao nonEntityDao;
private final PersistentBus eventBus;
private final NotificationQueueService notificationQueueService;
private final InternalCallContextFactory internalCallContextFactory;
private final CacheControllerDispatcher controllerDispatcher;

private NotificationQueue entitlementEventQueue;

@Inject
public DefaultEntitlementService(final EntitlementApi entitlementApi,
final BlockingStateDao blockingStateDao,
final NonEntityDao nonEntityDao,
final PersistentBus eventBus,
final NotificationQueueService notificationQueueService,
final InternalCallContextFactory internalCallContextFactory,
final CacheControllerDispatcher controllerDispatcher) {
final InternalCallContextFactory internalCallContextFactory) {
this.entitlementApi = entitlementApi;
this.blockingStateDao = blockingStateDao;
this.nonEntityDao = nonEntityDao;
this.eventBus = eventBus;
this.notificationQueueService = notificationQueueService;
this.internalCallContextFactory = internalCallContextFactory;
this.controllerDispatcher = controllerDispatcher;
}

@Override
Expand All @@ -101,8 +92,8 @@ public void handleReadyNotification(final NotificationEvent inputKey, final Date
final InternalCallContext internalCallContext = internalCallContextFactory.createInternalCallContext(tenantRecordId, accountRecordId, "EntitlementQueue", CallOrigin.INTERNAL, UserType.SYSTEM, fromNotificationQueueUserToken);

if (inputKey instanceof EntitlementNotificationKey) {
final UUID tenantId = nonEntityDao.retrieveIdFromObject(tenantRecordId, ObjectType.TENANT, controllerDispatcher.getCacheController(CacheType.OBJECT_ID));
processEntitlementNotification((EntitlementNotificationKey) inputKey, tenantId, internalCallContext);
final CallContext callContext = internalCallContextFactory.createCallContext(internalCallContext);
processEntitlementNotification((EntitlementNotificationKey) inputKey, internalCallContext, callContext);
} else if (inputKey instanceof BlockingTransitionNotificationKey) {
processBlockingNotification((BlockingTransitionNotificationKey) inputKey, internalCallContext);
} else if (inputKey != null) {
Expand All @@ -121,10 +112,10 @@ public void handleReadyNotification(final NotificationEvent inputKey, final Date
}
}

private void processEntitlementNotification(final EntitlementNotificationKey key, final UUID tenantId, final InternalCallContext internalCallContext) {
private void processEntitlementNotification(final EntitlementNotificationKey key, final InternalCallContext internalCallContext, final CallContext callContext) {
final Entitlement entitlement;
try {
entitlement = entitlementApi.getEntitlementForId(key.getEntitlementId(), internalCallContext.toTenantContext(tenantId));
entitlement = entitlementApi.getEntitlementForId(key.getEntitlementId(), callContext);
} catch (final EntitlementApiException e) {
log.error("Error retrieving entitlement for id " + key.getEntitlementId(), e);
return;
Expand All @@ -139,11 +130,13 @@ private void processEntitlementNotification(final EntitlementNotificationKey key
try {
if (EntitlementNotificationKeyAction.CHANGE.equals(entitlementNotificationKeyAction) ||
EntitlementNotificationKeyAction.CANCEL.equals(entitlementNotificationKeyAction)) {
((DefaultEntitlement) entitlement).blockAddOnsIfRequired(key.getEffectiveDate(), internalCallContext.toTenantContext(tenantId), internalCallContext);
} else if (EntitlementNotificationKeyAction.PAUSE.equals(entitlementNotificationKeyAction)) {
entitlementApi.pause(key.getBundleId(), key.getEffectiveDate().toLocalDate(), internalCallContext.toCallContext(tenantId));
} else if (EntitlementNotificationKeyAction.RESUME.equals(entitlementNotificationKeyAction)) {
entitlementApi.resume(key.getBundleId(), key.getEffectiveDate().toLocalDate(), internalCallContext.toCallContext(tenantId));
((DefaultEntitlement) entitlement).blockAddOnsIfRequired(key.getEffectiveDate(), callContext, internalCallContext);
} else {
if (EntitlementNotificationKeyAction.PAUSE.equals(entitlementNotificationKeyAction)) {
entitlementApi.pause(key.getBundleId(), key.getEffectiveDate().toLocalDate(), callContext);
} else if (EntitlementNotificationKeyAction.RESUME.equals(entitlementNotificationKeyAction)) {
entitlementApi.resume(key.getBundleId(), key.getEffectiveDate().toLocalDate(), callContext);
}

This comment has been minimized.

Copy link
@sbrossie

sbrossie Jan 22, 2015

Member

Slightly confused with the reordering if the if-then-else but this is probably good.

This comment has been minimized.

Copy link
@pierre

pierre Jan 23, 2015

Author Member

Uh, weird indeed. I hadn't realized it (maybe some IDEA correction?). I'll revert to the previous else if / else if.

}
} catch (final EntitlementApiException e) {
log.error("Error processing event for entitlement {}" + entitlement.getId(), e);
Expand All @@ -164,7 +157,7 @@ private void processBlockingNotification(final BlockingTransitionNotificationKey

try {
eventBus.post(event);
} catch (EventBusException e) {
} catch (final EventBusException e) {
log.warn("Failed to post event {}", e);
}
}
Expand Down
@@ -1,7 +1,9 @@
/*
* Copyright 2010-2013 Ning, Inc.
* Copyright 2014-2015 Groupon, Inc
* Copyright 2014-2015 The Billing Project, LLC
*
* Ning licenses this file to you under the Apache License, version 2.0
* The Billing Project licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
Expand All @@ -27,9 +29,6 @@
import javax.inject.Inject;

import org.joda.time.DateTimeZone;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.killbill.billing.ErrorCode;
import org.killbill.billing.ObjectType;
import org.killbill.billing.callcontext.InternalCallContext;
Expand All @@ -41,15 +40,14 @@
import org.killbill.billing.subscription.api.SubscriptionBaseInternalApi;
import org.killbill.billing.subscription.api.user.SubscriptionBaseApiException;
import org.killbill.billing.subscription.api.user.SubscriptionBaseBundle;
import org.killbill.billing.util.cache.Cachable.CacheType;
import org.killbill.billing.util.cache.CacheControllerDispatcher;
import org.killbill.billing.util.callcontext.CallContext;
import org.killbill.billing.util.callcontext.InternalCallContextFactory;
import org.killbill.billing.util.callcontext.TenantContext;
import org.killbill.billing.util.customfield.ShouldntHappenException;
import org.killbill.billing.util.dao.NonEntityDao;
import org.killbill.billing.util.entity.Pagination;
import org.killbill.billing.util.entity.dao.DefaultPaginationHelper.SourcePaginationBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import com.google.common.base.Function;
import com.google.common.base.Optional;
Expand Down Expand Up @@ -85,31 +83,26 @@ public int compare(final SubscriptionBundle o1, final SubscriptionBundle o2) {
private final EntitlementInternalApi entitlementInternalApi;
private final SubscriptionBaseInternalApi subscriptionBaseInternalApi;
private final InternalCallContextFactory internalCallContextFactory;
private final NonEntityDao nonEntityDao;
private final CacheControllerDispatcher cacheControllerDispatcher;
private final EntitlementUtils entitlementUtils;

@Inject
public DefaultSubscriptionApi(final EntitlementInternalApi entitlementInternalApi, final SubscriptionBaseInternalApi subscriptionInternalApi,
final InternalCallContextFactory internalCallContextFactory, final EntitlementUtils entitlementUtils, final NonEntityDao nonEntityDao, final CacheControllerDispatcher cacheControllerDispatcher) {
final InternalCallContextFactory internalCallContextFactory, final EntitlementUtils entitlementUtils) {
this.entitlementInternalApi = entitlementInternalApi;
this.subscriptionBaseInternalApi = subscriptionInternalApi;
this.internalCallContextFactory = internalCallContextFactory;
this.nonEntityDao = nonEntityDao;
this.entitlementUtils = entitlementUtils;
this.cacheControllerDispatcher = cacheControllerDispatcher;
}

@Override
public Subscription getSubscriptionForEntitlementId(final UUID entitlementId, final TenantContext context) throws SubscriptionApiException {
final Long accountRecordId = nonEntityDao.retrieveAccountRecordIdFromObject(entitlementId, ObjectType.SUBSCRIPTION, cacheControllerDispatcher.getCacheController(CacheType.ACCOUNT_RECORD_ID));
final UUID accountId = nonEntityDao.retrieveIdFromObject(accountRecordId, ObjectType.ACCOUNT, cacheControllerDispatcher.getCacheController(CacheType.OBJECT_ID));
final UUID accountId = internalCallContextFactory.getAccountId(entitlementId, ObjectType.SUBSCRIPTION, context);

// Retrieve entitlements
final AccountEntitlements accountEntitlements;
try {
accountEntitlements = entitlementInternalApi.getAllEntitlementsForAccountId(accountId, context);
} catch (EntitlementApiException e) {
} catch (final EntitlementApiException e) {
throw new SubscriptionApiException(e);
}

Expand All @@ -127,8 +120,7 @@ public boolean apply(final Subscription subscription) {

@Override
public SubscriptionBundle getSubscriptionBundle(final UUID bundleId, final TenantContext context) throws SubscriptionApiException {
final Long accountRecordId = nonEntityDao.retrieveAccountRecordIdFromObject(bundleId, ObjectType.BUNDLE, cacheControllerDispatcher.getCacheController(CacheType.ACCOUNT_RECORD_ID));
final UUID accountId = nonEntityDao.retrieveIdFromObject(accountRecordId, ObjectType.ACCOUNT, cacheControllerDispatcher.getCacheController(CacheType.OBJECT_ID));
final UUID accountId = internalCallContextFactory.getAccountId(bundleId, ObjectType.BUNDLE, context);

final Optional<SubscriptionBundle> bundleOptional = Iterables.<SubscriptionBundle>tryFind(getSubscriptionBundlesForAccount(accountId, context),
new Predicate<SubscriptionBundle>() {
Expand Down
@@ -1,7 +1,9 @@
/*
* Copyright 2010-2013 Ning, Inc.
* Copyright 2014-2015 Groupon, Inc
* Copyright 2014-2015 The Billing Project, LLC
*
* Ning licenses this file to you under the Apache License, version 2.0
* The Billing Project licenses this file to you under the Apache License, version 2.0
* (the "License"); you may not use this file except in compliance with the
* License. You may obtain a copy of the License at:
*
Expand All @@ -19,19 +21,16 @@
import java.util.List;
import java.util.UUID;

import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import org.killbill.billing.account.api.Account;
import org.killbill.billing.api.TestApiListener.NextEvent;
import org.killbill.billing.callcontext.InternalCallContext;
import org.killbill.billing.entitlement.EntitlementTestSuiteWithEmbeddedDB;
import org.killbill.billing.entitlement.api.BlockingState;
import org.killbill.billing.entitlement.api.BlockingStateType;
import org.killbill.billing.junction.DefaultBlockingState;
import org.killbill.billing.util.callcontext.CallOrigin;
import org.killbill.billing.util.callcontext.UserType;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;

import com.google.common.base.Predicate;
import com.google.common.collect.Collections2;
Expand Down Expand Up @@ -82,7 +81,7 @@ public void testApiHistory() throws Exception {
final boolean blockBilling = false;

final Account account = accountApi.createAccount(getAccountData(7), callContext);
final InternalCallContext internalCallContext = internalCallContextFactory.createInternalCallContext(account.getId(), "TestBlockingApi", CallOrigin.TEST, UserType.SYSTEM, UUID.randomUUID());
final InternalCallContext internalCallContext = internalCallContextFactory.createInternalCallContext(account.getId(), callContext);

testListener.pushExpectedEvent(NextEvent.BLOCK);
final BlockingState state1 = new DefaultBlockingState(uuid, BlockingStateType.ACCOUNT, overdueStateName, service, blockChange, blockEntitlement, blockBilling, clock.getUTCNow());
Expand Down

0 comments on commit baa7c20

Please sign in to comment.