Skip to content

Commit

Permalink
[NFC] polish applications/Chat/coati/models/generation.py code style (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
yangluo7 authored and binmakeswell committed Jul 26, 2023
1 parent dc1b612 commit 709e121
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions applications/Chat/coati/models/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.nn as nn
import torch.nn.functional as F


try:
from transformers.generation_logits_process import (
LogitsProcessorList,
Expand Down Expand Up @@ -148,12 +147,12 @@ def generate(model: nn.Module,


@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
def generate_with_actor(
actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences
Expand Down

0 comments on commit 709e121

Please sign in to comment.