In [1]:
import os
import cv2
import json
import tifffile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
import multiprocessing
from joblib import Parallel, delayed
from tqdm import tqdm

import torch
from torchvision import transforms
from transformers import ViTModel, ViTImageProcessor


  from .autonotebook import tqdm as notebook_tqdm


In [24]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224")
vit = ViTModel.from_pretrained("google/vit-base-patch16-224").to(device)
vit.eval()

Some weights of ViTModel were not initialized from the model checkpoint at google/vit-base-patch16-224 and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


ViTModel(
  (embeddings): ViTEmbeddings(
    (patch_embeddings): ViTPatchEmbeddings(
      (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
    )
    (dropout): Dropout(p=0.0, inplace=False)
  )
  (encoder): ViTEncoder(
    (layer): ModuleList(
      (0-11): 12 x ViTLayer(
        (attention): ViTAttention(
          (attention): ViTSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
          )
          (output): ViTSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.0, inplace=False)
          )
        )
        (intermediate): ViTIntermediate(
          (dense): Linear(in_features=768, out_features=3072, bias=True)
          (intermediate_act_fn): GELUActivation()
        )
        (output): ViTOutput(
          (d

In [26]:
DATA_PATH = "/playpen/jesse/HIPI/preprocess/data"
he_path = "/playpen/jesse/HIPI/preprocess/data/CRC03-HE.ome.tif"
csv_file = "/playpen/jesse/HIPI/preprocess/data/CRC03_new_coordinates.csv"

he_image = tifffile.imread(he_path)
df = pd.read_csv(csv_file)


In [27]:
train_df, temp_df = train_test_split(df, test_size=0.2, random_state=42)
val_df, test_df = train_test_split(temp_df, test_size=0.5, random_state=42)


In [None]:
def extract_patch(he_image, x, y, area, size=224):
    image = np.transpose(he_image, (1, 2, 0))
    radius = int(np.sqrt(area / np.pi))

    x_min, x_max = max(0, x - radius), min(image.shape[1], x + radius)
    y_min, y_max = max(0, y - radius), min(image.shape[0], y + radius)
    
    if x_max <= x_min or y_max <= y_min or x_min >= image.shape[1] or y_min >= image.shape[0]:
        patch = np.zeros((16, 16, 3), dtype=np.float32)
    else:
        patch = image[y_min:y_max, x_min:x_max]
        temp_patch = np.zeros((16, 16, 3), dtype=np.float32)
        h, w = min(16, patch.shape[0]), min(16, patch.shape[1])
        temp_patch[:h, :w] = patch[:h, :w]
        patch = temp_patch

    return cv2.resize(patch, (size, size), interpolation=cv2.INTER_AREA)


In [None]:
def extract_features(patch):
    pil_image = transforms.ToPILImage()(patch)
    inputs = processor(images=pil_image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = vit(**inputs)
        features = outputs.last_hidden_state[:, 0, :]
    return features.squeeze().cpu().numpy()

In [29]:
def process_row(row, he_image):
    x, y, area = int(row['X']), int(row['Y']), int(row['AREA'])
    patch = extract_patch(he_image, x, y, area)
    features = extract_features(patch)
    return features, (x, y)

def save_features(features, coords, filename):
    np.savez(filename, features=features, coords=coords)
    print(f"Saved features to {filename}")

In [30]:
def process_data_parallel(dataframe, he_image, output_file, n_jobs=-1):
    if n_jobs == -1:
        n_jobs = multiprocessing.cpu_count()
    
    print(f"Processing {len(dataframe)} samples using {n_jobs} CPUs...")
    
    shared_he_image = he_image.copy()
    
    results = Parallel(n_jobs=n_jobs)(
        delayed(process_row)(row, shared_he_image) 
        for _, row in tqdm(dataframe.iterrows(), total=len(dataframe))
    )
    
    features, coords = zip(*results)
    features_array = np.array(features)
    coords_array = np.array(coords)
    save_features(features_array, coords_array, output_file)
    
    return features_array, coords_array

In [31]:
def run_kmeans(features, n_clusters=5):
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    cluster_labels = kmeans.fit_predict(features)
    return kmeans, cluster_labels

In [None]:
n_cpus = -1  
    
features_file = os.path.join(DATA_PATH, "extracted_features.npz")
features, coords = process_data_parallel(df, he_image, features_file, n_jobs=n_cpus)

n_clusters = 5
kmeans, cluster_labels = run_kmeans(features, n_clusters=n_clusters)



Processing 563877 samples using 152 CPUs...


In [2]:
csv_file = "/playpen/jesse/HIPI/preprocess/data/CRC03_new_coordinates.csv"

df = pd.read_csv(csv_file)
for i, row in df.iterrows():
    cell_ids = row.name
    print(cell_ids)
    break

0
