In [38]:
import csv
import tqdm
from pathlib import Path
from PIL import Image, ImageDraw
import numpy as np
import h5py
import faiss
import matplotlib.pyplot as plt

def load_bounding_boxes_csv(csv_file, imgnumber):
    bboxes = []
    with open(csv_file, "r") as f:
        reader = csv.reader(f)
        for row in reader:
            bboxes.append([imgnumber] + [int(float(x)) for x in row])  # Convert strings to floats
    return bboxes

def load_image_with_bbox(bbox, center_crop=False):
    """
    Load a 3x3 composite of tiles around the given bounding box and optionally
    return a 512x512 crop centered on the bounding box.

    Args:
        bbox (list or tuple): [imgnumber, row, col, height, width]
            - imgnumber: The image number (e.g., 5 for 'img5')
            - row, col: Top-left coordinates of the bounding box in global coordinates
            - height, width: Size of the bounding box
        center_crop (bool): If True, return a 512x512 crop centered on the bounding box.
                            If False, return the full 768x768 (3x3 tiles) composite.

    Returns:
        PIL.Image: The assembled image with bounding box drawn.
    """
    # Unpack the bounding box
    imgnumber, bbox_row, bbox_col, bbox_height, bbox_width = bbox

    TILE_SIZE = 256
    COMPOSITE_SIZE = TILE_SIZE * 3  # 768x768
    CROP_SIZE = 512
    HALF_CROP = CROP_SIZE // 2

    # Directory with tiles
    imgname = f"img{imgnumber}"
    tiles_dir = Path(f"dataset/{imgname}/tiles")

    # Determine the tile grid start
    # We find the tile that contains the top-left corner of the bbox
    tile_row_start = (bbox_row // TILE_SIZE) * TILE_SIZE
    tile_col_start = (bbox_col // TILE_SIZE) * TILE_SIZE

    # Create a composite image (3x3 tiles)
    composite_image = Image.new("RGB", (COMPOSITE_SIZE, COMPOSITE_SIZE), (255, 255, 255))

    # Load surrounding 3x3 tiles
    for i, drow in enumerate([-TILE_SIZE, 0, TILE_SIZE]):
        for j, dcol in enumerate([-TILE_SIZE, 0, TILE_SIZE]):
            tile_row = tile_row_start + drow
            tile_col = tile_col_start + dcol
            tile_name = f"tile_{int(tile_row)}_{int(tile_col)}.jpeg"
            tile_path = tiles_dir / tile_name

            if tile_path.exists():
                try:
                    tile_image = Image.open(tile_path)
                    if tile_image.mode != "RGB":
                        tile_image = tile_image.convert("RGB")
                    composite_image.paste(tile_image, (j * TILE_SIZE, i * TILE_SIZE))
                except Exception as e:
                    # If tile loading fails, use a placeholder
                    placeholder = Image.new("RGB", (TILE_SIZE, TILE_SIZE), (200, 200, 200))
                    draw_placeholder = ImageDraw.Draw(placeholder)
                    draw_placeholder.line((0, 0) + placeholder.size, fill=(150, 150, 150), width=3)
                    draw_placeholder.line((0, placeholder.size[1], placeholder.size[0], 0),
                                          fill=(150, 150, 150), width=3)
                    composite_image.paste(placeholder, (j * TILE_SIZE, i * TILE_SIZE))
            else:
                # Missing tile placeholder
                placeholder = Image.new("RGB", (TILE_SIZE, TILE_SIZE), (200, 200, 200))
                draw_placeholder = ImageDraw.Draw(placeholder)
                draw_placeholder.line((0, 0) + placeholder.size, fill=(150, 150, 150), width=3)
                draw_placeholder.line((0, placeholder.size[1], placeholder.size[0], 0),
                                      fill=(150, 150, 150), width=3)
                composite_image.paste(placeholder, (j * TILE_SIZE, i * TILE_SIZE))

    # Draw the bounding box on the composite image
    # Calculate the bounding box coordinates relative to the composite image
    # The composite image's center tile corresponds to (tile_row_start, tile_col_start) in global coords
    # Top-left tile in composite is at (tile_row_start - TILE_SIZE, tile_col_start - TILE_SIZE)
    composite_top_row = tile_row_start - TILE_SIZE
    composite_left_col = tile_col_start - TILE_SIZE

    bbox_row_rel = bbox_row - composite_top_row
    bbox_col_rel = bbox_col - composite_left_col
    bbox_bottom_rel = bbox_row_rel + bbox_height
    bbox_right_rel = bbox_col_rel + bbox_width

    draw = ImageDraw.Draw(composite_image)
    draw.rectangle([bbox_col_rel, bbox_row_rel, bbox_right_rel, bbox_bottom_rel],
                   outline="green", width=3)

    if center_crop:
        # We want to produce a 512x512 crop centered on the bbox center
        bbox_center_row = bbox_row_rel + bbox_height / 2
        bbox_center_col = bbox_col_rel + bbox_width / 2

        # Center the BBox in the crop
        # The BBox center should map to the center of the crop (256, 256)
        left = int(bbox_center_col - HALF_CROP)
        upper = int(bbox_center_row - HALF_CROP)
        right = left + CROP_SIZE
        lower = upper + CROP_SIZE

        # Ensure we don't go outside the composite image boundaries
        if left < 0:
            right -= left
            left = 0
        if upper < 0:
            lower -= upper
            upper = 0
        if right > COMPOSITE_SIZE:
            left -= (right - COMPOSITE_SIZE)
            right = COMPOSITE_SIZE
        if lower > COMPOSITE_SIZE:
            upper -= (lower - COMPOSITE_SIZE)
            lower = COMPOSITE_SIZE

        # Crop the image
        cropped_image = composite_image.crop((left, upper, right, lower))

        # If needed, pad to ensure exactly 512x512
        w, h = cropped_image.size
        if w < CROP_SIZE or h < CROP_SIZE:
            padded = Image.new("RGB", (CROP_SIZE, CROP_SIZE), (255, 255, 255))
            padded.paste(cropped_image, ((CROP_SIZE - w)//2, (CROP_SIZE - h)//2))
            cropped_image = padded

        return cropped_image
    else:
        # Return the full 3x3 composite
        return composite_image


In [None]:

# load features and bboxes
feats = []
bboxes = []
for i in tqdm.tqdm([1,2,3,5,6]):  # all slides we have
    imgname = f"img{i}"
    feat = h5py.File(Path('out') / imgname / 'masks' / 'features.h5', 'r')["dataset"][:]
    bboxs = load_bounding_boxes_csv(Path('out') / imgname / 'masks' / 'global_bboxes.txt', i)
    assert len(feat) == len(bboxs)
    feats.append(feat)
    bboxes.append(bboxs)
feats = np.concatenate(feats, axis=0)
bboxes = np.concatenate(bboxes, axis=0)



In [None]:
img = load_image_with_bbox(bboxes[3], center_crop=True)
print(img.size)
img

# compute pca

In [21]:
np.random.seed(0)
pca_matrix = faiss.PCAMatrix(feats.shape[1], 2)    
pca_matrix.train(feats)
feats_pca = pca_matrix.apply(feats)

In [None]:
num_subsample = 10000
subsampled_indices = np.random.choice(feats_pca.shape[0], 10000)
subsample = feats_pca[subsampled_indices]
plt.scatter(subsample[:,0], subsample[:,1], marker='.')

In [23]:
ncentroids = 32
niter = 32
verbose = False
dimension = 2
kmeans = faiss.Kmeans(dimension, ncentroids, niter=niter, verbose=verbose, gpu=True)
kmeans.train(subsample)
D, I = kmeans.index.search(subsample, 1)

In [None]:
colors = np.random.rand(ncentroids, 3)
plt.figure()
for i in range(ncentroids):
    plt.scatter(subsample[I[:,0]==i,0], subsample[I[:,0]==i,1], marker='.', color=colors[i])
plt.scatter(kmeans.centroids[:,0], kmeans.centroids[:,1], marker='x', color='red')

In [None]:
for i in tqdm.tqdm(range(ncentroids)):
    selected = I[:,0]==i
    cluster_subsampled_indices = subsampled_indices[selected][:16]
    to_draw_bboxes = bboxes[cluster_subsampled_indices]
    collage = np.zeros((512*4, 512*4, 3), dtype=np.uint8)
    for bbox_ind, bbox in enumerate(to_draw_bboxes):
        img = load_image_with_bbox(bbox, center_crop=True)
        r,c = bbox_ind // 4, bbox_ind % 4
        collage[r*512:(r+1)*512, c*512:(c+1)*512] = np.array(img)
    plt.figure(figsize=(10,10))
    plt.title(f"Cluster {i}")
    plt.imshow(collage)
    plt.axis('off')

    


# full kmeans (32 clust)

In [None]:
ncentroids = 32
niter = 32
verbose = False
dimension = feats.shape[1]
print('fitting kmeans...')
kmeans = faiss.Kmeans(dimension, ncentroids, niter=niter, verbose=verbose, gpu=True)
kmeans.train(feats)
print('done')
D, I = kmeans.index.search(feats, 1)
print(feats.shape, I.shape)

In [None]:
for i in tqdm.tqdm(range(ncentroids)):
    selected = I[:,0]==i
    random_indices = np.random.choice(np.where(selected)[0], 16)
    to_draw_bboxes = bboxes[random_indices]
    collage = np.zeros((512*4, 512*4, 3), dtype=np.uint8)
    for bbox_ind, bbox in enumerate(to_draw_bboxes):
        img = load_image_with_bbox(bbox, center_crop=True)
        r,c = bbox_ind // 4, bbox_ind % 4
        collage[r*512:(r+1)*512, c*512:(c+1)*512] = np.array(img)
    plt.figure(figsize=(10,10))
    plt.title(f"Cluster {i}")
    plt.imshow(collage)
    plt.axis('off')

# compute kmeans (like 4k clusters)

In [None]:
ncentroids = 1024 * 4
niter = 100
verbose = True
dimension = feats.shape[1]
kmeans = faiss.Kmeans(dimension, ncentroids, niter=niter, verbose=verbose, gpu=True)
kmeans.train(feats)

In [None]:
k = 3
index = faiss.IndexFlatL2(feats.shape[1])
D, I = index.search(feats, k)
print(I[:1])

In [None]:
D, Y = index.search(feats, k)
print(Y[2])

In [None]:
import matplotlib.pyplot as plt
import cuvs.neighbors.ivf_pq as ivfpq
import cupy as cp
params = ivfpq.IndexParams(metric='inner_product')
dataset = cp.array(feats)
index = ivfpq.build(params, dataset)

In [None]:
distances, neighbors = ivfpq.search(ivfpq.SearchParams(), index, dataset, 3)
distances
neighbors