In [None]:
import numpy as np
import torch
from lstm_model_arch import TennisPointLSTM
import scipy.ndimage
from typing import Optional

"""
My test.py file is my current evaluation file. however, it just looks at sequences, not the whole video. 
I need to further establish my post processing pipeline. The final output of my pipeline should be 
a csv of start_time, end_times, which i can compare to the annotated targets. 
For now, we will use the same gaussian smoothing and hysteresis filtering that we're using in the test.py file. 

Your task is to write a new file that runs the inference on an entire video's sequence file
"""

GAUSSIAN_SIGMA = 1.5  # for smoothing


In [2]:
def load_model_from_checkpoint(
    checkpoint_path: str,
    input_size: int = 360,
    hidden_size: int = 128,
    num_layers: int = 2,
    bidirectional: bool = True,
    return_logits: bool = False,
):
    """Load model weights from checkpoint, adapting architecture if needed."""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    ckpt = torch.load(checkpoint_path, map_location=device)

    # Extract model state dict
    if isinstance(ckpt, dict) and 'model_state_dict' in ckpt:
        state_dict = ckpt['model_state_dict']
    elif isinstance(ckpt, dict) and any(k.startswith('lstm.') or k.startswith('fc.') for k in ckpt.keys()):
        state_dict = ckpt
    else:
        # Fallback: attempt to use as state_dict
        state_dict = ckpt

    # Infer architecture from weights if possible
    inferred_input_size = input_size
    inferred_hidden_size = hidden_size
    inferred_num_layers = num_layers
    inferred_bidirectional = bidirectional

    try:
        # weight_ih_l0 shape: (4*hidden_size, input_size)
        w_ih_l0 = state_dict.get('lstm.weight_ih_l0', None)
        if w_ih_l0 is not None:
            inferred_hidden_size = w_ih_l0.shape[0] // 4
            inferred_input_size = w_ih_l0.shape[1]

        # Determine num_layers by counting layers
        layer_indices = set()
        for k in state_dict.keys():
            if k.startswith('lstm.weight_ih_l'):
                try:
                    idx_str = k.split('lstm.weight_ih_l')[1]
                    idx = int(idx_str.split('_')[0]) if '_' in idx_str else int(idx_str)
                    layer_indices.add(idx)
                except Exception:
                    pass
        if layer_indices:
            inferred_num_layers = max(layer_indices) + 1

        # Bidirectionality: presence of any reverse weights
        inferred_bidirectional = any('_reverse' in k for k in state_dict.keys())
    except Exception:
        pass

    # Build model with inferred architecture
    model = TennisPointLSTM(
        input_size=inferred_input_size,
        hidden_size=inferred_hidden_size,
        num_layers=inferred_num_layers,
        dropout=0.2,
        bidirectional=inferred_bidirectional,
        return_logits=return_logits,
    )

    # Load strictly now that shapes should match
    model.load_state_dict(state_dict, strict=True)
    model.to(device)
    model.eval()
    
    print(
        f"Loaded checkpoint: {checkpoint_path} "
        f"(input_size={inferred_input_size}, hidden_size={inferred_hidden_size}, "
        f"num_layers={inferred_num_layers}, bidirectional={inferred_bidirectional})"
    )
    return model, device


In [14]:
model_path = 'checkpoints/seq_len150/best_model.pth'
model, device = load_model_from_checkpoint(model_path, bidirectional=True, return_logits=False)



# load whole feature npz file for a specific video
video_feature_path = 'pose_data/features/yolos_0.25conf_15fps_0s_to_99999s/Aditi Narayan ï½œ Matchplay_features.npz'
data = np.load(video_feature_path)
targets = data['targets']
# create our ordered list of sequences with 50% overlap: must carefully track frame numbers

num_frames = len(data['features'])
sequence_length = 150 
overlap = 75
if num_frames < sequence_length:
    raise ValueError("input video too short")


Loaded checkpoint: checkpoints/seq_len150/best_model.pth (input_size=360, hidden_size=128, num_layers=2, bidirectional=True)


In [15]:
# Generate sequence start indices with 50% overlap (150 frame steps)
# Ensure we cover all frames without gaps
start_idxs = []
idx = 0
while idx + sequence_length <= num_frames:
    start_idxs.append(idx)
    idx += overlap

# If the last sequence doesn't reach the end, add one more sequence
if start_idxs[-1] + sequence_length < num_frames:
    start_idxs.append(num_frames - sequence_length)

