Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 84 additions & 129 deletions src/streamdiffusion/preprocessing/processors/temporal_net_tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ def __init__(self,
self.enable_tensorrt = enable_tensorrt and TENSORRT_AVAILABLE
self.force_rebuild = force_rebuild
self._first_frame = True
self.prev_input_frame = None # Store previous input frame for temporal flow computation

# Model paths
self.models_dir = Path("models") / "temporal_net"
Expand Down Expand Up @@ -370,7 +371,7 @@ def _process_core(self, image: Image.Image) -> Image.Image:
image: Current input image

Returns:
Warped previous frame for temporal guidance, or fallback for first frame
Processed frame for temporal guidance, or fallback for first frame
"""
# Convert to tensor and use tensor processing path for efficiency
tensor = self.pil_to_tensor(image)
Expand All @@ -385,11 +386,21 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
tensor: Current input tensor

Returns:
Warped previous frame tensor for temporal guidance
Processed frame tensor for temporal guidance
"""

# Check if we have a pipeline reference and previous output
if (self.pipeline_ref is not None and
# Normalize input tensor
input_tensor = tensor
if input_tensor.max() > 1.0:
input_tensor = input_tensor / 255.0

# Ensure consistent format
if input_tensor.dim() == 4 and input_tensor.shape[0] == 1:
input_tensor = input_tensor[0]

# Check if we have previous input frame and pipeline output for temporal processing
if (self.prev_input_frame is not None and
self.pipeline_ref is not None and
hasattr(self.pipeline_ref, 'prev_image_result') and
self.pipeline_ref.prev_image_result is not None and
not self._first_frame):
Expand All @@ -399,140 +410,138 @@ def _process_tensor_core(self, tensor: torch.Tensor) -> torch.Tensor:
# Convert from VAE output format [-1, 1] to [0, 1]
prev_output = (prev_output / 2.0 + 0.5).clamp(0, 1)

# Normalize input tensor
input_tensor = tensor
if input_tensor.max() > 1.0:
input_tensor = input_tensor / 255.0

# Ensure consistent format
# Ensure consistent format for previous output
if prev_output.dim() == 4 and prev_output.shape[0] == 1:
prev_output = prev_output[0]
if input_tensor.dim() == 4 and input_tensor.shape[0] == 1:
input_tensor = input_tensor[0]

try:
# Compute optical flow and warp on GPU using TensorRT
warped_tensor = self._compute_and_warp_tensor(input_tensor, prev_output)
# Compute optical flow between consecutive input frames
# then create flow visualization
flow_image_tensor = self._compute_flow_image_tensor(input_tensor, prev_output)

# Check output format
output_format = self.params.get('output_format', 'concat')
if output_format == "concat":
# Concatenate current frame + warped frame for TemporalNet2 (6 channels)
result_tensor = self._concatenate_frames_tensor(input_tensor, warped_tensor)
# Concatenate previous output + flow visualization for TemporalNet2 (6 channels)
result_tensor = self._concatenate_frames_tensor(prev_output, flow_image_tensor)
else:
# Return only warped frame (3 channels)
result_tensor = warped_tensor
# Return only flow visualization (3 channels)
result_tensor = flow_image_tensor

# Ensure correct output format
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)

result = result_tensor.to(device=self.device, dtype=self.dtype)

# Store current input for next iteration
self.prev_input_frame = input_tensor.detach().clone()
except Exception as e:
logger.error(f"_process_tensor_core: TensorRT optical flow failed: {e}")
output_format = self.params.get('output_format', 'concat')
if output_format == "concat":
# Create 6-channel fallback by concatenating current frame with itself
result_tensor = self._concatenate_frames_tensor(input_tensor, input_tensor)
# Create 6-channel fallback: previous output + black flow image
black_flow = torch.zeros_like(prev_output)
result_tensor = self._concatenate_frames_tensor(prev_output, black_flow)
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
result = result_tensor.to(device=self.device, dtype=self.dtype)
else:
# Create 6-channel fallback by concatenating current frame with itself
result_tensor = self._concatenate_frames_tensor(input_tensor, input_tensor)
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
result = result_tensor.to(device=self.device, dtype=self.dtype)
# Return black flow image as fallback
black_flow = torch.zeros_like(input_tensor)
if black_flow.dim() == 3:
black_flow = black_flow.unsqueeze(0)
result = black_flow.to(device=self.device, dtype=self.dtype)
else:
# First frame or no previous output available
# First frame or no previous data available
self._first_frame = False
if tensor.dim() == 3:
tensor = tensor.unsqueeze(0)

# Handle 6-channel output for first frame
output_format = self.params.get('output_format', 'concat')
if output_format == "concat":
# For first frame, duplicate current frame to create 6-channel output
if tensor.dim() == 4 and tensor.shape[0] == 1:
current_tensor = tensor[0]
else:
current_tensor = tensor
result_tensor = self._concatenate_frames_tensor(current_tensor, current_tensor)
# For first frame: current frame + black flow image
black_flow = torch.zeros_like(input_tensor)
result_tensor = self._concatenate_frames_tensor(input_tensor, black_flow)
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
result = result_tensor.to(device=self.device, dtype=self.dtype)
else:
# Create 6-channel fallback by concatenating current frame with itself
if tensor.dim() == 4 and tensor.shape[0] == 1:
current_tensor = tensor[0]
else:
current_tensor = tensor
result_tensor = self._concatenate_frames_tensor(current_tensor, current_tensor)
if result_tensor.dim() == 3:
result_tensor = result_tensor.unsqueeze(0)
result = result_tensor.to(device=self.device, dtype=self.dtype)
# Return black flow image for first frame
black_flow = torch.zeros_like(input_tensor)
if black_flow.dim() == 3:
black_flow = black_flow.unsqueeze(0)
result = black_flow.to(device=self.device, dtype=self.dtype)

# Store current input for next iteration
self.prev_input_frame = input_tensor.detach().clone()

return result

def _compute_and_warp_tensor(self, current_tensor: torch.Tensor, prev_tensor: torch.Tensor) -> torch.Tensor:
def _compute_flow_image_tensor(self, current_tensor: torch.Tensor, prev_output_tensor: torch.Tensor) -> torch.Tensor:
"""
Compute optical flow using TensorRT and warp previous tensor
Compute optical flow between consecutive input frames and convert to flow visualization

Args:
current_tensor: Current input frame tensor (CHW format, [0,1]) on GPU
prev_tensor: Previous pipeline output tensor (CHW format, [0,1]) on GPU
prev_output_tensor: Previous pipeline output tensor (CHW format, [0,1]) on GPU

Returns:
Warped previous frame tensor on GPU
Flow visualization tensor (RGB image showing flow vectors) on GPU
"""
target_width, target_height = self.get_target_dimensions()

# Convert to float32 for TensorRT processing
current_tensor = current_tensor.to(device=self.device, dtype=torch.float32)
prev_tensor = prev_tensor.to(device=self.device, dtype=torch.float32)
prev_input_tensor = self.prev_input_frame.to(device=self.device, dtype=torch.float32)

# Resize for flow computation if needed (keep on GPU)
# Resize input frames for flow computation if needed (keep on GPU)
if current_tensor.shape[-1] != self.detect_resolution or current_tensor.shape[-2] != self.detect_resolution:
current_resized = F.interpolate(
current_tensor.unsqueeze(0),
size=(self.detect_resolution, self.detect_resolution),
mode='bilinear',
align_corners=False
).squeeze(0)
prev_resized = F.interpolate(
prev_tensor.unsqueeze(0),
prev_input_resized = F.interpolate(
prev_input_tensor.unsqueeze(0),
size=(self.detect_resolution, self.detect_resolution),
mode='bilinear',
align_corners=False
).squeeze(0)
else:
current_resized = current_tensor
prev_resized = prev_tensor
prev_input_resized = prev_input_tensor

# Compute optical flow using TensorRT
flow = self._compute_optical_flow_tensorrt(current_resized, prev_resized)
# Compute optical flow between consecutive input frames
flow = self._compute_optical_flow_tensorrt(prev_input_resized, current_resized)

# Apply flow strength scaling (GPU operation)
flow_strength = self.params.get('flow_strength', 1.0)
if flow_strength != 1.0:
flow = flow * flow_strength

# Warp previous frame using flow (GPU operation)
warped_frame = self._warp_frame_tensor(prev_resized, flow)
# Convert flow to RGB visualization using flow_to_image
flow_image = flow_to_image(flow.unsqueeze(0)).squeeze(0) # flow_to_image expects batch dimension

# Convert from [0,255] uint8 to [0,1] float format
flow_image = flow_image.float() / 255.0

# Resize back to target resolution if needed (keep on GPU)
if warped_frame.shape[-1] != target_width or warped_frame.shape[-2] != target_height:
warped_frame = F.interpolate(
warped_frame.unsqueeze(0),
# Resize to target resolution if needed (keep on GPU)
if flow_image.shape[-1] != target_width or flow_image.shape[-2] != target_height:
flow_image = F.interpolate(
flow_image.unsqueeze(0),
size=(target_height, target_width),
mode='bilinear',
align_corners=False
).squeeze(0)

# Convert to processor's dtype only at the very end
result = warped_frame.to(dtype=self.dtype)
# Convert to processor's dtype
result = flow_image.to(dtype=self.dtype)

return result


def _compute_optical_flow_tensorrt(self, frame1: torch.Tensor, frame2: torch.Tensor) -> torch.Tensor:
"""
Expand Down Expand Up @@ -567,102 +576,47 @@ def _compute_optical_flow_tensorrt(self, frame1: torch.Tensor, frame2: torch.Ten

cuda_stream = torch.cuda.current_stream().cuda_stream
result = self.trt_engine.infer(feed_dict, cuda_stream)
flow = result['flow'][0] # Remove batch dimension
flow = result['flow'][0]

return flow



def _warp_frame_tensor(self, frame: torch.Tensor, flow: torch.Tensor) -> torch.Tensor:
"""
Warp frame using optical flow with cached coordinate grids

Args:
frame: Frame to warp (CHW format)
flow: Optical flow (2HW format)

Returns:
Warped frame tensor
"""
H, W = frame.shape[-2:]

# Use cached grid if available
grid_key = (H, W)
if grid_key not in self._grid_cache:
grid_y, grid_x = torch.meshgrid(
torch.arange(H, device=self.device, dtype=torch.float32),
torch.arange(W, device=self.device, dtype=torch.float32),
indexing='ij'
)
self._grid_cache[grid_key] = (grid_x, grid_y)
else:
grid_x, grid_y = self._grid_cache[grid_key]

# Apply flow to coordinates
new_x = grid_x + flow[0]
new_y = grid_y + flow[1]

# Normalize coordinates to [-1, 1] for grid_sample
new_x = 2.0 * new_x / (W - 1) - 1.0
new_y = 2.0 * new_y / (H - 1) - 1.0

# Create sampling grid (HW2 format for grid_sample)
grid = torch.stack([new_x, new_y], dim=-1).unsqueeze(0)

# Warp frame
warped_batch = F.grid_sample(
frame.unsqueeze(0),
grid,
mode='bilinear',
padding_mode='border',
align_corners=True
)

result = warped_batch.squeeze(0)

return result

def _concatenate_frames(self, current_image: Image.Image, warped_image: Image.Image) -> Image.Image:
"""Concatenate current frame and warped previous frame for TemporalNet2 (6-channel input)"""
# Convert to tensors and use tensor concatenation for consistency
current_tensor = self.pil_to_tensor(current_image).squeeze(0)
warped_tensor = self.pil_to_tensor(warped_image).squeeze(0)
result_tensor = self._concatenate_frames_tensor(current_tensor, warped_tensor)
return self.tensor_to_pil(result_tensor)

def _concatenate_frames_tensor(self, current_tensor: torch.Tensor, warped_tensor: torch.Tensor) -> torch.Tensor:
def _concatenate_frames_tensor(self, first_tensor: torch.Tensor, second_tensor: torch.Tensor) -> torch.Tensor:
"""
Concatenate current frame and warped previous frame tensors for TemporalNet2 (6-channel input)
Concatenate two frame tensors for TemporalNet2 (6-channel input)

Args:
current_tensor: Current input frame tensor (CHW format)
warped_tensor: Warped previous frame tensor (CHW format)
first_tensor: First frame tensor (CHW format)
second_tensor: Second frame tensor (CHW format)

Returns:
Concatenated tensor (6CHW format)
"""
# Ensure same size
if current_tensor.shape != warped_tensor.shape:
if first_tensor.shape != second_tensor.shape:
target_width, target_height = self.get_target_dimensions()

if current_tensor.shape[-2:] != (target_height, target_width):
current_tensor = F.interpolate(
current_tensor.unsqueeze(0),
if first_tensor.shape[-2:] != (target_height, target_width):
first_tensor = F.interpolate(
first_tensor.unsqueeze(0),
size=(target_height, target_width),
mode='bilinear',
align_corners=False
).squeeze(0)

if warped_tensor.shape[-2:] != (target_height, target_width):
warped_tensor = F.interpolate(
warped_tensor.unsqueeze(0),
if second_tensor.shape[-2:] != (target_height, target_width):
second_tensor = F.interpolate(
second_tensor.unsqueeze(0),
size=(target_height, target_width),
mode='bilinear',
align_corners=False
).squeeze(0)

# Concatenate along channel dimension: [current_R, current_G, current_B, warped_R, warped_G, warped_B]
concatenated = torch.cat([current_tensor, warped_tensor], dim=0)
# Concatenate along channel dimension: [first_R, first_G, first_B, second_R, second_G, second_B]
concatenated = torch.cat([first_tensor, second_tensor], dim=0)

return concatenated

Expand All @@ -671,6 +625,7 @@ def reset(self):
Reset the preprocessor state (useful for new sequences)
"""
self._first_frame = True
self.prev_input_frame = None
# Clear caches to free memory
self._grid_cache.clear()
self._tensor_cache.clear()
Expand Down