Skip to content

Commit

Permalink
[srp-server] process completed update from proxy from taskelt (#9398)
Browse files Browse the repository at this point in the history
This commit enhances `Srp::Server` to process and commit the completed
`UpdateMetadata` entries (signaled by the "proxy service handler"
calling `HandleServiceUpdateResult()`) from a `Tasklet`. This change
is helpful in the case where the `HandleServiceUpdateResult
()` callback is invoked directly from the "update service handler"
itself. While `Srp::Server` can handle this situation, the change
makes it easier for platform implementations of advertising proxy.

In particular, it addresses an issue with the `otbr` advertising proxy
implementation. This implementation can potentially access an already
freed `Host` object. This can happen because the implementation may
hold on to the `Host` object while iterating over its `Service`
entries as advertising an earlier `Service` of the same `Host` may
fail immediately and invoke the callback directly. This would then
cause the `Host` to be freed by `Srp::Server`.
  • Loading branch information
abtink committed Sep 6, 2023
1 parent 5c051ff commit d3608df
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 23 deletions.
20 changes: 20 additions & 0 deletions src/core/common/linked_list.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,26 @@ template <typename Type> class LinkedList
aPrevEntry.SetNext(&aEntry);
}

/**
* Pushes an entry after the tail in the linked list.
*
* @param[in] aEntry A reference to an entry to push into the list.
*
*/
void PushAfterTail(Type &aEntry)
{
Type *tail = GetTail();

if (tail == nullptr)
{
Push(aEntry);
}
else
{
PushAfter(aEntry, *tail);
}
}

