Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torchscript compatibility #83

Open
SergeySandler opened this issue Mar 1, 2024 · 6 comments
Open

Torchscript compatibility #83

SergeySandler opened this issue Mar 1, 2024 · 6 comments

Comments

@SergeySandler
Copy link

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

@SergeySandler
Copy link
Author

SergeySandler commented Mar 1, 2024

It seems to be possible to overcome the error reported above by modifying BlockV2.__init__ by adding an else clause after

    if self.use_projection:
      self.proj_conv = nn.Conv2d(
          in_channels=channels_in,
          out_channels=channels_out,
          kernel_size=1,
          stride=stride,
          padding=0,
          bias=False,
      )

in https://github.com/google-deepmind/tapnet/blob/main/torch/nets.py#L225-L233, so it looks like

   if self.use_projection:
      self.proj_conv = nn.Conv2d(...)
   else:
      self.proj_conv = DummyModel()

where DummyModel is dummy:

class DummyModel:

    def __init__(self):
        pass
        
    def forward(self):
        return torch.tensor(0)
        
    def __call__(self, input):
        return self.forward()

But then torch.jit.script(model) fails with

Arguments for call are not valid.
The following variants are available:
  
  aten::cat(Tensor[] tensors, int dim=0) -> Tensor:
  Keyword argument axis unknown.
  
  aten::cat.names(Tensor[] tensors, str dim) -> Tensor:
  Argument dim not provided.
  
  aten::cat.names_out(Tensor[] tensors, str dim, *, Tensor(a!) out) -> Tensor(a!):
  Argument dim not provided.
  
  aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!):
  Argument out not provided.

The original call is:
  File "C:\tapnet\tapnet\torch\nets.py", line 61
    prev_frame = torch.cat([x[0:1], x[:-1]], dim=0)
    next_frame = torch.cat([x[1:], x[-1:]], dim=0)
    resid = torch.cat([x, prev_frame, next_frame], axis=1) 
            ~~~~~~~~~ <--- HERE

that can be resolved by replacing resid = torch.cat([x, prev_frame, next_frame], axis=1) with resid = torch.cat([x, prev_frame, next_frame], dim=1) . I'd like to know why does not it cause 'axis' an unexpected keyword argument error?
The next error that happens is the following:

Unknown type constructor Mapping:
  File "C:\tapnet\tapnet\torch\tapir_model.py", line 145
      get_query_feats: bool = False,
      refinement_resolutions: Optional[List[Tuple[int, int]]] = None,
  ) -> Mapping[str, torch.Tensor]:
       ~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE

@sgjheywa
Copy link
Collaborator

sgjheywa commented Mar 6, 2024

Hi,

Thanks for raising the issue, prior to release we were also able to trace the model using the same method you described but after testing it actually showed very little performance increase when used. Can I ask what the use case is for scripting here? Thanks

@SergeySandler
Copy link
Author

SergeySandler commented Mar 8, 2024

@sgjheywa, scripting (torch.jit.script) helps to save a model with dynamic dimensions, while only static dimensions are supported through tracing. There were many code changes to achieve JIT compatibility, please review #85.

@sgjheywa
Copy link
Collaborator

Sorry, I am familiar with scripting, I'm just trying to figure out what the use case is here. Since the model is compatible with torch.compile this seems unnecessary. Thanks

@SergeySandler
Copy link
Author

@sgjheywa, the use case is LibTorch integration in C++. The model can be compiled with torch.compile, but it does not help since you cannot save it with torch.jit.save. Am I missing something? Thank you.

@pubyLu
Copy link

pubyLu commented May 20, 2024

While making the torch TAPIR model compatible with Torchscript tracing is easy by changing TAPIR.forward() in https://github.com/google-deepmind/tapnet/blob/main/torch/tapir_model.py#L196-L209 from

    out = dict(
        occlusion=torch.mean(
            torch.stack(trajectories['occlusion'][p::p]), dim=0
        ),
        tracks=torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
        expected_dist=torch.mean(
            torch.stack(trajectories['expected_dist'][p::p]), dim=0
        ),
        unrefined_occlusion=trajectories['occlusion'][:-1],
        unrefined_tracks=trajectories['tracks'][:-1],
        unrefined_expected_dist=trajectories['expected_dist'][:-1],
    )

    return out

to

    class Output(NamedTuple):
        occlusion: torch.tensor
        tracks: torch.tensor
        expected_dist: torch.tensor

    out = Output(torch.mean(torch.stack(trajectories['occlusion'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['tracks'][p::p]), dim=0),
                 torch.mean(torch.stack(trajectories['expected_dist'][p::p]), dim=0)
                )

    return out

(assuming it is OK to eliminate unrefined_ from the output), so that

model = tapir_model.TAPIR(pyramid_level=1)
model.load_state_dict(torch.load('bootstapir_checkpoint.pt'))
model = model.to(torch.device('cpu'))
model.eval()
dummy_input_frames = torch.randn(1, 32, 256, 256, 3, dtype=torch.float32, device = torch.device('cpu'))
dummy_input_query_points = torch.randn(1, 20, 3, dtype=torch.float32, device = torch.device('cpu'))    
scriptModule = torch.jit.trace(model, (dummy_input_frames, dummy_input_query_points))
torch.jit.save(scriptModule, 'bootstapir_checkpoint.ptc')

succeeds, it is not so easy to make it Torchscript scripting compatible.

scriptModule = torch.jit.script(model)

fails with

Module 'BlockV2' has no attribute 'proj_conv' :
  File "C:\tapnet\tapnet\torch\nets.py", line 278
    x = torch.relu(x)
    if self.use_projection:
      shortcut = self.proj_conv(x)
                 ~~~~~~~~~~~~~~ <--- HERE

How to make the model Torchscript scripting compatible?

hello! May I ask if you have implemented model training for the Tapir Python version

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants