In [6]:
from datasets.smpldata import SmplData
from inference_hoi_model import HoiResult
from object_contact_prediction.cpdm_dno_conds import find_contiguous_static_blocks
import torch
import torch.nn.functional as F
from typing import Optional

def resample(smpl_data: SmplData, n_frames: int) -> SmplData:
    def _resample_tensor(x: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
        if x is None:
            return None
        T = x.shape[0]
        if T == n_frames:
            return x
        # flatten non-time dims into "features"
        rest = x.shape[1:]
        x_flat = x.reshape(T, -1).transpose(0, 1).unsqueeze(0)  # (1, features, T)
        y_flat = F.interpolate(x_flat, size=n_frames, mode='linear', align_corners=True)
        y = y_flat.squeeze(0).transpose(0, 1).reshape(n_frames, *rest)
        return y

    data_dict = smpl_data.to_dict()
    for key, val in data_dict.items():
        if isinstance(val, torch.Tensor):
            data_dict[key] = _resample_tensor(val)
    return SmplData(**data_dict)

def remove_short_false_segments(arr: torch.Tensor, min_length: int) -> torch.Tensor:
    arr = arr.bool()
    padded = torch.cat([torch.tensor([False]), arr, torch.tensor([False])])
    diffs = padded[1:] != padded[:-1]
    idxs = torch.nonzero(diffs).flatten()
    starts, ends = idxs[::2], idxs[1::2]
    for s, e in zip(starts, ends):
        if not arr[s] and (e - s) < min_length:
            arr[s:e] = True
    return arr

def get_longest_contact_range(smpldata: SmplData):
    has_contact = (smpldata.contact > 0.5).any(dim=1)  # (seq, n_anchors)
    has_contact = remove_short_false_segments(has_contact, min_length=2)
    contact_blocks = find_contiguous_static_blocks(~has_contact)
    start, end = sorted(contact_blocks, key=lambda rng: rng[1] - rng[0])[-1]
    return start, end

result_path = "/home/dcor/roeyron/trumans_utils/results/Results_May20/1b_inference_only_0/cphoi__cphoi_05011024_c15p100_v0__model000120000__0014__s10_bowl_pass_1__bowl__The_person_is_passing_a_bowl__phase0.pickle"

# Option 1 - with resampling
result = HoiResult.load(result_path)
smpldata = result.smpldata
start, end = get_longest_contact_range(smpldata)
smpldata = smpldata.cut(start, end)
smpldata = resample(smpldata, 30)

# Option 2 - with resampling
result = HoiResult.load(result_path)
smpldata = result.smpldata
start, end = get_longest_contact_range(smpldata)
mid = (start + end) // 2
smpldata = smpldata.cut(mid-15, mid+15)