diff --git a/auto_round/modelling/gpt_oss.py b/auto_round/modelling/gpt_oss.py index 78f73075c..2acae3bb1 100644 --- a/auto_round/modelling/gpt_oss.py +++ b/auto_round/modelling/gpt_oss.py @@ -103,20 +103,27 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # Use the original router (it returns scores and indices already softmaxed over top-k) router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k] - out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) - - # Accumulate expert outputs for chosen experts only - for j in range(self.top_k): - idx = router_indices[:, j] - w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1) - unique_experts = torch.unique(idx) - for e in unique_experts: - mask = idx == e - out[mask] += self.experts[e](x[mask]) * w[mask] - - out = out.view(B, T, H) - router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder - return out, router_scores + final_hidden_states = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x) + num_all_tokens, total_num_experts = x.size(0), self.num_experts + mask_weights = torch.zeros((num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device) + topk_ids, experts_mask = router_indices, router_scores + topk_ids = topk_ids.to(torch.int64) + + mask_weights.scatter_(-1, topk_ids, 1) + + mask_weights = mask_weights[:num_all_tokens, :total_num_experts] + mask_weights = mask_weights.transpose(0, 1) + experts_mask = experts_mask[:num_all_tokens, :total_num_experts] + experts_mask = experts_mask.transpose(0, 1) + num_experts = total_num_experts + for expert_index in range(num_experts): + mask_weight = mask_weights[expert_index].unsqueeze(1) + current_state_static = x * mask_weight + expert = self.experts[expert_index] + expert_output = expert(current_state_static) + expert_output = expert_output * experts_mask[expert_index].unsqueeze(1) + final_hidden_states += expert_output + return final_hidden_states.view(B, T, H), router_scores.view(B * T, -1) def get_replacement_info(config):