Skip to content

Commit

Permalink
habanalabs: validate packet id during CB parse
Browse files Browse the repository at this point in the history
[ Upstream commit bc75be2 ]

During command buffer parsing, driver extracts packet id
from user buffer. Driver must validate this packet id, since it is
being used in order to extract information from internal structures.

Signed-off-by: Ofir Bitton <obitton@habana.ai>
Reviewed-by: Oded Gabbay <oded.gabbay@gmail.com>
Signed-off-by: Oded Gabbay <oded.gabbay@gmail.com>
Signed-off-by: Sasha Levin <sashal@kernel.org>
  • Loading branch information
ofirbitt authored and gregkh committed Sep 9, 2020
1 parent b18c607 commit fb8b459
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 0 deletions.
35 changes: 35 additions & 0 deletions drivers/misc/habanalabs/gaudi/gaudi.c
Expand Up @@ -154,6 +154,29 @@ static const u16 gaudi_packet_sizes[MAX_PACKET_ID] = {
[PACKET_LOAD_AND_EXE] = sizeof(struct packet_load_and_exe)
};

static inline bool validate_packet_id(enum packet_id id)
{
switch (id) {
case PACKET_WREG_32:
case PACKET_WREG_BULK:
case PACKET_MSG_LONG:
case PACKET_MSG_SHORT:
case PACKET_CP_DMA:
case PACKET_REPEAT:
case PACKET_MSG_PROT:
case PACKET_FENCE:
case PACKET_LIN_DMA:
case PACKET_NOP:
case PACKET_STOP:
case PACKET_ARB_POINT:
case PACKET_WAIT:
case PACKET_LOAD_AND_EXE:
return true;
default:
return false;
}
}

static const char * const
gaudi_tpc_interrupts_cause[GAUDI_NUM_OF_TPC_INTR_CAUSE] = {
"tpc_address_exceed_slm",
Expand Down Expand Up @@ -3859,6 +3882,12 @@ static int gaudi_validate_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);

if (!validate_packet_id(pkt_id)) {
dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
rc = -EINVAL;
break;
}

pkt_size = gaudi_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
Expand Down Expand Up @@ -4082,6 +4111,12 @@ static int gaudi_patch_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);

if (!validate_packet_id(pkt_id)) {
dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
rc = -EINVAL;
break;
}

pkt_size = gaudi_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
Expand Down
31 changes: 31 additions & 0 deletions drivers/misc/habanalabs/goya/goya.c
Expand Up @@ -139,6 +139,25 @@ static u16 goya_packet_sizes[MAX_PACKET_ID] = {
[PACKET_STOP] = sizeof(struct packet_stop)
};

static inline bool validate_packet_id(enum packet_id id)
{
switch (id) {
case PACKET_WREG_32:
case PACKET_WREG_BULK:
case PACKET_MSG_LONG:
case PACKET_MSG_SHORT:
case PACKET_CP_DMA:
case PACKET_MSG_PROT:
case PACKET_FENCE:
case PACKET_LIN_DMA:
case PACKET_NOP:
case PACKET_STOP:
return true;
default:
return false;
}
}

static u64 goya_mmu_regs[GOYA_MMU_REGS_NUM] = {
mmDMA_QM_0_GLBL_NON_SECURE_PROPS,
mmDMA_QM_1_GLBL_NON_SECURE_PROPS,
Expand Down Expand Up @@ -3381,6 +3400,12 @@ static int goya_validate_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);

if (!validate_packet_id(pkt_id)) {
dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
rc = -EINVAL;
break;
}

pkt_size = goya_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
Expand Down Expand Up @@ -3616,6 +3641,12 @@ static int goya_patch_cb(struct hl_device *hdev,
PACKET_HEADER_PACKET_ID_MASK) >>
PACKET_HEADER_PACKET_ID_SHIFT);

if (!validate_packet_id(pkt_id)) {
dev_err(hdev->dev, "Invalid packet id %u\n", pkt_id);
rc = -EINVAL;
break;
}

pkt_size = goya_packet_sizes[pkt_id];
cb_parsed_length += pkt_size;
if (cb_parsed_length > parser->user_cb_size) {
Expand Down

0 comments on commit fb8b459

Please sign in to comment.