In [None]:
import os
import torch
import numpy as np
from skimage.io import imread


In [None]:
# Directories
input_dir = "G:/Shared drives/Posner Group Current/Cole's Files/ANSA/RPA on glass slides/100_serial/processed"
output_dir = "G:/Shared drives/Posner Group Current/Cole's Files/ANSA/RPA on glass slides/100_serial/torch_tensors"
os.makedirs(output_dir, exist_ok=True)

# Function to get all image paths and their labels
def get_image_paths_and_labels(root_dir):
    image_data = []
    for label in os.listdir(root_dir):
        label_path = os.path.join(root_dir, label)
        if os.path.isdir(label_path):
            for idx, image_name in enumerate(os.listdir(label_path)):
                if image_name.lower().endswith('.tif'):
                    image_path = os.path.join(label_path, image_name)
                    image_data.append((image_path, f"{label}_{idx+1}"))
    return image_data

# Get all image paths and labels
image_data = get_image_paths_and_labels(input_dir)

# Find the largest image dimensions across all time series stacks
max_frames, max_height, max_width = 0, 0, 0
for img_path, _ in image_data:
    image_stack = imread(img_path)
    f, h, w = image_stack.shape
    max_frames = max(max_frames, f)
    max_height = max(max_height, h)
    max_width = max(max_width, w)

print(f"Largest dimensions found: {max_frames}x{max_height}x{max_width}")

# Pad and save images as .pt files
for img_path, label in image_data:
    image_stack = imread(img_path)  # Read the full 3D stack
    f, h, w = image_stack.shape

    # Calculate padding needed
    pad_frames = max_frames - f
    pad_height = max_height - h
    pad_width = max_width - w

    # Apply padding
    padded_image = np.pad(image_stack, 
                          ((0, pad_frames), (0, pad_height), (0, pad_width)), 
                          mode='constant', constant_values=0)

    # Convert to tensor and save
    tensor_image = torch.tensor(padded_image, dtype=torch.float32).unsqueeze(0)  # Add channel dimension 1 for grayscale
    save_path = os.path.join(output_dir, f"{label}.pt")
    torch.save(tensor_image, save_path)

    print(f"Saved: {save_path}")

print("Done")
