In [1]:
frame_to_cluster_mapping_path = "cluster_analysis/nov/18/23_18_56_10000_0.6_0.1_0.1_0.1_0.1/"

In [2]:
import torch
import numpy as np
import pickle


# Extract Euler angles from rotation matrix
def get_euler_angles_from_rotation_matrix(rot_matrix):
    rot_matrix = rot_matrix.permute(0, 2, 1)
    yaw = torch.arcsin(-rot_matrix[:, 2, 0])
    near_pi_over_2 = torch.isclose(torch.cos(yaw), torch.tensor(0.0), atol=1e-6)

    pitch = torch.where(
        ~near_pi_over_2,
        torch.atan2(rot_matrix[:, 2, 1], rot_matrix[:, 2, 2]),
        torch.atan2(rot_matrix[:, 1, 2], rot_matrix[:, 1, 1])
    )

    roll = torch.where(
        ~near_pi_over_2,
        torch.atan2(rot_matrix[:, 1, 0], rot_matrix[:, 0, 0]),
        torch.zeros_like(yaw)
    )

    pitch = pitch * 180 / torch.pi
    yaw = yaw * 180 / torch.pi
    roll = roll * 180 / torch.pi

    return pitch, yaw, roll


def get_rotation_matrix(pitch_, yaw_, roll_):
    """ the input is in degree
    """
    # transform to radian
    pitch = pitch_ / 180 * torch.pi
    yaw = yaw_ / 180 * torch.pi
    roll = roll_ / 180 * torch.pi

    device = pitch.device

    if pitch.ndim == 1:
        pitch = pitch.unsqueeze(1)
    if yaw.ndim == 1:
        yaw = yaw.unsqueeze(1)
    if roll.ndim == 1:
        roll = roll.unsqueeze(1)

    # calculate the euler matrix
    bs = pitch.shape[0]
    ones = torch.ones([bs, 1]).to(device)
    zeros = torch.zeros([bs, 1]).to(device)
    x, y, z = pitch, yaw, roll

    rot_x = torch.cat([
        ones, zeros, zeros,
        zeros, torch.cos(x), -torch.sin(x),
        zeros, torch.sin(x), torch.cos(x)
    ], dim=1).reshape([bs, 3, 3])

    rot_y = torch.cat([
        torch.cos(y), zeros, torch.sin(y),
        zeros, ones, zeros,
        -torch.sin(y), zeros, torch.cos(y)
    ], dim=1).reshape([bs, 3, 3])

    rot_z = torch.cat([
        torch.cos(z), -torch.sin(z), zeros,
        torch.sin(z), torch.cos(z), zeros,
        zeros, zeros, ones
    ], dim=1).reshape([bs, 3, 3])

    rot = rot_z @ rot_y @ rot_x
    return rot.permute(0, 2, 1)  # transpose

# Convert yaw, pitch, and roll to a rotation matrix
def euler_angles_to_rotation_matrix(pitch, yaw, roll):
    PI = np.pi
    # Convert to torch tensors and add batch dimension
    pitch_ = torch.tensor([pitch], dtype=torch.float32)
    yaw_ = torch.tensor([yaw], dtype=torch.float32)
    roll_ = torch.tensor([roll], dtype=torch.float32)

    # Get rotation matrix using provided function
    R = get_rotation_matrix(pitch_, yaw_, roll_)

    # Convert to numpy and reshape to (1,3,3)
    R = R.cpu().numpy().astype(np.float32)

    return R


