In [23]:
# Configuration
import os
import cv2


IMAGE_DIR = r"Datasets\Padded"  # Directory where 4500x4500 images are stored
CSV_PATH = "eccentricity_data.csv"  # CSV file with bounding boxes
PATCH_SIZE = 1024  # Size of tiles to extract from the full image
STRIDE = 1024      # Stride for sliding window tiling

# Output directory to store tiles and processed annotations
PATCH_OUTPUT_DIR = "tiles/"
ANNOTATION_OUTPUT_PATH = "tile_annotations.csv"

os.makedirs(PATCH_OUTPUT_DIR, exist_ok=True)


In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from PIL import Image
import numpy as np
import torchvision.transforms as transforms

class StreakDataset(Dataset):
    def __init__(self, csv_file, image_dir, transform=None, patch_size=256, stride=128):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            image_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied.
            patch_size (int): Size of patches to extract from large image
            stride (int): Stride for patch extraction
        """
        self.annotations = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transform = transform
        self.patch_size = patch_size
        self.stride = stride
        
        # Group annotations by image
        self.image_groups = self.annotations.groupby('image')
        
        # Precompute all possible patches for all images
        self.patches = []
        for img_name, group in self.image_groups:
            img_path = f"{self.image_dir}/{img_name}"
            img_patches = self._extract_patches(img_path)
            self.patches.extend(img_patches)
    
    def _extract_patches(self, img_path):
        patches = []
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)  # Load as grayscale
        
        if img is None:
            raise ValueError(f"Failed to load image at {img_path}")
        
        height, width = img.shape  # OpenCV uses (height, width) format
        
        for y in range(0, height - self.patch_size + 1, self.stride):
            for x in range(0, width - self.patch_size + 1, self.stride):
                patch = img[y:y+self.patch_size, x:x+self.patch_size]
                patch = Image.fromarray(patch)  # Convert back to PIL for compatibility
                patches.append((patch, (x, y)))
        
        return patches
    
    def __len__(self):
        return len(self.patches)
    
    def __getitem__(self, idx):
        patch, (x_offset, y_offset) = self.patches[idx]
        img_name = self.image_groups.groups.keys()[0]  # Simplified
        
        if self.transform:
            patch = self.transform(patch)
        
        # Get annotations for this patch
        img_annotations = self.image_groups.get_group(img_name)
        patch_annotations = []
        
        for _, row in img_annotations.iterrows():
            # Check if annotation is within this patch
            x, y, w, h = row['bbox_x'], row['bbox_y'], row['bbox_width'], row['bbox_height']
            if (x_offset <= x < x_offset + self.patch_size and 
                y_offset <= y < y_offset + self.patch_size):
                # Convert to patch coordinates
                patch_x = x - x_offset
                patch_y = y - y_offset
                patch_annotations.append([patch_x, patch_y, w, h, 0 if row['object_type'] == 'star' else 1])
        
        # Convert to tensor
        target = {
            'boxes': torch.tensor([ann[:4] for ann in patch_annotations], dtype=torch.float32),
            'labels': torch.tensor([ann[4] for ann in patch_annotations], dtype=torch.int64)
        }
        
        return patch, target

In [27]:
import torchvision
from torchvision.models.detection import FasterRCNN
from torchvision.models.detection.rpn import AnchorGenerator
from torchvision.models.detection.transform import GeneralizedRCNNTransform
from torch import nn
import math

class StreakDetectionModel(nn.Module):
    def __init__(self, backbone_name='resnet50', pretrained=True):
        super().__init__()
        
        # Load backbone
        if backbone_name == 'resnet50':
            backbone = torchvision.models.resnet50(pretrained=pretrained)
            # Remove the fully connected layers
            backbone = nn.Sequential(*list(backbone.children())[:-2])
            backbone.out_channels = 2048
        else:
            raise ValueError(f"Unsupported backbone: {backbone_name}")
        
        # Custom anchor sizes and aspect ratios optimized for stars/streaks
        # Stars are small and square, streaks are elongated
        anchor_sizes = ((8,), (16,), (32,), (64,), (128,))
        aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
        
        anchor_generator = AnchorGenerator(
            sizes=anchor_sizes,
            aspect_ratios=aspect_ratios
        )
        
        # ROI pooling
        roi_pooler = torchvision.ops.MultiScaleRoIAlign(
            featmap_names=['0'],
            output_size=7,
            sampling_ratio=2
        )
        
        # Box head
        box_head = self._create_box_head(backbone.out_channels)
        
        # Faster R-CNN model
        self.model = FasterRCNN(
            backbone,
            num_classes=3,  # background, star, streak
            rpn_anchor_generator=anchor_generator,
            box_roi_pool=roi_pooler,
            box_head=box_head,
            box_predictor=None,  # Will be created internally
            transform=CustomTransform()  # See below
        )
        
        # Add directional prediction for streaks
        self.direction_predictor = nn.Sequential(
            nn.Linear(backbone.out_channels * 7 * 7, 256),
            nn.ReLU(),
            nn.Linear(256, 1)  # Predicts angle in radians
        )
    
    def _create_box_head(self, in_channels):
        return nn.Sequential(
            nn.Linear(in_channels * 7 * 7, 1024),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(1024, 1024),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
    
    def forward(self, images, targets=None):
        if self.training and targets is None:
            raise ValueError("In training mode, targets should be passed")
        
        # Standard Faster R-CNN forward pass
        outputs = self.model(images, targets)
        
        if not self.training:
            # Add direction prediction for streaks in inference mode
            for output in outputs:
                if 'boxes' in output:
                    # Get ROI features
                    features = self.model.roi_heads.box_roi_pool(
                        [self.model.backbone(img.unsqueeze(0))['0'] for img in images],
                        [output['boxes']],
                        [img.shape[-2:] for img in images]
                    )
                    features = features.flatten(1)
                    
                    # Predict direction
                    directions = self.direction_predictor(features)
                    output['directions'] = directions
        
        return outputs


class CustomTransform(GeneralizedRCNNTransform):
    def __init__(self):
        super().__init__(
            min_size=512,  # Minimum size of the image
            max_size=512,   # Maximum size of the image
            image_mean=[0.5],  # Grayscale mean
            image_std=[0.5]    # Grayscale std
        )
    
    def postprocess(self, result, image_shapes, original_image_sizes):
        # Override to handle our custom outputs
        if isinstance(result, dict) and 'directions' in result:
            # Convert boxes to original image space
            boxes = result['boxes']
            directions = result['directions']
            
            # Scale boxes back to original image size
            # (implementation depends on your exact needs)
            
            return {
                'boxes': boxes,
                'labels': result['labels'],
                'scores': result['scores'],
                'directions': directions
            }
        return super().postprocess(result, image_shapes, original_image_sizes)

In [32]:
def collate_fn(batch):
    return tuple(zip(*batch))

def get_transform(train):
    transform_list = [transforms.ToTensor()]  # Start with converting to tensor
    
    if train:
        # Add data augmentation only for training
        transform_list.extend([
            transforms.RandomHorizontalFlip(0.5),
            transforms.RandomVerticalFlip(0.5),
            transforms.RandomRotation(10)
        ])
    
    # Combine all transforms
    return transforms.Compose(transform_list)

def train_model():
    # Initialize dataset and dataloader
    dataset = StreakDataset(
        csv_file=CSV_PATH,
        image_dir=IMAGE_DIR,
        transform=get_transform(train=True),
        patch_size=256,
        stride=128
    )
    
    data_loader = DataLoader(
        dataset,
        batch_size=4,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn
    )
    
    # Initialize model
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    model = StreakDetectionModel(backbone_name='resnet50', pretrained=True)
    model.to(device)
    
    # Optimizer and learning rate scheduler
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = torch.optim.SGD(params, lr=0.005, momentum=0.9, weight_decay=0.0005)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
    
    # Loss function (included in Faster R-CNN)
    
    # Training loop
    num_epochs = 10
    for epoch in range(num_epochs):
        model.train()
        for images, targets in data_loader:
            images = list(image.to(device) for image in images)
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
        
        lr_scheduler.step()
        
        # Evaluation (implement as needed)
        # evaluate(model, validation_loader, device)
    
    # Save model
    torch.save(model.state_dict(), 'streak_detection_model.pth')

In [33]:
def detect_streaks_full_image(model, image_path, patch_size=512, stride=256, device='cuda'):
    # Load image
    img = Image.open(image_path).convert('L')
    width, height = img.size
    
    # Initialize results
    all_boxes = []
    all_labels = []
    all_scores = []
    all_directions = []
    
    # Process image in patches
    for y in range(0, height - patch_size + 1, stride):
        for x in range(0, width - patch_size + 1, stride):
            patch = img.crop((x, y, x + patch_size, y + patch_size))
            patch_tensor = transforms.ToTensor()(patch).unsqueeze(0).to(device)
            
            # Inference
            with torch.no_grad():
                output = model(patch_tensor)[0]
            
            # Convert boxes to original image coordinates
            boxes = output['boxes'].cpu().numpy()
            boxes[:, [0, 2]] += x
            boxes[:, [1, 3]] += y
            
            all_boxes.extend(boxes)
            all_labels.extend(output['labels'].cpu().numpy())
            all_scores.extend(output['scores'].cpu().numpy())
            
            if 'directions' in output:
                all_directions.extend(output['directions'].cpu().numpy())
    
    # Apply non-maximum suppression to remove overlapping boxes
    if len(all_boxes) > 0:
        keep_indices = torchvision.ops.nms(
            torch.tensor(all_boxes),
            torch.tensor(all_scores),
            iou_threshold=0.5
        )
        
        final_boxes = [all_boxes[i] for i in keep_indices]
        final_labels = [all_labels[i] for i in keep_indices]
        final_scores = [all_scores[i] for i in keep_indices]
        final_directions = [all_directions[i] for i in keep_indices] if all_directions else None
        
        return {
            'boxes': final_boxes,
            'labels': final_labels,
            'scores': final_scores,
            'directions': final_directions
        }
    return None

In [34]:
train_model()

AttributeError: 'NoneType' object has no attribute 'size'

In [16]:
def tile_image_and_annotations(image_path, annotations_df, patch_size=1024, stride=1024):
    image_name = os.path.basename(image_path)
    img = cv2.imread(image_path)
    h, w, _ = img.shape
    
    image_annotations = annotations_df[annotations_df['image'] == image_name]
    saved_tiles = []
    
    tile_id = 0
    new_annotations = []

    for y in range(0, h, stride):
        for x in range(0, w, stride):
            tile = img[y:y+patch_size, x:x+patch_size]
            if tile.shape[0] < patch_size or tile.shape[1] < patch_size:
                continue  # skip edge tiles

            tile_boxes = []
            for _, row in image_annotations.iterrows():
                xmin, ymin = row['bbox_x'], row['bbox_y']
                xmax = xmin + row['bbox_width']
                ymax = ymin + row['bbox_height']

                # Check if bbox intersects the tile
                if xmax < x or xmin > x+patch_size or ymax < y or ymin > y+patch_size:
                    continue  # no intersection

                # Convert to local coordinates
                new_xmin = max(xmin - x, 0)
                new_ymin = max(ymin - y, 0)
                new_xmax = min(xmax - x, patch_size)
                new_ymax = min(ymax - y, patch_size)

                box_width = new_xmax - new_xmin
                box_height = new_ymax - new_ymin
                if box_width > 0 and box_height > 0:
                    tile_boxes.append({
                        "tile_name": f"{image_name}_tile_{tile_id}.png",
                        "bbox_x": new_xmin,
                        "bbox_y": new_ymin,
                        "bbox_width": box_width,
                        "bbox_height": box_height,
                        "object_type": row["object_type"]
                    })

            if tile_boxes:
                tile_filename = f"{image_name}_tile_{tile_id}.png"
                tile_path = os.path.join(PATCH_OUTPUT_DIR, tile_filename)
                cv2.imwrite(tile_path, tile)
                saved_tiles.append(tile_filename)
                new_annotations.extend(tile_boxes)
                tile_id += 1

    return new_annotations


In [17]:
# Load annotations
annotations_df = pd.read_csv(CSV_PATH)

# For all images in the directory
all_new_annotations = []
for image_file in os.listdir(IMAGE_DIR):
    if image_file.lower().endswith(('.png', '.jpg', '.jpeg', '.tif', '.tiff')):
        image_path = os.path.join(IMAGE_DIR, image_file)
        print(f"Tiling {image_file}...")
        new_anns = tile_image_and_annotations(image_path, annotations_df)
        all_new_annotations.extend(new_anns)

# Save new tile-level annotations
tile_df = pd.DataFrame(all_new_annotations)
tile_df.to_csv(ANNOTATION_OUTPUT_PATH, index=False)
print(f"Saved {len(tile_df)} annotations for tiled patches.")


Tiling Raw_Observation_009_Set1.tiff...
Saved 31 annotations for tiled patches.
