# Utils

In [None]:
import pickle
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import time
import pandas as pd

## convert full dataset to cell dataset

In [None]:
data = pickle.load(open('full_dataset.pkl', 'rb'))
feats, bboxes = data['feats'], data['bboxes']

annfile = "annotations_is_cell.csv"
with open(annfile, 'r') as csvfile:
    reader = csv.reader(csvfile)
    anns = list(reader)
indices, classes = np.array([int(ann[0]) for ann in anns]), np.array([int(ann[1]) for ann in anns])

# train a cell classification model
X = feats[indices]
y = classes

# Define the logistic regression model
class LogisticRegression(nn.Module):
    def __init__(self, input_dim):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(input_dim, 1)
    
    def forward(self, x):
        return torch.sigmoid(self.linear(x))

# Convert X and y to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)

# Initialize the model, loss function, and optimizer
input_dim = X.shape[1]
model = LogisticRegression(input_dim)
criterion = nn.BCELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, weight_decay=0.01)

# Train the model
num_epochs = 10000
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(X_tensor)
    loss = criterion(outputs, y_tensor)
    
    # Backward pass and optimization
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if (epoch+1) % 10 == 0:
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}', end="\r")

# predict
st = time.time()
with torch.no_grad():
    preds = model(torch.tensor(feats))
print(time.time()-st)
cell_indices = preds.flatten().numpy() > 0.5
print((cell_indices).sum())

# save the new dataset
with open("cell_dataset.pkl", "wb") as f:
    pickle.dump({"feats": feats[cell_indices], "bboxes": bboxes[cell_indices]}, f)

## convert annotations from tsv to csv binary

In [None]:
# Load bounding boxes and annotation mapping
cell_bboxes = pickle.load(open('cell_dataset.pkl', 'rb'))['bboxes']  # shape (N,5): (imgnumber, row, col, height, width)
anno_map = {1: "Image_01.vsi - 40x_BF_Z_01-points.tsv",
            2: "Image.vsi - 40x_BF_Z_01-points.tsv",
            5: "Image_05.vsi - 40x_BF_Z_01-points.tsv"}
classes = ["lymphocyte", "lymphoplasmocyte", "plasmocyte"]

# Load annotations
all_annos = {}
for k, v in anno_map.items():
    df = pd.read_csv("annotations/" + v, sep="\t")
    all_annos[k] = df

cell_bboxes = np.array(cell_bboxes)
indices = np.arange(cell_bboxes.shape[0])
centers_row_col = cell_bboxes[:, 1:3] + cell_bboxes[:, 3:5] / 2

# For each class, iterate through the annotations first
counter = 0
for c in classes:
    lines = []
    visited = set()
    
    # Iterate through each annotation file
    for imgnum, annotations in all_annos.items():
        image_mask = cell_bboxes[:, 0] == imgnum

        # Skip if no annotations for this image
        if annotations.empty:
            continue
        
        # Iterate through the annotations
        for _, anno in annotations.iterrows():
            counter += 1
            print('counter', counter, end='\r')
            anno_col, anno_row, anno_class = anno["x"], anno["y"], anno["class"]

            closest_idx_in_mask = np.argmin(np.abs(centers_row_col[image_mask] - np.array([anno_row, anno_col])).sum(axis=1))
            closest_idx = indices[image_mask][closest_idx_in_mask]
            # Calculate the distance between the annotation and each bounding box
            dist = np.abs(centers_row_col[closest_idx] - np.array([anno_row, anno_col])).sum()
            h, w = cell_bboxes[closest_idx, 3:5]

            # Check if the annotation is within the bounding box threshold
            if dist > (h + w) / 2:
                continue

            if closest_idx in visited:
                continue
            else:
                visited.add(closest_idx)
                
            # Assign label 1 if a match was found, else 0
            if anno_class == c:
                lines.append(f"{closest_idx},1")
            else:
                lines.append(f"{closest_idx},0")
    
    # Write the result for the current class
    with open(f"annotations_{c}.csv", "w") as f:
        f.write("\n".join(lines))


## visualize

In [None]:
from pathlib import Path
from PIL import Image, ImageDraw