# Function to extract the full 208-dimensional vector from frame data
def extract_full_vector(frame_data):
    c_d_eyes = frame_data['c_d_eyes_lst'][0].reshape(-1)  # 2 values
    c_d_lip = frame_data['c_d_lip_lst'][0].reshape(-1)    # 1 value

    driving_template = frame_data['driving_template_dct']
    c_eyes = driving_template['motion'][0]['c_eyes_lst'][0].reshape(-1)  # 2 values
    c_lip = driving_template['motion'][0]['c_lip_lst'][0].reshape(-1)    # 1 value

    motion = driving_template['motion'][0]
    scale = np.array(motion['scale']).reshape(-1)         # 1 value
    t = motion['t'].reshape(-1)                           # 3 values
    R = motion['R'].reshape(1, 3, 3)                      # 9 values in matrix form
    exp = motion['exp'].reshape(-1)                       # 63 values
    x_s = motion['x_s'].reshape(-1)                       # 63 values
    kp = motion['kp'].reshape(-1)                         # 63 values actual value now becomes 202

    # Convert R to pitch, yaw, and roll using the function
    pitch, yaw, roll = get_euler_angles_from_rotation_matrix(torch.tensor(R))
    euler_angles = np.array([pitch.item(), yaw.item(), roll.item()])

    if not np.array_equal(c_d_eyes, c_eyes):
        print("Eyes arrays not equal")
    if not np.array_equal(c_d_lip, c_lip):
        print("Lip arrays not equal")

    # print(c_d_eyes.shape, c_d_lip.shape, c_eyes.shape, c_lip.shape, scale.shape, t.shape, euler_angles.shape, exp.shape, x_s.shape, kp.shape)
    # print("(2,) (1,) (2,) (1,) (1,) (3,) (3,) (63,) (63,) (63,)")
    # 202 values

    # Combine the components into a full vector excluding R
    vector = np.concatenate([c_d_eyes, c_d_lip, c_eyes, c_lip, scale, t, euler_angles, exp, x_s, kp])

    return vector

def unflatten_vector(avg_vector):
    # Convert flattened vector back to original format
    c_d_eyes = np.array(avg_vector[0:2], dtype=np.float32).reshape(1, 2)
    c_d_lip = np.array(avg_vector[2:3], dtype=np.float32).reshape(1, 1)
    c_eyes = np.array(avg_vector[3:5], dtype=np.float32).reshape(1, 2)
    c_lip = np.array(avg_vector[5:6], dtype=np.float32).reshape(1, 1)
    scale = np.array(avg_vector[6:7], dtype=np.float32).reshape(1, 1)
    t = np.array(avg_vector[7:10], dtype=np.float32).reshape(1, 3)

    # Convert to rotation matrix and update
    R = euler_angles_to_rotation_matrix(avg_vector[10], avg_vector[11], avg_vector[12])

    # Expression, shape and keypoint parameters
    exp = np.array(avg_vector[13:76], dtype=np.float32).reshape(1, 21, 3)
    x_s = np.array(avg_vector[76:139], dtype=np.float32).reshape(1, 21, 3)
    kp = np.array(avg_vector[139:202], dtype=np.float32).reshape(1, 21, 3)

    # Return dictionary in original format
    return {
        'c_d_eyes_lst': c_d_eyes,
        'c_d_lip_lst': c_d_lip,
        'driving_template_dct': {
            'motion': [{
                'scale': scale,
                'R': R,
                't': t,
                'c_eyes_lst': c_eyes,
                'c_lip_lst': c_lip,
                'exp': exp,
                'x_s': x_s,
                'kp': kp
            }],

        }
    }


In [3]:
def interpolate_clusters(cluster_id_1, cluster_id_2, average_descriptor_dict, num_intermediate_clusters=9):
    # Extract descriptors for the given cluster IDs
    descriptor_1 = average_descriptor_dict[cluster_id_1]
    descriptor_2 = average_descriptor_dict[cluster_id_2]

    # Flatten the descriptors using the extract_vectors function
    flattened_1 = extract_full_vector(descriptor_1)
    flattened_2 = extract_full_vector(descriptor_2)

    # Interpolate between the two flattened descriptors
    intermediate_vectors = []
    for t in range(1, num_intermediate_clusters + 1):  # Generate specified number of intermediate vectors
        alpha = t / (num_intermediate_clusters + 1)  # Adjust alpha to control the number of intermediate clusters
        intermediate_vector = (1 - alpha) * flattened_1 + alpha * flattened_2
        intermediate_vectors.append(unflatten_vector(intermediate_vector))
        
    return descriptor_1, descriptor_2, intermediate_vectors

In [4]:
with open(f"{frame_to_cluster_mapping_path}/averaged_descriptors_raw.pkl", 'rb') as file:
    average_descriptor_dict = pickle.load(file)

In [5]:
with open(f'{frame_to_cluster_mapping_path}/frame_to_cluster_mapping.pkl', 'rb') as f:
    frame_to_cluster_mapping = pickle.load(f)
frame_to_cluster_mapping_transformed = {}
for key in frame_to_cluster_mapping:
    frame_to_cluster_mapping_transformed[key] = [cluster_id for _, cluster_id in frame_to_cluster_mapping[key]]

