diff --git a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py index e893bb4a..2558c8d0 100644 --- a/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py +++ b/src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py @@ -200,6 +200,7 @@ def __init__(self, self.enable_tensorrt = enable_tensorrt and TENSORRT_AVAILABLE self.force_rebuild = force_rebuild self._first_frame = True + self.prev_input_frame = None # Store previous input frame for temporal flow computation # Model paths self.models_dir = Path("models") / "temporal_net" @@ -370,7 +371,7 @@ def _process_core(self, image: Image.Image) -> Image.Image: image: Current input image Returns: - Warped previous frame for temporal guidance, or fallback for first frame + Processed frame for temporal guidance, or fallback for first frame """ # Convert to tensor and use tensor processing path for efficiency tensor = self.pil_to_tensor(image) @@ -385,11 +386,21 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: tensor: Current input tensor Returns: - Warped previous frame tensor for temporal guidance + Processed frame tensor for temporal guidance """ - # Check if we have a pipeline reference and previous output - if (self.pipeline_ref is not None and + # Normalize input tensor + input_tensor = tensor + if input_tensor.max() > 1.0: + input_tensor = input_tensor / 255.0 + + # Ensure consistent format + if input_tensor.dim() == 4 and input_tensor.shape[0] == 1: + input_tensor = input_tensor[0] + + # Check if we have previous input frame and pipeline output for temporal processing + if (self.prev_input_frame is not None and + self.pipeline_ref is not None and hasattr(self.pipeline_ref, 'prev_image_result') and self.pipeline_ref.prev_image_result is not None and not self._first_frame): @@ -399,52 +410,50 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: # Convert from VAE output format [-1, 1] to [0, 1] prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1) - # Normalize input tensor - input_tensor = tensor - if input_tensor.max() > 1.0: - input_tensor = input_tensor / 255.0 - - # Ensure consistent format + # Ensure consistent format for previous output if prev_output.dim() == 4 and prev_output.shape[0] == 1: prev_output = prev_output[0] - if input_tensor.dim() == 4 and input_tensor.shape[0] == 1: - input_tensor = input_tensor[0] try: - # Compute optical flow and warp on GPU using TensorRT - warped_tensor = self._compute_and_warp_tensor(input_tensor, prev_output) + # Compute optical flow between consecutive input frames + # then create flow visualization + flow_image_tensor = self._compute_flow_image_tensor(input_tensor, prev_output) # Check output format output_format = self.params.get('output_format', 'concat') if output_format == "concat": - # Concatenate current frame + warped frame for TemporalNet2 (6 channels) - result_tensor = self._concatenate_frames_tensor(input_tensor, warped_tensor) + # Concatenate previous output + flow visualization for TemporalNet2 (6 channels) + result_tensor = self._concatenate_frames_tensor(prev_output, flow_image_tensor) else: - # Return only warped frame (3 channels) - result_tensor = warped_tensor + # Return only flow visualization (3 channels) + result_tensor = flow_image_tensor # Ensure correct output format if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) + + # Store current input for next iteration + self.prev_input_frame = input_tensor.detach().clone() except Exception as e: logger.error(f"_process_tensor_core: TensorRT optical flow failed: {e}") output_format = self.params.get('output_format', 'concat') if output_format == "concat": - # Create 6-channel fallback by concatenating current frame with itself - result_tensor = self._concatenate_frames_tensor(input_tensor, input_tensor) + # Create 6-channel fallback: previous output + black flow image + black_flow = torch.zeros_like(prev_output) + result_tensor = self._concatenate_frames_tensor(prev_output, black_flow) if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) else: - # Create 6-channel fallback by concatenating current frame with itself - result_tensor = self._concatenate_frames_tensor(input_tensor, input_tensor) - if result_tensor.dim() == 3: - result_tensor = result_tensor.unsqueeze(0) - result = result_tensor.to(device=self.device, dtype=self.dtype) + # Return black flow image as fallback + black_flow = torch.zeros_like(input_tensor) + if black_flow.dim() == 3: + black_flow = black_flow.unsqueeze(0) + result = black_flow.to(device=self.device, dtype=self.dtype) else: - # First frame or no previous output available + # First frame or no previous data available self._first_frame = False if tensor.dim() == 3: tensor = tensor.unsqueeze(0) @@ -452,46 +461,42 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor: # Handle 6-channel output for first frame output_format = self.params.get('output_format', 'concat') if output_format == "concat": - # For first frame, duplicate current frame to create 6-channel output - if tensor.dim() == 4 and tensor.shape[0] == 1: - current_tensor = tensor[0] - else: - current_tensor = tensor - result_tensor = self._concatenate_frames_tensor(current_tensor, current_tensor) + # For first frame: current frame + black flow image + black_flow = torch.zeros_like(input_tensor) + result_tensor = self._concatenate_frames_tensor(input_tensor, black_flow) if result_tensor.dim() == 3: result_tensor = result_tensor.unsqueeze(0) result = result_tensor.to(device=self.device, dtype=self.dtype) else: - # Create 6-channel fallback by concatenating current frame with itself - if tensor.dim() == 4 and tensor.shape[0] == 1: - current_tensor = tensor[0] - else: - current_tensor = tensor - result_tensor = self._concatenate_frames_tensor(current_tensor, current_tensor) - if result_tensor.dim() == 3: - result_tensor = result_tensor.unsqueeze(0) - result = result_tensor.to(device=self.device, dtype=self.dtype) + # Return black flow image for first frame + black_flow = torch.zeros_like(input_tensor) + if black_flow.dim() == 3: + black_flow = black_flow.unsqueeze(0) + result = black_flow.to(device=self.device, dtype=self.dtype) + + # Store current input for next iteration + self.prev_input_frame = input_tensor.detach().clone() return result - def _compute_and_warp_tensor(self, current_tensor: torch.Tensor, prev_tensor: torch.Tensor) -> torch.Tensor: + def _compute_flow_image_tensor(self, current_tensor: torch.Tensor, prev_output_tensor: torch.Tensor) -> torch.Tensor: """ - Compute optical flow using TensorRT and warp previous tensor + Compute optical flow between consecutive input frames and convert to flow visualization Args: current_tensor: Current input frame tensor (CHW format, [0,1]) on GPU - prev_tensor: Previous pipeline output tensor (CHW format, [0,1]) on GPU + prev_output_tensor: Previous pipeline output tensor (CHW format, [0,1]) on GPU Returns: - Warped previous frame tensor on GPU + Flow visualization tensor (RGB image showing flow vectors) on GPU """ target_width, target_height = self.get_target_dimensions() # Convert to float32 for TensorRT processing current_tensor = current_tensor.to(device=self.device, dtype=torch.float32) - prev_tensor = prev_tensor.to(device=self.device, dtype=torch.float32) + prev_input_tensor = self.prev_input_frame.to(device=self.device, dtype=torch.float32) - # Resize for flow computation if needed (keep on GPU) + # Resize input frames for flow computation if needed (keep on GPU) if current_tensor.shape[-1] != self.detect_resolution or current_tensor.shape[-2] != self.detect_resolution: current_resized = F.interpolate( current_tensor.unsqueeze(0), @@ -499,40 +504,44 @@ def _compute_and_warp_tensor(self, current_tensor: torch.Tensor, prev_tensor: to mode='bilinear', align_corners=False ).squeeze(0) - prev_resized = F.interpolate( - prev_tensor.unsqueeze(0), + prev_input_resized = F.interpolate( + prev_input_tensor.unsqueeze(0), size=(self.detect_resolution, self.detect_resolution), mode='bilinear', align_corners=False ).squeeze(0) else: current_resized = current_tensor - prev_resized = prev_tensor + prev_input_resized = prev_input_tensor - # Compute optical flow using TensorRT - flow = self._compute_optical_flow_tensorrt(current_resized, prev_resized) + # Compute optical flow between consecutive input frames + flow = self._compute_optical_flow_tensorrt(prev_input_resized, current_resized) # Apply flow strength scaling (GPU operation) flow_strength = self.params.get('flow_strength', 1.0) if flow_strength != 1.0: flow = flow * flow_strength - # Warp previous frame using flow (GPU operation) - warped_frame = self._warp_frame_tensor(prev_resized, flow) + # Convert flow to RGB visualization using flow_to_image + flow_image = flow_to_image(flow.unsqueeze(0)).squeeze(0) # flow_to_image expects batch dimension + + # Convert from [0,255] uint8 to [0,1] float format + flow_image = flow_image.float() / 255.0 - # Resize back to target resolution if needed (keep on GPU) - if warped_frame.shape[-1] != target_width or warped_frame.shape[-2] != target_height: - warped_frame = F.interpolate( - warped_frame.unsqueeze(0), + # Resize to target resolution if needed (keep on GPU) + if flow_image.shape[-1] != target_width or flow_image.shape[-2] != target_height: + flow_image = F.interpolate( + flow_image.unsqueeze(0), size=(target_height, target_width), mode='bilinear', align_corners=False ).squeeze(0) - # Convert to processor's dtype only at the very end - result = warped_frame.to(dtype=self.dtype) + # Convert to processor's dtype + result = flow_image.to(dtype=self.dtype) return result + def _compute_optical_flow_tensorrt(self, frame1: torch.Tensor, frame2: torch.Tensor) -> torch.Tensor: """ @@ -567,102 +576,47 @@ def _compute_optical_flow_tensorrt(self, frame1: torch.Tensor, frame2: torch.Ten cuda_stream = torch.cuda.current_stream().cuda_stream result = self.trt_engine.infer(feed_dict, cuda_stream) - flow = result['flow'][0] # Remove batch dimension + flow = result['flow'][0] return flow - def _warp_frame_tensor(self, frame: torch.Tensor, flow: torch.Tensor) -> torch.Tensor: - """ - Warp frame using optical flow with cached coordinate grids - - Args: - frame: Frame to warp (CHW format) - flow: Optical flow (2HW format) - - Returns: - Warped frame tensor - """ - H, W = frame.shape[-2:] - - # Use cached grid if available - grid_key = (H, W) - if grid_key not in self._grid_cache: - grid_y, grid_x = torch.meshgrid( - torch.arange(H, device=self.device, dtype=torch.float32), - torch.arange(W, device=self.device, dtype=torch.float32), - indexing='ij' - ) - self._grid_cache[grid_key] = (grid_x, grid_y) - else: - grid_x, grid_y = self._grid_cache[grid_key] - - # Apply flow to coordinates - new_x = grid_x + flow[0] - new_y = grid_y + flow[1] - - # Normalize coordinates to [-1, 1] for grid_sample - new_x = 2.0 * new_x / (W - 1) - 1.0 - new_y = 2.0 * new_y / (H - 1) - 1.0 - - # Create sampling grid (HW2 format for grid_sample) - grid = torch.stack([new_x, new_y], dim=-1).unsqueeze(0) - - # Warp frame - warped_batch = F.grid_sample( - frame.unsqueeze(0), - grid, - mode='bilinear', - padding_mode='border', - align_corners=True - ) - - result = warped_batch.squeeze(0) - - return result - def _concatenate_frames(self, current_image: Image.Image, warped_image: Image.Image) -> Image.Image: - """Concatenate current frame and warped previous frame for TemporalNet2 (6-channel input)""" - # Convert to tensors and use tensor concatenation for consistency - current_tensor = self.pil_to_tensor(current_image).squeeze(0) - warped_tensor = self.pil_to_tensor(warped_image).squeeze(0) - result_tensor = self._concatenate_frames_tensor(current_tensor, warped_tensor) - return self.tensor_to_pil(result_tensor) - def _concatenate_frames_tensor(self, current_tensor: torch.Tensor, warped_tensor: torch.Tensor) -> torch.Tensor: + def _concatenate_frames_tensor(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor) -> torch.Tensor: """ - Concatenate current frame and warped previous frame tensors for TemporalNet2 (6-channel input) + Concatenate two frame tensors for TemporalNet2 (6-channel input) Args: - current_tensor: Current input frame tensor (CHW format) - warped_tensor: Warped previous frame tensor (CHW format) + first_tensor: First frame tensor (CHW format) + second_tensor: Second frame tensor (CHW format) Returns: Concatenated tensor (6CHW format) """ # Ensure same size - if current_tensor.shape != warped_tensor.shape: + if first_tensor.shape != second_tensor.shape: target_width, target_height = self.get_target_dimensions() - if current_tensor.shape[-2:] != (target_height, target_width): - current_tensor = F.interpolate( - current_tensor.unsqueeze(0), + if first_tensor.shape[-2:] != (target_height, target_width): + first_tensor = F.interpolate( + first_tensor.unsqueeze(0), size=(target_height, target_width), mode='bilinear', align_corners=False ).squeeze(0) - if warped_tensor.shape[-2:] != (target_height, target_width): - warped_tensor = F.interpolate( - warped_tensor.unsqueeze(0), + if second_tensor.shape[-2:] != (target_height, target_width): + second_tensor = F.interpolate( + second_tensor.unsqueeze(0), size=(target_height, target_width), mode='bilinear', align_corners=False ).squeeze(0) - # Concatenate along channel dimension: [current_R, current_G, current_B, warped_R, warped_G, warped_B] - concatenated = torch.cat([current_tensor, warped_tensor], dim=0) + # Concatenate along channel dimension: [first_R, first_G, first_B, second_R, second_G, second_B] + concatenated = torch.cat([first_tensor, second_tensor], dim=0) return concatenated @@ -671,6 +625,7 @@ def reset(self): Reset the preprocessor state (useful for new sequences) """ self._first_frame = True + self.prev_input_frame = None # Clear caches to free memory self._grid_cache.clear() self._tensor_cache.clear()