@@ -177,7 +177,7 @@ static struct aesgcm_ctx *snp_init_crypto(u8 *key, size_t keylen)
177177 return ctx ;
178178}
179179
180- static int verify_and_dec_payload (struct snp_guest_dev * snp_dev , void * payload , u32 sz )
180+ static int verify_and_dec_payload (struct snp_guest_dev * snp_dev , struct snp_guest_req * req )
181181{
182182 struct snp_guest_msg * resp_msg = & snp_dev -> secret_response ;
183183 struct snp_guest_msg * req_msg = & snp_dev -> secret_request ;
@@ -206,20 +206,19 @@ static int verify_and_dec_payload(struct snp_guest_dev *snp_dev, void *payload,
206206 * If the message size is greater than our buffer length then return
207207 * an error.
208208 */
209- if (unlikely ((resp_msg_hdr -> msg_sz + ctx -> authsize ) > sz ))
209+ if (unlikely ((resp_msg_hdr -> msg_sz + ctx -> authsize ) > req -> resp_sz ))
210210 return - EBADMSG ;
211211
212212 /* Decrypt the payload */
213213 memcpy (iv , & resp_msg_hdr -> msg_seqno , min (sizeof (iv ), sizeof (resp_msg_hdr -> msg_seqno )));
214- if (!aesgcm_decrypt (ctx , payload , resp_msg -> payload , resp_msg_hdr -> msg_sz ,
214+ if (!aesgcm_decrypt (ctx , req -> resp_buf , resp_msg -> payload , resp_msg_hdr -> msg_sz ,
215215 & resp_msg_hdr -> algo , AAD_LEN , iv , resp_msg_hdr -> authtag ))
216216 return - EBADMSG ;
217217
218218 return 0 ;
219219}
220220
221- static int enc_payload (struct snp_guest_dev * snp_dev , u64 seqno , int version , u8 type ,
222- void * payload , size_t sz )
221+ static int enc_payload (struct snp_guest_dev * snp_dev , u64 seqno , struct snp_guest_req * req )
223222{
224223 struct snp_guest_msg * msg = & snp_dev -> secret_request ;
225224 struct snp_guest_msg_hdr * hdr = & msg -> hdr ;
@@ -231,11 +230,11 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
231230 hdr -> algo = SNP_AEAD_AES_256_GCM ;
232231 hdr -> hdr_version = MSG_HDR_VER ;
233232 hdr -> hdr_sz = sizeof (* hdr );
234- hdr -> msg_type = type ;
235- hdr -> msg_version = version ;
233+ hdr -> msg_type = req -> msg_type ;
234+ hdr -> msg_version = req -> msg_version ;
236235 hdr -> msg_seqno = seqno ;
237- hdr -> msg_vmpck = vmpck_id ;
238- hdr -> msg_sz = sz ;
236+ hdr -> msg_vmpck = req -> vmpck_id ;
237+ hdr -> msg_sz = req -> req_sz ;
239238
240239 /* Verify the sequence number is non-zero */
241240 if (!hdr -> msg_seqno )
@@ -244,17 +243,17 @@ static int enc_payload(struct snp_guest_dev *snp_dev, u64 seqno, int version, u8
244243 pr_debug ("request [seqno %lld type %d version %d sz %d]\n" ,
245244 hdr -> msg_seqno , hdr -> msg_type , hdr -> msg_version , hdr -> msg_sz );
246245
247- if (WARN_ON ((sz + ctx -> authsize ) > sizeof (msg -> payload )))
246+ if (WARN_ON ((req -> req_sz + ctx -> authsize ) > sizeof (msg -> payload )))
248247 return - EBADMSG ;
249248
250249 memcpy (iv , & hdr -> msg_seqno , min (sizeof (iv ), sizeof (hdr -> msg_seqno )));
251- aesgcm_encrypt (ctx , msg -> payload , payload , sz , & hdr -> algo , AAD_LEN ,
252- iv , hdr -> authtag );
250+ aesgcm_encrypt (ctx , msg -> payload , req -> req_buf , req -> req_sz , & hdr -> algo ,
251+ AAD_LEN , iv , hdr -> authtag );
253252
254253 return 0 ;
255254}
256255
257- static int __handle_guest_request (struct snp_guest_dev * snp_dev , u64 exit_code ,
256+ static int __handle_guest_request (struct snp_guest_dev * snp_dev , struct snp_guest_req * req ,
258257 struct snp_guest_request_ioctl * rio )
259258{
260259 unsigned long req_start = jiffies ;
@@ -269,7 +268,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
269268 * sequence number must be incremented or the VMPCK must be deleted to
270269 * prevent reuse of the IV.
271270 */
272- rc = snp_issue_guest_request (exit_code , & snp_dev -> input , rio );
271+ rc = snp_issue_guest_request (req , & snp_dev -> input , rio );
273272 switch (rc ) {
274273 case - ENOSPC :
275274 /*
@@ -280,7 +279,7 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
280279 * IV reuse.
281280 */
282281 override_npages = snp_dev -> input .data_npages ;
283- exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
282+ req -> exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
284283
285284 /*
286285 * Override the error to inform callers the given extended
@@ -340,10 +339,8 @@ static int __handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
340339 return rc ;
341340}
342341
343- static int handle_guest_request (struct snp_guest_dev * snp_dev , u64 exit_code ,
344- struct snp_guest_request_ioctl * rio , u8 type ,
345- void * req_buf , size_t req_sz , void * resp_buf ,
346- u32 resp_sz )
342+ static int snp_send_guest_request (struct snp_guest_dev * snp_dev , struct snp_guest_req * req ,
343+ struct snp_guest_request_ioctl * rio )
347344{
348345 u64 seqno ;
349346 int rc ;
@@ -357,7 +354,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
357354 memset (snp_dev -> response , 0 , sizeof (struct snp_guest_msg ));
358355
359356 /* Encrypt the userspace provided payload in snp_dev->secret_request. */
360- rc = enc_payload (snp_dev , seqno , rio -> msg_version , type , req_buf , req_sz );
357+ rc = enc_payload (snp_dev , seqno , req );
361358 if (rc )
362359 return rc ;
363360
@@ -368,7 +365,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
368365 memcpy (snp_dev -> request , & snp_dev -> secret_request ,
369366 sizeof (snp_dev -> secret_request ));
370367
371- rc = __handle_guest_request (snp_dev , exit_code , rio );
368+ rc = __handle_guest_request (snp_dev , req , rio );
372369 if (rc ) {
373370 if (rc == - EIO &&
374371 rio -> exitinfo2 == SNP_GUEST_VMM_ERR (SNP_GUEST_VMM_ERR_INVALID_LEN ))
@@ -382,7 +379,7 @@ static int handle_guest_request(struct snp_guest_dev *snp_dev, u64 exit_code,
382379 return rc ;
383380 }
384381
385- rc = verify_and_dec_payload (snp_dev , resp_buf , resp_sz );
382+ rc = verify_and_dec_payload (snp_dev , req );
386383 if (rc ) {
387384 dev_alert (snp_dev -> dev , "Detected unexpected decode failure from ASP. rc: %d\n" , rc );
388385 snp_disable_vmpck (snp_dev );
@@ -401,6 +398,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
401398{
402399 struct snp_report_req * report_req = & snp_dev -> req .report ;
403400 struct snp_report_resp * report_resp ;
401+ struct snp_guest_req req = {};
404402 int rc , resp_len ;
405403
406404 lockdep_assert_held (& snp_cmd_mutex );
@@ -421,8 +419,16 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
421419 if (!report_resp )
422420 return - ENOMEM ;
423421
424- rc = handle_guest_request (snp_dev , SVM_VMGEXIT_GUEST_REQUEST , arg , SNP_MSG_REPORT_REQ ,
425- report_req , sizeof (* report_req ), report_resp -> data , resp_len );
422+ req .msg_version = arg -> msg_version ;
423+ req .msg_type = SNP_MSG_REPORT_REQ ;
424+ req .vmpck_id = vmpck_id ;
425+ req .req_buf = report_req ;
426+ req .req_sz = sizeof (* report_req );
427+ req .resp_buf = report_resp -> data ;
428+ req .resp_sz = resp_len ;
429+ req .exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
430+
431+ rc = snp_send_guest_request (snp_dev , & req , arg );
426432 if (rc )
427433 goto e_free ;
428434
@@ -438,6 +444,7 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
438444{
439445 struct snp_derived_key_req * derived_key_req = & snp_dev -> req .derived_key ;
440446 struct snp_derived_key_resp derived_key_resp = {0 };
447+ struct snp_guest_req req = {};
441448 int rc , resp_len ;
442449 /* Response data is 64 bytes and max authsize for GCM is 16 bytes. */
443450 u8 buf [64 + 16 ];
@@ -460,8 +467,16 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
460467 sizeof (* derived_key_req )))
461468 return - EFAULT ;
462469
463- rc = handle_guest_request (snp_dev , SVM_VMGEXIT_GUEST_REQUEST , arg , SNP_MSG_KEY_REQ ,
464- derived_key_req , sizeof (* derived_key_req ), buf , resp_len );
470+ req .msg_version = arg -> msg_version ;
471+ req .msg_type = SNP_MSG_KEY_REQ ;
472+ req .vmpck_id = vmpck_id ;
473+ req .req_buf = derived_key_req ;
474+ req .req_sz = sizeof (* derived_key_req );
475+ req .resp_buf = buf ;
476+ req .resp_sz = resp_len ;
477+ req .exit_code = SVM_VMGEXIT_GUEST_REQUEST ;
478+
479+ rc = snp_send_guest_request (snp_dev , & req , arg );
465480 if (rc )
466481 return rc ;
467482
@@ -482,6 +497,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
482497{
483498 struct snp_ext_report_req * report_req = & snp_dev -> req .ext_report ;
484499 struct snp_report_resp * report_resp ;
500+ struct snp_guest_req req = {};
485501 int ret , npages = 0 , resp_len ;
486502 sockptr_t certs_address ;
487503
@@ -529,9 +545,17 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
529545 return - ENOMEM ;
530546
531547 snp_dev -> input .data_npages = npages ;
532- ret = handle_guest_request (snp_dev , SVM_VMGEXIT_EXT_GUEST_REQUEST , arg , SNP_MSG_REPORT_REQ ,
533- & report_req -> data , sizeof (report_req -> data ),
534- report_resp -> data , resp_len );
548+
549+ req .msg_version = arg -> msg_version ;
550+ req .msg_type = SNP_MSG_REPORT_REQ ;
551+ req .vmpck_id = vmpck_id ;
552+ req .req_buf = & report_req -> data ;
553+ req .req_sz = sizeof (report_req -> data );
554+ req .resp_buf = report_resp -> data ;
555+ req .resp_sz = resp_len ;
556+ req .exit_code = SVM_VMGEXIT_EXT_GUEST_REQUEST ;
557+
558+ ret = snp_send_guest_request (snp_dev , & req , arg );
535559
536560 /* If certs length is invalid then copy the returned length */
537561 if (arg -> vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN ) {
@@ -1057,7 +1081,7 @@ static int __init sev_guest_probe(struct platform_device *pdev)
10571081 misc -> name = DEVICE_NAME ;
10581082 misc -> fops = & snp_guest_fops ;
10591083
1060- /* initial the input address for guest request */
1084+ /* Initialize the input addresses for guest request */
10611085 snp_dev -> input .req_gpa = __pa (snp_dev -> request );
10621086 snp_dev -> input .resp_gpa = __pa (snp_dev -> response );
10631087 snp_dev -> input .data_gpa = __pa (snp_dev -> certs_data );
0 commit comments