In [6]:
# Create a transition count matrix
def compute_basic_transition_matrix(frame_to_cluster_mapping_transformed):
    # First, get all unique clusters across all videos
    all_clusters = set()
    for cluster_ids in frame_to_cluster_mapping_transformed.values():
        all_clusters.update(cluster_ids)

    n_clusters = len(all_clusters)
    cluster_to_idx = {cluster: idx for idx, cluster in enumerate(sorted(all_clusters))}

    # Initialize transition count matrix
    transitions = np.zeros((n_clusters, n_clusters))

    # Count transitions
    for cluster_ids in frame_to_cluster_mapping_transformed.values():
        for i in range(len(cluster_ids) - 1):
            current = cluster_to_idx[cluster_ids[i]]
            next_cluster = cluster_to_idx[cluster_ids[i + 1]]
            transitions[current, next_cluster] += 1

    # Convert to probabilities
    # Add small epsilon to avoid division by zero
    row_sums = transitions.sum(axis=1, keepdims=True)
    row_sums = np.where(row_sums == 0, 1e-10, row_sums)  # Replace zeros with small value
    transition_probs = transitions / row_sums

    return transition_probs, cluster_to_idx
def get_nonzero_transitions(transition_probs, cluster_to_idx, cluster_id):
    # Get the index for the given cluster
    cluster_idx = cluster_to_idx[cluster_id]

    # Get transition probabilities for this cluster
    cluster_transitions = transition_probs[cluster_idx]

    # Find indices where probabilities are non-zero
    nonzero_indices = np.where(cluster_transitions > 0)[0]

    # Convert indices back to cluster IDs and get probabilities
    results = []
    for idx in nonzero_indices:
        # Find cluster ID from index
        cluster = [k for k, v in cluster_to_idx.items() if v == idx][0]
        prob = cluster_transitions[idx]
        results.append((cluster, prob))

    return results

transition_probs, cluster_to_idx = compute_basic_transition_matrix(frame_to_cluster_mapping_transformed)

In [7]:
all_nonzero_transitions_dict = {}

for cluster in cluster_to_idx.keys():
    # Get all non-zero transitions for the current cluster
    nonzero_transitions = get_nonzero_transitions(transition_probs, cluster_to_idx, cluster)

    # Exclude the cluster itself from the transitions
    filtered_transitions = [next_cluster for next_cluster, prob in nonzero_transitions if next_cluster != cluster]

    # Store the transitions in the dictionary
    all_nonzero_transitions_dict[cluster] = filtered_transitions

print("All transitions with non-zero probabilities for each cluster (excluding itself):")
for cluster, transitions in all_nonzero_transitions_dict.items():
    print(f"Cluster {cluster}: {transitions}")
    break

# Calculate the total number of transition probabilities in the dictionary
total_transitions = sum(len(transitions) for transitions in all_nonzero_transitions_dict.values())

print(f"Total number of non-zero transitions (excluding self-transitions): {total_transitions}")


All transitions with non-zero probabilities for each cluster (excluding itself):
Cluster 0: [9685]
Total number of non-zero transitions (excluding self-transitions): 23004


In [8]:
interpolated_clusters_dict = {}
for cluster, transitions in all_nonzero_transitions_dict.items():
    for next_cluster in transitions:
        descriptor_1, descriptor_2, intermediate_descriptors = interpolate_clusters(cluster, next_cluster, average_descriptor_dict, num_intermediate_clusters=2)
        if cluster not in interpolated_clusters_dict:
            interpolated_clusters_dict[cluster] = {}
        output = [descriptor_1] + intermediate_descriptors + [descriptor_2]
        if len(output) > 6:
            print(cluster, next_cluster)
        interpolated_clusters_dict[cluster][next_cluster] =output

In [24]:
import pickle

# Save the interpolated cluster dict to the frame to cluster mapping path
with open(f"{frame_to_cluster_mapping_path}/interpolated_descriptors.pkl", 'wb') as file:
    pickle.dump(interpolated_clusters_dict, file)


import pickle
# Save the interpolated cluster dict to the frame to cluster mapping path
with open(f"{frame_to_cluster_mapping_path}/all_nonzero_transitions_dict.pkl", 'wb') as file:
    pickle.dump(interpolated_clusters_dict, file)
