Skip to content

Commit

Permalink
Use safer client socket IDs
Browse files Browse the repository at this point in the history
  • Loading branch information
hluk committed May 14, 2018
1 parent 8bb30cc commit 6eb9625
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 29 deletions.
36 changes: 20 additions & 16 deletions src/app/clipboardserver.cpp
Expand Up @@ -157,10 +157,13 @@ void ClipboardServer::stopMonitoring()
COPYQ_LOG("Terminating monitor");

for (auto it = m_clients.constBegin(); it != m_clients.constEnd(); ++it) {
const auto actionId = it.value().proxy->actionId();
const auto &clientData = it.value();
if (!clientData.isValid())
continue;

const auto actionId = clientData.proxy->actionId();
if ( actionId == m_monitor->id() ) {
const auto client = it.key();
client->sendMessage(QByteArray(), CommandStop);
clientData.client->sendMessage(QByteArray(), CommandStop);
break;
}
}
Expand Down Expand Up @@ -321,8 +324,9 @@ bool ClipboardServer::hasRunningCommands() const
void ClipboardServer::terminateClients(int waitMs)
{
for (auto it = m_clients.constBegin(); it != m_clients.constEnd(); ++it) {
const auto client = it.key();
client->sendMessage(QByteArray(), CommandStop);
const auto &clientData = it.value();
if (clientData.isValid())
clientData.client->sendMessage(QByteArray(), CommandStop);
}

waitForClientsToFinish(waitMs);
Expand All @@ -339,7 +343,7 @@ void ClipboardServer::waitForClientsToFinish(int waitMs)
void ClipboardServer::onClientNewConnection(const ClientSocketPtr &client)
{
auto proxy = new ScriptableProxy(m_wnd, client.get());
m_clients.insert( client.get(), ClientData(client, proxy) );
m_clients.insert( client->id(), ClientData(client, proxy) );
connect( this, &ClipboardServer::closeClients,
client.get(), &ClientSocket::close );
connect( client.get(), &ClientSocket::messageReceived,
Expand All @@ -352,16 +356,16 @@ void ClipboardServer::onClientNewConnection(const ClientSocketPtr &client)
}

void ClipboardServer::onClientMessageReceived(
const QByteArray &message, int messageCode, ClientSocket *client)
const QByteArray &message, int messageCode, ClientSocketId clientId)
{
Q_UNUSED(client);
switch (messageCode) {
case CommandFunctionCall: {
auto proxy = m_clients.value(client).proxy;
if (!proxy)
const auto &clientData = m_clients.value(clientId);
if (!clientData.isValid())
return;
const auto result = proxy->callFunction(message);
client->sendMessage(result, CommandFunctionCallReturnValue);

const auto result = clientData.proxy->callFunction(message);
clientData.client->sendMessage(result, CommandFunctionCallReturnValue);
break;
}
default:
Expand All @@ -370,15 +374,15 @@ void ClipboardServer::onClientMessageReceived(
}
}

void ClipboardServer::onClientDisconnected(ClientSocket *client)
void ClipboardServer::onClientDisconnected(ClientSocketId clientId)
{
m_clients.remove(client);
m_clients.remove(clientId);
}

void ClipboardServer::onClientConnectionFailed(ClientSocket *client)
void ClipboardServer::onClientConnectionFailed(ClientSocketId clientId)
{
log("Client connection failed", LogWarning);
m_clients.remove(client);
m_clients.remove(clientId);
}

void ClipboardServer::onMonitorFinished()
Expand Down
15 changes: 11 additions & 4 deletions src/app/clipboardserver.h
Expand Up @@ -23,6 +23,7 @@
#include "app.h"
#include "common/clipboardmode.h"
#include "common/server.h"
#include "common/clientsocket.h"

#include <QMap>
#include <QPointer>
Expand Down Expand Up @@ -78,9 +79,9 @@ public slots:

private slots:
void onClientNewConnection(const ClientSocketPtr &client);
void onClientMessageReceived(const QByteArray &message, int messageCode, ClientSocket *client);
void onClientDisconnected(ClientSocket *client);
void onClientConnectionFailed(ClientSocket *client);
void onClientMessageReceived(const QByteArray &message, int messageCode, ClientSocketId clientId);
void onClientDisconnected(ClientSocketId clientId);
void onClientConnectionFailed(ClientSocketId clientId);

/** An error occurred on monitor connection. */
void onMonitorFinished();
Expand Down Expand Up @@ -138,10 +139,16 @@ private slots:
, proxy(proxy)
{
}

bool isValid() const
{
return client && proxy;
}

ClientSocketPtr client;
ScriptableProxy *proxy = nullptr;
};
QMap<ClientSocket*, ClientData> m_clients;
QMap<ClientSocketId, ClientData> m_clients;
};

#endif // CLIPBOARDSERVER_H
8 changes: 4 additions & 4 deletions src/common/clientsocket.cpp
Expand Up @@ -31,7 +31,7 @@
namespace {

const int bigMessageThreshold = 5 * 1024 * 1024;
int lastSocketId = 0;
ClientSocketId lastSocketId = 0;

const quint32 protocolMagicNumber = 0x0C090701;
const quint32 protocolVersion = 1;
Expand Down Expand Up @@ -166,7 +166,7 @@ void ClientSocket::start()
{
if ( !m_socket || !m_socket->waitForConnected(4000) )
{
emit connectionFailed(this);
emit connectionFailed(id());
return;
}

Expand Down Expand Up @@ -278,7 +278,7 @@ void ClientSocket::onReadyRead()
m_hasMessageLength = false;
m_message = m_message.mid(length);

emit messageReceived(msg, messageCode, this);
emit messageReceived(msg, messageCode, id());
}
}

Expand Down Expand Up @@ -307,7 +307,7 @@ void ClientSocket::onStateChanged(QLocalSocket::LocalSocketState state)
if (m_hasMessageLength)
log("ERROR: Socket disconnected before receiving message", LogError);

emit disconnected(this);
emit disconnected(id());
}
}
}
Expand Down
12 changes: 7 additions & 5 deletions src/common/clientsocket.h
Expand Up @@ -24,6 +24,8 @@
#include <QObject>
#include <QPointer>

using ClientSocketId = qulonglong;

class LocalSocketGuard
{
public:
Expand Down Expand Up @@ -55,7 +57,7 @@ class ClientSocket : public QObject
~ClientSocket();

/// Return socket ID unique in process (thread-safe).
int id() const { return m_socketId; }
ClientSocketId id() const { return m_socketId; }

void waitForReadyRead();

Expand All @@ -73,9 +75,9 @@ class ClientSocket : public QObject
bool isClosed() const;

signals:
void messageReceived(const QByteArray &message, int messageCode, ClientSocket *client);
void disconnected(ClientSocket *client);
void connectionFailed(ClientSocket *client);
void messageReceived(const QByteArray &message, int messageCode, ClientSocketId clientId);
void disconnected(ClientSocketId clientId);
void connectionFailed(ClientSocketId clientId);

private:
void onReadyRead();
Expand All @@ -85,7 +87,7 @@ class ClientSocket : public QObject
void error(const QString &errorMessage);

LocalSocketGuard m_socket;
int m_socketId;
ClientSocketId m_socketId;
bool m_closed;

bool m_hasMessageLength = false;
Expand Down

0 comments on commit 6eb9625

Please sign in to comment.