diff --git a/components/wpa_supplicant/src/ap/wpa_auth.c b/components/wpa_supplicant/src/ap/wpa_auth.c index c062ccd44fb..ce76023a56f 100644 --- a/components/wpa_supplicant/src/ap/wpa_auth.c +++ b/components/wpa_supplicant/src/ap/wpa_auth.c @@ -958,6 +958,12 @@ int wpa_auth_pmksa_add_sae(struct wpa_authenticator *wpa_auth, const u8 *addr, return -1; } +void wpa_auth_add_sae_pmkid(struct wpa_state_machine *sm, const u8 *pmkid) +{ + os_memcpy(sm->pmkid, pmkid, PMKID_LEN); + sm->pmkid_set = 1; +} + static int wpa_gmk_to_gtk(const u8 *gmk, const char *label, const u8 *addr, const u8 *gnonce, u8 *gtk, size_t gtk_len) { @@ -1541,15 +1547,38 @@ SM_STATE(WPA_PTK, PTKSTART) pmkid[0] = WLAN_EID_VENDOR_SPECIFIC; pmkid[1] = RSN_SELECTOR_LEN + PMKID_LEN; RSN_SELECTOR_PUT(&pmkid[2], RSN_KEY_DATA_PMKID); - - { + if (sm->pmksa) { + wpa_hexdump(MSG_DEBUG, + "RSN: Message 1/4 PMKID from PMKSA entry", + sm->pmksa->pmkid, PMKID_LEN); + os_memcpy(&pmkid[2 + RSN_SELECTOR_LEN], + sm->pmksa->pmkid, PMKID_LEN); +#ifdef CONFIG_SAE + } else if (wpa_key_mgmt_sae(sm->wpa_key_mgmt)) { + if (sm->pmkid_set) { + wpa_hexdump(MSG_DEBUG, + "RSN: Message 1/4 PMKID from SAE", + sm->pmkid, PMKID_LEN); + os_memcpy(&pmkid[2 + RSN_SELECTOR_LEN], + sm->pmkid, PMKID_LEN); + } else { + /* No PMKID available */ + wpa_printf(MSG_DEBUG, + "RSN: No SAE PMKID available for message 1/4"); + pmkid = NULL; + } +#endif /* CONFIG_SAE */ + } else { /* * Calculate PMKID since no PMKSA cache entry was * available with pre-calculated PMKID. */ - rsn_pmkid(sm->PMK, PMK_LEN, sm->wpa_auth->addr, - sm->addr, &pmkid[2 + RSN_SELECTOR_LEN], - wpa_key_mgmt_sha256(sm->wpa_key_mgmt)); + rsn_pmkid(sm->PMK, sm->pmk_len, sm->wpa_auth->addr, + sm->addr, &pmkid[2 + RSN_SELECTOR_LEN], + sm->wpa_key_mgmt); + wpa_hexdump(MSG_DEBUG, + "RSN: Message 1/4 PMKID derived from PMK", + &pmkid[2 + RSN_SELECTOR_LEN], PMKID_LEN); } } wpa_send_eapol(sm->wpa_auth, sm, @@ -2539,6 +2568,14 @@ bool wpa_ap_join(struct sta_info *sta, uint8_t *bssid, uint8_t *wpa_ie, } status_code = wpa_validate_wpa_ie(hapd->wpa_auth, sta->wpa_sm, wpa_ie, wpa_ie_len, rsnxe, rsnxe_len); + +#ifdef CONFIG_SAE + if (wpa_auth_uses_sae(sta->wpa_sm) && sta->sae && + sta->sae->state == SAE_ACCEPTED) { + wpa_auth_add_sae_pmkid(sta->wpa_sm, sta->sae->pmkid); + } +#endif /* CONFIG_SAE */ + resp = wpa_res_to_status_code(status_code); send_resp: diff --git a/components/wpa_supplicant/src/ap/wpa_auth_i.h b/components/wpa_supplicant/src/ap/wpa_auth_i.h index de8322f2fc0..672cc09425c 100644 --- a/components/wpa_supplicant/src/ap/wpa_auth_i.h +++ b/components/wpa_supplicant/src/ap/wpa_auth_i.h @@ -90,6 +90,7 @@ struct wpa_state_machine { unsigned int pmk_r1_name_valid:1; #endif /* CONFIG_IEEE80211R */ unsigned int is_wnmsleep:1; + unsigned int pmkid_set:1; u8 req_replay_counter[WPA_REPLAY_COUNTER_LEN]; int req_replay_counter_used; diff --git a/components/wpa_supplicant/src/common/wpa_common.c b/components/wpa_supplicant/src/common/wpa_common.c index 34e704f6404..05e00463aad 100644 --- a/components/wpa_supplicant/src/common/wpa_common.c +++ b/components/wpa_supplicant/src/common/wpa_common.c @@ -1171,26 +1171,28 @@ int wpa_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const char *label, * PMKID = HMAC-SHA1-128(PMK, "PMK Name" || AA || SPA) */ void rsn_pmkid(const u8 *pmk, size_t pmk_len, const u8 *aa, const u8 *spa, - u8 *pmkid, int use_sha256) + u8 *pmkid, int akmp) { - char title[9]; + char *title = "PMK Name"; const u8 *addr[3]; const size_t len[3] = { 8, ETH_ALEN, ETH_ALEN }; unsigned char hash[SHA256_MAC_LEN]; - os_memcpy(title, "PMK Name", sizeof("PMK Name")); addr[0] = (u8 *) title; addr[1] = aa; addr[2] = spa; #ifdef CONFIG_IEEE80211W - if (use_sha256) { + if (wpa_key_mgmt_sha256(akmp)) { + wpa_printf(MSG_DEBUG, "RSN: Derive PMKID using HMAC-SHA-256"); hmac_sha256_vector(pmk, pmk_len, 3, addr, len, hash); - } - else + } else #endif /* CONFIG_IEEE80211W */ - hmac_sha1_vector(pmk, pmk_len, 3, addr, len, hash); - memcpy(pmkid, hash, PMKID_LEN); + { + wpa_printf(MSG_DEBUG, "RSN: Derive PMKID using HMAC-SHA-1"); + hmac_sha1_vector(pmk, pmk_len, 3, addr, len, hash); + } + os_memcpy(pmkid, hash, PMKID_LEN); } diff --git a/components/wpa_supplicant/src/common/wpa_common.h b/components/wpa_supplicant/src/common/wpa_common.h index d456794ab20..4462c81f003 100644 --- a/components/wpa_supplicant/src/common/wpa_common.h +++ b/components/wpa_supplicant/src/common/wpa_common.h @@ -432,7 +432,7 @@ int wpa_pmk_to_ptk(const u8 *pmk, size_t pmk_len, const char *label, struct wpa_ptk *ptk, int akmp, int cipher); void rsn_pmkid(const u8 *pmk, size_t pmk_len, const u8 *aa, const u8 *spa, - u8 *pmkid, int use_sha256); + u8 *pmkid, int akmp); int wpa_cipher_key_len(int cipher); int wpa_cipher_rsc_len(int cipher);