Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion router/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,10 @@ fn main() -> Result<(), RouterError> {
"Inferred max batch total tokens: {max_supported_batch_total_tokens}"
);
}
if max_total_tokens as u32 > max_supported_batch_total_tokens {
return Err(RouterError::ArgumentValidation(format!("`max_total_tokens` must be <= `max_batch_total_tokens`. Given: {max_total_tokens} and {max_supported_batch_total_tokens}")));
}

max_supported_batch_total_tokens
}
};
Expand Down Expand Up @@ -270,7 +274,7 @@ fn main() -> Result<(), RouterError> {
ngrok_authtoken,
ngrok_edge,
)
.await?;
.await?;
Ok(())
})
}
Expand Down
2 changes: 1 addition & 1 deletion server/Makefile-vllm
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
vllm_commit := d284b831c17f42a8ea63369a06138325f73c4cf9
vllm_commit := 084ca75d4271f8f67be731bc58e0d41d8e0afd3a

vllm:
# Clone vllm
Expand Down
35 changes: 15 additions & 20 deletions server/text_generation_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,36 +219,31 @@ def load(config, prefix: str, weights):
)

def forward(self, input: torch.Tensor) -> torch.Tensor:
if not self.should_gather:
return super().forward(input)

world_size = self.process_group.size()
if len(input.shape) == 2 and isinstance(self.linear, FastLinear):
# Fast branch for single requests
if (
self.should_gather
and len(input.shape) == 2
and isinstance(self.linear, FastLinear)
and input.shape[0] == 1
):
out_dim = self.linear.weight.shape[0]

if input.shape[0] == 1:
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)
gather_input = local_out
else:
world_out = input.new_empty(out_dim * world_size, input.shape[0])
gather_input = input.new_empty(out_dim, input.shape[0])
local_out = gather_input.T
world_out = input.new_empty(1, out_dim * world_size)
local_out = input.new_empty(1, out_dim)

torch.mm(input, self.linear.weight.T, out=local_out)

torch.distributed.all_gather_into_tensor(
world_out, gather_input, group=self.process_group
world_out, local_out, group=self.process_group
)

if input.shape[0] == 1:
return world_out
return world_out.T
return world_out

output = super().forward(input)
world_output = [
torch.empty_like(output) for _ in range(self.process_group.size())
]
if not self.should_gather:
return output

world_output = [torch.empty_like(output) for _ in range(world_size)]
torch.distributed.all_gather(world_output, output, group=self.process_group)
world_output = torch.cat(world_output, dim=-1)
return world_output
Expand Down