def load_image_with_bbox(bbox, center_crop=False, root=''):
    """
    Load a 3x3 composite of tiles around the given bounding box and optionally
    return a 512x512 crop centered on the bounding box.

    Args:
        bbox (list or tuple): [imgnumber, row, col, height, width]
            - imgnumber: The image number (e.g., 5 for 'img5')
            - row, col: Top-left coordinates of the bounding box in global coordinates
            - height, width: Size of the bounding box
        center_crop (bool): If True, return a 512x512 crop centered on the bounding box.
                            If False, return the full 768x768 (3x3 tiles) composite.

    Returns:
        PIL.Image: The assembled image with bounding box drawn.
    """
    # Unpack the bounding box
    imgnumber, bbox_row, bbox_col, bbox_height, bbox_width = bbox

    TILE_SIZE = 256
    COMPOSITE_SIZE = TILE_SIZE * 3  # 768x768
    CROP_SIZE = 512
    HALF_CROP = CROP_SIZE // 2

    # Directory with tiles
    imgname = f"img{imgnumber}"
    tiles_dir = Path(root) / f"dataset/{imgname}/tiles"

    # Determine the tile grid start
    # We find the tile that contains the top-left corner of the bbox
    tile_row_start = (bbox_row // TILE_SIZE) * TILE_SIZE
    tile_col_start = (bbox_col // TILE_SIZE) * TILE_SIZE

    # Create a composite image (3x3 tiles)
    composite_image = Image.new(
        "RGB", (COMPOSITE_SIZE, COMPOSITE_SIZE), (255, 255, 255)
    )

    # Load surrounding 3x3 tiles
    for i, drow in enumerate([-TILE_SIZE, 0, TILE_SIZE]):
        for j, dcol in enumerate([-TILE_SIZE, 0, TILE_SIZE]):
            tile_row = tile_row_start + drow
            tile_col = tile_col_start + dcol
            tile_name = f"tile_{int(tile_row)}_{int(tile_col)}.jpeg"
            tile_path = tiles_dir / tile_name

            if tile_path.exists():
                try:
                    tile_image = Image.open(tile_path)
                    if tile_image.mode != "RGB":
                        tile_image = tile_image.convert("RGB")
                    composite_image.paste(
                        tile_image, (j * TILE_SIZE, i * TILE_SIZE)
                    )
                except Exception as e:
                    # If tile loading fails, use a placeholder
                    placeholder = Image.new(
                        "RGB", (TILE_SIZE, TILE_SIZE), (200, 200, 200)
                    )
                    draw_placeholder = ImageDraw.Draw(placeholder)
                    draw_placeholder.line(
                        (0, 0) + placeholder.size, fill=(150, 150, 150), width=3
                    )
                    draw_placeholder.line(
                        (0, placeholder.size[1], placeholder.size[0], 0),
                        fill=(150, 150, 150),
                        width=3,
                    )
                    composite_image.paste(
                        placeholder, (j * TILE_SIZE, i * TILE_SIZE)
                    )
            else:
                # Missing tile placeholder
                placeholder = Image.new(
                    "RGB", (TILE_SIZE, TILE_SIZE), (200, 200, 200)
                )
                draw_placeholder = ImageDraw.Draw(placeholder)
                draw_placeholder.line(
                    (0, 0) + placeholder.size, fill=(150, 150, 150), width=3
                )
                draw_placeholder.line(
                    (0, placeholder.size[1], placeholder.size[0], 0),
                    fill=(150, 150, 150),
                    width=3,
                )
                composite_image.paste(placeholder, (j * TILE_SIZE, i * TILE_SIZE))

    # Draw the bounding box on the composite image
    # Calculate the bounding box coordinates relative to the composite image
    # The composite image's center tile corresponds to (tile_row_start, tile_col_start) in global coords
    # Top-left tile in composite is at (tile_row_start - TILE_SIZE, tile_col_start - TILE_SIZE)
    composite_top_row = tile_row_start - TILE_SIZE
    composite_left_col = tile_col_start - TILE_SIZE

    bbox_row_rel = bbox_row - composite_top_row
    bbox_col_rel = bbox_col - composite_left_col
    bbox_bottom_rel = bbox_row_rel + bbox_height
    bbox_right_rel = bbox_col_rel + bbox_width

    draw = ImageDraw.Draw(composite_image)
    draw.rectangle(
        [bbox_col_rel, bbox_row_rel, bbox_right_rel, bbox_bottom_rel],
        outline="green",
        width=3,
    )

    if center_crop:
        # We want to produce a 512x512 crop centered on the bbox center
        bbox_center_row = bbox_row_rel + bbox_height / 2
        bbox_center_col = bbox_col_rel + bbox_width / 2

        # Center the BBox in the crop
        # The BBox center should map to the center of the crop (256, 256)
        left = int(bbox_center_col - HALF_CROP)
        upper = int(bbox_center_row - HALF_CROP)
        right = left + CROP_SIZE
        lower = upper + CROP_SIZE

        # Ensure we don't go outside the composite image boundaries
        if left < 0:
            right -= left
            left = 0
        if upper < 0:
            lower -= upper
            upper = 0
        if right > COMPOSITE_SIZE:
            left -= right - COMPOSITE_SIZE
            right = COMPOSITE_SIZE
        if lower > COMPOSITE_SIZE:
            upper -= lower - COMPOSITE_SIZE
            lower = COMPOSITE_SIZE

        # Crop the image
        cropped_image = composite_image.crop((left, upper, right, lower))

        # If needed, pad to ensure exactly 512x512
        w, h = cropped_image.size
        if w < CROP_SIZE or h < CROP_SIZE:
            padded = Image.new("RGB", (CROP_SIZE, CROP_SIZE), (255, 255, 255))
            padded.paste(
                cropped_image, ((CROP_SIZE - w) // 2, (CROP_SIZE - h) // 2)
            )
            cropped_image = padded

        return cropped_image
    else:
        return composite_image


import matplotlib.pyplot as plt

# e.g. show annotations
cell_bboxes = pickle.load(open('cell_dataset.pkl', 'rb'))['bboxes']
annotations = pd.read_csv("annotations_lymphoplasmocyte.csv", header=None)

print(annotations)
counter = 0
for _, ann in annotations.iterrows():
    if ann[1] == 1:
        idx = ann[0]
        bbox = cell_bboxes[idx]
        img = load_image_with_bbox(bbox, center_crop=True, root='.')
        plt.figure()
        plt.title(f"Index: {idx}")
        plt.imshow(img)
        counter += 1
        if counter > 10:
            break

# certainty = torch.abs(preds-0.5).flatten()
# cell_probs = preds.flatten()
# nocell_probs = 1 - cell_probs
# selected_indices = torch.argsort(-certainty, descending=True).flatten().numpy()

# for i in range(5):
#     idx = selected_indices[i]
#     bbox = bboxes[idx]
#     img = load_image_with_bbox(bbox, center_crop=True, root='.')
#     plt.figure()
#     plt.title(f"Index: {idx}, Prediction: {preds[idx].item()}")
#     plt.imshow(img)