Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions auto_round/modelling/gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,20 +103,27 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# Use the original router (it returns scores and indices already softmaxed over top-k)
router_scores, router_indices = self.router(x) # scores: [tokens, E], indices: [tokens, k]

out = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x)

# Accumulate expert outputs for chosen experts only
for j in range(self.top_k):
idx = router_indices[:, j]
w = router_scores[torch.arange(idx.size(0), device=idx.device), idx].unsqueeze(-1)
unique_experts = torch.unique(idx)
for e in unique_experts:
mask = idx == e
out[mask] += self.experts[e](x[mask]) * w[mask]

out = out.view(B, T, H)
router_scores = router_scores.view(B * T, -1) # shape doesn't matter much; it’s ignored by the decoder
return out, router_scores
final_hidden_states = self.shared_expert(x) if self.shared_expert is not None else torch.zeros_like(x)
num_all_tokens, total_num_experts = x.size(0), self.num_experts
mask_weights = torch.zeros((num_all_tokens, total_num_experts), dtype=x.dtype, device=x.device)
topk_ids, experts_mask = router_indices, router_scores
topk_ids = topk_ids.to(torch.int64)

mask_weights.scatter_(-1, topk_ids, 1)

mask_weights = mask_weights[:num_all_tokens, :total_num_experts]
mask_weights = mask_weights.transpose(0, 1)
experts_mask = experts_mask[:num_all_tokens, :total_num_experts]
experts_mask = experts_mask.transpose(0, 1)
num_experts = total_num_experts
for expert_index in range(num_experts):
mask_weight = mask_weights[expert_index].unsqueeze(1)
current_state_static = x * mask_weight
expert = self.experts[expert_index]
expert_output = expert(current_state_static)
expert_output = expert_output * experts_mask[expert_index].unsqueeze(1)
final_hidden_states += expert_output
return final_hidden_states.view(B, T, H), router_scores.view(B * T, -1)


def get_replacement_info(config):
Expand Down