print(f"Generated {len(start_idxs)} sequences for {num_frames} frames")
print(f"Coverage: {start_idxs[0]} to {start_idxs[-1] + sequence_length}")
print(f"Start indices: {start_idxs[:5]}...{start_idxs[-5:] if len(start_idxs) > 5 else start_idxs}")

# Check for gaps
for i in range(len(start_idxs) - 1):
    gap = start_idxs[i+1] - (start_idxs[i] + sequence_length)
    if gap > 0:
        print(f"WARNING: Gap of {gap} frames between sequences {i} and {i+1}")
    elif gap < -overlap:
        print(f"WARNING: Excessive overlap of {-gap} frames between sequences {i} and {i+1}")

ordered_sequences = []
output_arr = np.full((3, num_frames), np.nan)


Generated 232 sequences for 17439 frames
Coverage: 0 to 17439
Start indices: [0, 75, 150, 225, 300]...[17025, 17100, 17175, 17250, 17289]


In [16]:
# now we construct the feature lists, perform inference, and fill output array, tracking start indexes
print(f"Running inference on {len(start_idxs)} sequences...")
for seq_idx, i in enumerate(start_idxs):
    # slice features and convert to tensor of shape (1, sequence_length, input_size)
    seq_np = data['features'][i:i+sequence_length, :].astype(np.float32)
    seq_tensor = torch.from_numpy(seq_np).unsqueeze(0).to(device)
    with torch.no_grad():
        output_tensor = model(seq_tensor)  # (1, seq_len, 1)
    output_sequence = output_tensor.squeeze().detach().cpu().numpy()  # (seq_len,)
    
    # Find which row to place this sequence in
    placed = False
    for row in range(3):
        if np.isnan(output_arr[row, i:i+sequence_length]).all():
            output_arr[row, i:i+sequence_length] = output_sequence
            if seq_idx < 5 or seq_idx >= len(start_idxs) - 5:  # Debug first/last few
                print(f"  Seq {seq_idx}: frames {i}-{i+sequence_length-1} -> row {row}")
            placed = True
            break
    
    if not placed:
        print(f"ERROR: Could not place sequence {seq_idx} (frames {i}-{i+sequence_length-1})")
        raise ValueError('res arr filling logic messed up')

# now we have filled res_arr. next, get 1, num_frames array by averaging over 0th axis, and apply gaussian smoothing
print("Checking output_arr coverage...")
for row in range(3):
    nan_count = np.isnan(output_arr[row, :]).sum()
    print(f"  Row {row}: {nan_count}/{num_frames} NaNs ({100*nan_count/num_frames:.1f}%)")

avg_probs = np.nanmean(output_arr, axis=0)
nan_count_avg = np.isnan(avg_probs).sum()
print(f"avg_probs: {nan_count_avg}/{num_frames} NaNs ({100*nan_count_avg/num_frames:.1f}%)")

if nan_count_avg > 0:
    # Find NaN ranges
    nan_mask = np.isnan(avg_probs)
    nan_starts = np.where(np.diff(np.concatenate(([False], nan_mask))))[0]
    nan_ends = np.where(np.diff(np.concatenate((nan_mask, [False]))))[0]
    print("NaN ranges:")
    for start, end in zip(nan_starts, nan_ends):
        print(f"  frames {start}-{end-1} ({end-start} frames)")

Running inference on 232 sequences...
  Seq 0: frames 0-149 -> row 0
  Seq 1: frames 75-224 -> row 1
  Seq 2: frames 150-299 -> row 0
  Seq 3: frames 225-374 -> row 1
  Seq 4: frames 300-449 -> row 0
  Seq 227: frames 17025-17174 -> row 1
  Seq 228: frames 17100-17249 -> row 0
  Seq 229: frames 17175-17324 -> row 1
  Seq 230: frames 17250-17399 -> row 0
  Seq 231: frames 17289-17438 -> row 2
Checking output_arr coverage...
  Row 0: 39/17439 NaNs (0.2%)
  Row 1: 189/17439 NaNs (1.1%)
  Row 2: 17289/17439 NaNs (99.1%)
avg_probs: 0/17439 NaNs (0.0%)


In [17]:
simple = (avg_probs >= 0.5).astype(int)
np.sum(simple == targets)/len(avg_probs)

np.float64(0.7253856299099719)

In [18]:
np.unique(targets)

array([0, 1])

In [19]:
smoothed_probs = scipy.ndimage.gaussian_filter1d(avg_probs.astype(np.float32), sigma=GAUSSIAN_SIGMA)
_ = smoothed_probs  # silence variable display in notebooks

