In [124]:
import numpy as np
import operator
import pickle
from scipy.spatial import KDTree
import tifffile

In [125]:
# Load list of outer+core segments
segments = []
timepoints = [0,1,2,3]

for timepoint in timepoints:
    with open(f'output/outer_segments_{timepoint}.pkl', 'rb') as f:
        segments.append(pickle.load(f))
    print(f"Number of outer segments in T{timepoint}: {len(segments[timepoint])}")

Number of outer segments in T0: 74
Number of outer segments in T1: 66
Number of outer segments in T2: 52
Number of outer segments in T3: 76


In [126]:
import numpy as np
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean

# Example segments
segments_0 = segments[0]
segments_1 = segments[1]
segments_2 = segments[2]
segments_3 = segments[3]

In [127]:
print(len(segments_0))
print(len(segments_1))
print(len(segments_2))
print(len(segments_3))

74
66
52
76


In [128]:
segments_0_ends = []
segments_1_ends = []
segments_2_ends = []
segments_3_ends = []

for segment in segments_0:
    segment = [segment[0], segment[-1]]
    segments_0_ends.append(segment)

for segment in segments_1:
    segment = [segment[0], segment[-1]]
    segments_1_ends.append(segment)

for segment in segments_2:
    segment = [segment[0], segment[-1]]
    segments_2_ends.append(segment)

for segment in segments_3:
    segment = [segment[0], segment[-1]]
    segments_3_ends.append(segment)

In [133]:
import numpy as np
from fastdtw import fastdtw
from scipy.spatial.distance import euclidean
import faiss
import concurrent.futures

def preprocess_segment(segment, max_length):
    """
    Flatten and pad/truncate a segment to a fixed length.
    """
    flat_segment = np.array(segment).flatten()
    if len(flat_segment) < max_length:
        # Pad with zeros if the segment is too short
        flat_segment = np.pad(flat_segment, (0, max_length - len(flat_segment)), 'constant')
    else:
        # Truncate if the segment is too long
        flat_segment = flat_segment[:max_length]
    return flat_segment

def build_faiss_index(segments, max_length):
    """
    Build a FAISS index for the given segments.
    """
    dim = max_length
    index = faiss.IndexFlatL2(dim)
    vectors = np.array([preprocess_segment(seg, max_length) for seg in segments])
    index.add(vectors)
    return index, vectors

def compute_dtw_distance(seg1, seg2):
    """
    Compute the DTW distance between two segments.
    """
    distance, _ = fastdtw(seg1, seg2, dist=euclidean)
    return distance

def find_best_match_faiss(segment, indices, vectors, max_length, top_k=5):
    """
    Find the best match for the given segment using FAISS and then refine using DTW.
    """
    query_vector = preprocess_segment(segment, max_length).reshape(1, -1)
    D, I = indices.search(query_vector, top_k)  # D is the distances, I is the indices
    
    min_distance = float('inf')
    best_index = -1

    for idx in I[0]:
        candidate_segment = vectors[idx].reshape(-1, 3)
        dist = compute_dtw_distance(segment, candidate_segment)
        if dist < min_distance:
            min_distance = dist
            best_index = idx

    return best_index, min_distance

def best_match(segment, timepoint1_segments, timepoint2_segments, timepoint3_segments):
    # Determine the maximum length for padding/truncation
    max_length = max(
        max(len(np.array(seg).flatten()) for seg in timepoint1_segments),
        max(len(np.array(seg).flatten()) for seg in timepoint2_segments),
        max(len(np.array(seg).flatten()) for seg in timepoint3_segments)
    )

    # Build FAISS indices
    index1, vectors1 = build_faiss_index(timepoint1_segments, max_length)
    index2, vectors2 = build_faiss_index(timepoint2_segments, max_length)
    index3, vectors3 = build_faiss_index(timepoint3_segments, max_length)

    # Find best matches using FAISS and then refine using DTW
    best_match_1, best_match_1_dist = find_best_match_faiss(segment, index1, vectors1, max_length)
    best_match_2, best_match_2_dist = find_best_match_faiss(segment, index2, vectors2, max_length)
    best_match_3, best_match_3_dist = find_best_match_faiss(segment, index3, vectors3, max_length)
    
    return (best_match_1, best_match_2, best_match_3), (best_match_1_dist, best_match_2_dist, best_match_3_dist)

# Example usage
segment_of_interest = segments_0[71]
timepoint1_segments = segments_1
timepoint2_segments = segments_2
timepoint3_segments = segments_3

best_indices, distances = best_match(segment_of_interest, timepoint1_segments, timepoint2_segments, timepoint3_segments)
print(f"Best matches at indices: {best_indices}")
print(f"Distances: {distances}")


Best matches at indices: (62, 50, 74)
Distances: (61679.72859467454, 64170.1844208672, 60884.17945236171)


In [130]:
segment_matches = []
segment_distances = []

for segment in segments_0:
    best_indices, distances = best_match(segment, timepoint1_segments, timepoint2_segments, timepoint3_segments)
    segment_matches.append(best_indices)
    segment_distances.append(distances)

In [131]:
len(segment_matches)

74

In [132]:
print(segment_matches[107])
print(segment_distances[107])

IndexError: list index out of range