/**
* Pops an entry from head of the linked list.
*
Expand Down
67 changes: 48 additions & 19 deletions src/core/net/srp_server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ Server::Server(Instance &aInstance)
, mSocket(aInstance)
, mLeaseTimer(aInstance)
, mOutstandingUpdatesTimer(aInstance)
, mCompletedUpdateTask(aInstance)
, mServiceUpdateId(Random::NonCrypto::GetUint32())
, mPort(kUdpPortMin)
, mState(kStateDisabled)
Expand Down Expand Up @@ -381,26 +382,26 @@ bool Server::HasNameConflictsWith(Host &aHost) const

void Server::HandleServiceUpdateResult(ServiceUpdateId aId, Error aError)
{
UpdateMetadata *update = mOutstandingUpdates.FindMatching(aId);
UpdateMetadata *update = mOutstandingUpdates.RemoveMatching(aId);

if (update != nullptr)
{
HandleServiceUpdateResult(update, aError);
}
else
if (update == nullptr)
{
LogInfo("Delayed SRP host update result, the SRP update has been committed (updateId = %lu)", ToUlong(aId));
ExitNow();
}
}

void Server::HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError)
{
LogInfo("Handler result of SRP update (id = %lu) is received: %s", ToUlong(aUpdate->GetId()),
ErrorToString(aError));
update->SetError(aError);

LogInfo("Handler result of SRP update (id = %lu) is received: %s", ToUlong(update->GetId()), ErrorToString(aError));

IgnoreError(mOutstandingUpdates.Remove(*aUpdate));
CommitSrpUpdate(aError, *aUpdate);
aUpdate->Free();
// We add new `update` at the tail of the `mCompletedUpdates` list
// so that updates are processed in the order we receive the
// `HandleServiceUpdateResult()` callbacks for them. The
// completed updates are processed from `mCompletedUpdateTask`
// and `ProcessCompletedUpdates()`.

mCompletedUpdates.PushAfterTail(*update);
mCompletedUpdateTask.Post();

if (mOutstandingUpdates.IsEmpty())
{
Expand All @@ -410,6 +411,19 @@ void Server::HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError)
{
mOutstandingUpdatesTimer.FireAt(mOutstandingUpdates.GetTail()->GetExpireTime());
}

exit:
return;
}

void Server::ProcessCompletedUpdates(void)
{
UpdateMetadata *update;

while ((update = mCompletedUpdates.Pop()) != nullptr)
{
CommitSrpUpdate(*update);
}
}

void Server::CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &aMessageMetadata)
Expand All @@ -418,11 +432,13 @@ void Server::CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &a
aMessageMetadata.mTtlConfig, aMessageMetadata.mLeaseConfig);
}

void Server::CommitSrpUpdate(Error aError, UpdateMetadata &aUpdateMetadata)
void Server::CommitSrpUpdate(UpdateMetadata &aUpdateMetadata)
{
CommitSrpUpdate(aError, aUpdateMetadata.GetHost(), aUpdateMetadata.GetDnsHeader(),
CommitSrpUpdate(aUpdateMetadata.GetError(), aUpdateMetadata.GetHost(), aUpdateMetadata.GetDnsHeader(),
aUpdateMetadata.IsDirectRxFromClient() ? &aUpdateMetadata.GetMessageInfo() : nullptr,
aUpdateMetadata.GetTtlConfig(), aUpdateMetadata.GetLeaseConfig());

aUpdateMetadata.Free();
}

void Server::CommitSrpUpdate(Error aError,
Expand Down Expand Up @@ -1663,10 +1679,22 @@ void Server::HandleLeaseTimer(void)

void Server::HandleOutstandingUpdatesTimer(void)
{
while (!mOutstandingUpdates.IsEmpty() && mOutstandingUpdates.GetTail()->GetExpireTime() <= TimerMilli::GetNow())
TimeMilli now = TimerMilli::GetNow();
UpdateMetadata *update;

while ((update = mOutstandingUpdates.GetTail()) != nullptr)
{
LogInfo("Outstanding service update timeout (updateId = %lu)", ToUlong(mOutstandingUpdates.GetTail()->GetId()));
HandleServiceUpdateResult(mOutstandingUpdates.GetTail(), kErrorResponseTimeout);
if (update->GetExpireTime() > now)
{
mOutstandingUpdatesTimer.FireAtIfEarlier(update->GetExpireTime());
break;
}

LogInfo("Outstanding service update timeout (updateId = %lu)", ToUlong(update->GetId()));

IgnoreError(mOutstandingUpdates.Remove(*update));
update->SetError(kErrorResponseTimeout);
CommitSrpUpdate(*update);
}
}

Expand Down Expand Up @@ -2097,6 +2125,7 @@ Server::UpdateMetadata::UpdateMetadata(Instance &aInstance, Host &aHost, const M
, mTtlConfig(aMessageMetadata.mTtlConfig)
, mLeaseConfig(aMessageMetadata.mLeaseConfig)
, mHost(aHost)
, mError(kErrorNone)
, mIsDirectRxFromClient(aMessageMetadata.IsDirectRxFromClient())
{
if (aMessageMetadata.mMessageInfo != nullptr)
Expand Down
14 changes: 10 additions & 4 deletions src/core/net/srp_server.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,6 +912,8 @@ class Server : public InstanceLocator, private NonCopyable
const LeaseConfig &GetLeaseConfig(void) const { return mLeaseConfig; }
Host &GetHost(void) { return mHost; }
const Ip6::MessageInfo &GetMessageInfo(void) const { return mMessageInfo; }
Error GetError(void) const { return mError; }
void SetError(Error aError) { mError = aError; }
bool IsDirectRxFromClient(void) const { return mIsDirectRxFromClient; }
bool Matches(ServiceUpdateId aId) const { return mId == aId; }

Expand All @@ -926,6 +928,7 @@ class Server : public InstanceLocator, private NonCopyable
LeaseConfig mLeaseConfig; // Lease config to use when processing the message.
Host &mHost; // The `UpdateMetadata` has no ownership of this host.
Ip6::MessageInfo mMessageInfo; // Valid when `mIsDirectRxFromClient` is true.
Error mError;
bool mIsDirectRxFromClient;
};

Expand All @@ -948,7 +951,7 @@ class Server : public InstanceLocator, private NonCopyable

void InformUpdateHandlerOrCommit(Error aError, Host &aHost, const MessageMetadata &aMetadata);
void CommitSrpUpdate(Error aError, Host &aHost, const MessageMetadata &aMessageMetadata);
void CommitSrpUpdate(Error aError, UpdateMetadata &aUpdateMetadata);
void CommitSrpUpdate(UpdateMetadata &aUpdateMetadata);
void CommitSrpUpdate(Error aError,
Host &aHost,
const Dns::UpdateHeader &aDnsHeader,
Expand Down Expand Up @@ -998,15 +1001,16 @@ class Server : public InstanceLocator, private NonCopyable
void HandleLeaseTimer(void);
static void HandleOutstandingUpdatesTimer(Timer &aTimer);
void HandleOutstandingUpdatesTimer(void);
void ProcessCompletedUpdates(void);

void HandleServiceUpdateResult(UpdateMetadata *aUpdate, Error aError);
const UpdateMetadata *FindOutstandingUpdate(const MessageMetadata &aMessageMetadata) const;
static const char *AddressModeToString(AddressMode aMode);

void UpdateResponseCounters(Dns::Header::Response aResponseCode);

using LeaseTimer = TimerMilliIn<Server, &Server::HandleLeaseTimer>;
using UpdateTimer = TimerMilliIn<Server, &Server::HandleOutstandingUpdatesTimer>;
using LeaseTimer = TimerMilliIn<Server, &Server::HandleLeaseTimer>;
using UpdateTimer = TimerMilliIn<Server, &Server::HandleOutstandingUpdatesTimer>;
using CompletedUpdatesTask = TaskletIn<Server, &Server::ProcessCompletedUpdates>;

Ip6::Udp::Socket mSocket;

Expand All @@ -1022,6 +1026,8 @@ class Server : public InstanceLocator, private NonCopyable

UpdateTimer mOutstandingUpdatesTimer;
LinkedList<UpdateMetadata> mOutstandingUpdates;
LinkedList<UpdateMetadata> mCompletedUpdates;
CompletedUpdatesTask mCompletedUpdateTask;

ServiceUpdateId mServiceUpdateId;
uint16_t mPort;
Expand Down
10 changes: 10 additions & 0 deletions tests/unit/test_linked_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,16 @@ void TestLinkedList(void)
list.RemoveAllMatching(kBetaType, removedList);
VerifyLinkedListContent(&list, &a, &b, &e, nullptr);
VerifyLinkedListContent(&removedList, &f, &d, &c, nullptr);

list.Clear();
list.PushAfterTail(a);
VerifyLinkedListContent(&list, &a, nullptr);
list.PushAfterTail(b);
VerifyLinkedListContent(&list, &a, &b, nullptr);
list.PushAfterTail(c);
VerifyLinkedListContent(&list, &a, &b, &c, nullptr);
list.PushAfterTail(d);
VerifyLinkedListContent(&list, &a, &b, &c, &d, nullptr);
}

void TestOwningList(void)
Expand Down

0 comments on commit d3608df

Please sign in to comment.