# perform hysteresis filtering on smoothed sequence

# use hysteresis for start/end times, write to csv

# %%%
smoothed_probs.shape
# %%

print("avg_probs head:", avg_probs[:30])
print("smoothed_probs head:", smoothed_probs[:30])
print(
    "smoothed stats:",
    "min=", float(np.nanmin(smoothed_probs)),
    "max=", float(np.nanmax(smoothed_probs)),
    "nans=", int(np.isnan(smoothed_probs).sum()),
)
# %%
# now run hysteresis filtering to get actual discrete in vs out of point vals? 
# or something to get discrete in vs out of points


# %%



def hysteresis_threshold(
    values: np.ndarray,
    low: float = 0.3,
    high: float = 0.7,
    min_duration: int = 0,
) -> np.ndarray:
    """Apply 1D hysteresis thresholding to a probability-like signal.

    - Enter active state when values >= high
    - Exit active state when values < low
    - Optional min_duration suppresses short active segments
    Returns a 0/1 array of the same length.
    """
    assert 0.0 <= low < high <= 1.0, "Require 0 <= low < high <= 1"
    n = len(values)
    pred = np.zeros(n, dtype=np.int8)
    active = False
    start_idx: Optional[int] = None

    for i in range(n):
        v = values[i]
        if not active:
            if v >= high:
                active = True
                start_idx = i
        else:
            if v < low:
                end_idx = i
                if start_idx is not None and (end_idx - start_idx) >= max(0, min_duration):
                    pred[start_idx:end_idx] = 1
                active = False
                start_idx = None

    # Handle active segment reaching the end
    if active and start_idx is not None:
        end_idx = n
        if (end_idx - start_idx) >= max(0, min_duration):
            pred[start_idx:end_idx] = 1

    return pred.astype(np.int32)

# %%
for high_thresh in range(50, 90, 5):
    for low_thresh in range(10, 50, 5):
        # now we compare output!
        HIGH_THRESHOLD = high_thresh/100  # for starting a point
        LOW_THRESHOLD =   low_thresh / 100 # for ending a point


        filtered_sequence = hysteresis_threshold(smoothed_probs, LOW_THRESHOLD, HIGH_THRESHOLD, min_duration=6)
        accuracy = np.sum(filtered_sequence == targets) / num_frames
        print(f"Accuracy: {accuracy:.3f}, High: {HIGH_THRESHOLD}, Low: {LOW_THRESHOLD} ")


# %%
HIGH_THRESHOLD = 60  # for starting a point
LOW_THRESHOLD =   40 # for ending a point


filtered_sequence = hysteresis_threshold(smoothed_probs, LOW_THRESHOLD, HIGH_THRESHOLD, min_duration=6)
accuracy = np.sum(filtered_sequence == targets) / num_frames
print(f"Accuracy: {accuracy:.3f}, High: {HIGH_THRESHOLD}, Low: {LOW_THRESHOLD} ")

np.unique(filtered_sequence)

# %% next, based on point vs not in point, we create final start_time,end_time times


avg_probs head: [0.25714421 0.24497803 0.24257775 0.23969215 0.23660417 0.23375887
 0.23123586 0.22903688 0.22754726 0.22592807 0.22493242 0.22408922
 0.2234481  0.22302155 0.2226958  0.22251448 0.22248556 0.22246866
 0.22262025 0.22260508 0.22263879 0.22263867 0.22269064 0.22282159
 0.22285062 0.22299479 0.22307621 0.22311153 0.22327054 0.22334903]
smoothed_probs head: [0.25013953 0.24738353 0.24353737 0.23994568 0.23684816 0.23409574
 0.23164853 0.2295414  0.22778259 0.22634749 0.22520576 0.22432293
 0.22365868 0.22317642 0.22284788 0.22265036 0.22255994 0.22254567
 0.22257072 0.2226044  0.22263642 0.22267479 0.22272979 0.22280224
 0.22288595 0.2229744  0.22306432 0.22315988 0.2232749  0.22343719]
smoothed stats: min= 0.002126772655174136 max= 0.8780487775802612 nans= 0
Accuracy: 0.706, High: 0.5, Low: 0.1 
Accuracy: 0.717, High: 0.5, Low: 0.15 
Accuracy: 0.718, High: 0.5, Low: 0.2 
Accuracy: 0.718, High: 0.5, Low: 0.25 
Accuracy: 0.719, High: 0.5, Low: 0.3 
Accuracy: 0.724, High: 0.

AssertionError: Require 0 <= low < high <= 1