In [1]:
from preprocessing import PreProcessing
import satlaspretrain_models
from training import *
from postprocessing import PostProcessing

  check_for_updates()


In [2]:
weights_manager = satlaspretrain_models.Weights()
satlas_model = weights_manager.get_pretrained_model("Sentinel2_SwinT_SI_MS", fpn=True, 
                                             head=satlaspretrain_models.Head.SEGMENT, 
                                                num_categories=5, device= "cpu")
device = torch.device("cpu")
satlas_model = satlas_model.to(device)

In [3]:
def load_model(model, save_path, model_name):
    # 1. Initialize model architecture
    
    # 2. Load the saved weights
    weights_path = save_path + model_name
    state = torch.load(weights_path, map_location=torch.device('cpu'))

    model.load_state_dict(state['model_state_dict'])
    #model.load_state_dict(state)
    # 3. Set to evaluation mode
    model.eval()
    
    return model

In [4]:
loaded_model = load_model(satlas_model, '/Users/bragehs/Documents/weights/', "best_combined_loss.pth")

In [5]:
print(loaded_model)

Model(
  (backbone): SwinBackbone(
    (backbone): SwinTransformer(
      (features): Sequential(
        (0): Sequential(
          (0): Conv2d(9, 96, kernel_size=(4, 4), stride=(4, 4))
          (1): Permute()
          (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
        )
        (1): Sequential(
          (0): SwinTransformerBlockV2(
            (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
            (attn): ShiftedWindowAttentionV2(
              (qkv): Linear(in_features=96, out_features=288, bias=True)
              (proj): Linear(in_features=96, out_features=96, bias=True)
              (cpb_mlp): Sequential(
                (0): Linear(in_features=2, out_features=512, bias=True)
                (1): ReLU(inplace=True)
                (2): Linear(in_features=512, out_features=3, bias=False)
              )
            )
            (stochastic_depth): StochasticDepth(p=0.0, mode=row)
            (norm2): LayerNorm((96,), eps=1e-05, elementwise_af

In [6]:
test_data = PreProcessing(train_set=False)
test_data.preprocess()

Processing evaluation_0.tif
Processing evaluation_1.tif
Processing evaluation_2.tif
Processing evaluation_3.tif
Processing evaluation_4.tif
Processing evaluation_5.tif
Processing evaluation_6.tif
Processing evaluation_7.tif
Processing evaluation_8.tif
Processing evaluation_9.tif
Processing evaluation_10.tif
Processing evaluation_11.tif
Processing evaluation_12.tif
Processing evaluation_13.tif
Processing evaluation_14.tif
Processing evaluation_15.tif
Processing evaluation_16.tif
Processing evaluation_17.tif
Processing evaluation_18.tif
Processing evaluation_19.tif
Processing evaluation_20.tif
Processing evaluation_21.tif
Processing evaluation_22.tif
Processing evaluation_23.tif
Processing evaluation_24.tif
Processing evaluation_25.tif
Processing evaluation_26.tif
Processing evaluation_27.tif
Processing evaluation_28.tif
Processing evaluation_29.tif
Processing evaluation_30.tif
Processing evaluation_31.tif
Processing evaluation_32.tif
Processing evaluation_33.tif
Processing evaluation_34

In [7]:
test_data = test_data.prepared_data
test_data.shape

(118, 9, 1024, 1024)

In [8]:
post = PostProcessing(loaded_model, test_data[:2])

In [None]:
#post_opt = post.post_process_optimized(patch_size=64)
#polygons = post.converter(post_opt)

In [21]:
class TestDataset(Dataset):
    """Dataset class for test data without labels"""
    def __init__(
        self,
        images: np.ndarray,  # Shape: (N, C, H, W)
        patch_size: int = 256,
        stride: int = 256
    ):
        self.images = images
        self.patch_size = patch_size
        self.stride = stride
        self.patches, self.positions = self._create_patches()

    def _create_patches(self):
        """Create patches and store their original positions"""
        patches = []
        positions = []  # Store (image_idx, y, x) for each patch
        N, C, H, W = self.images.shape

        for img_idx in range(N):
            image = self.images[img_idx]
            for y in range(0, H - self.patch_size + 1, self.stride):
                for x in range(0, W - self.patch_size + 1, self.stride):
                    img_patch = image[:, y:y + self.patch_size, x:x + self.patch_size]
                    patches.append(img_patch)
                    positions.append((img_idx, y, x))

        return patches, positions

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        return torch.as_tensor(self.patches[idx], dtype=torch.float32)

In [70]:
test_dataset = TestDataset(
    images=test_data[:2],
    patch_size=128,
    stride=128
)
test_dataloader = DataLoader(
                            test_dataset, 
                            batch_size=64, 
                            shuffle=False, 
                            num_workers=0,
    )

In [118]:
from time import sleep

In [147]:
def predict_probs(model, test_dataloader):
    model.eval()
    all_probs = []
    with torch.no_grad():
        for data in test_dataloader:
            data = data.to(device)
            output = model(data)[0]  # Shape: (batch_size, 5, 128, 128)
            probs = torch.nn.functional.softmax(output, dim=1)
            # Permute to (batch_size, 128, 128, 5) and collect
            probs = probs.permute(0, 2, 3, 1).cpu().numpy()
            all_probs.append(probs)
    # Concatenate all batches into (num_patches, 128, 128, 5)
    return np.concatenate(all_probs, axis=0)

In [148]:
preds = predict_probs(loaded_model, test_dataloader)

In [164]:
def stitch_patches(prob_patches, positions, image_shape=(1024, 1024), patch_size=128):
    num_classes = prob_patches.shape[-1]
    stitched_images = []
    # Extract unique image indices from positions
    image_indices = sorted(set(pos[0] for pos in positions))
    for img_idx in image_indices:
        full_image = np.zeros((image_shape[0], image_shape[1], num_classes))
        # Iterate through all patches and place them if they belong to the current image
        for i, (curr_idx, y, x) in enumerate(positions):
            if curr_idx == img_idx:
                full_image[y:y+patch_size, x:x+patch_size, :] = prob_patches[i]
        stitched_images.append(full_image)
    return [torch.tensor(img) for img in stitched_images]


def post_process_torch(outputs, gamma=0.5):
    all_images = []
    for img in outputs:  # Each img is a tensor of shape (1024, 1024, 5)
        # Pad the image to handle borders
        padded = torch.nn.functional.pad(img, (0, 0, 1, 1, 1, 1))  # Pad H and W by 1
        
        # Compute contributions from 4-directional neighbors
        top = padded[:-2, 1:-1, :]    # Shape: (1024, 1024, 5)
        bottom = padded[2:, 1:-1, :]
        left = padded[1:-1, :-2, :]
        right = padded[1:-1, 2:, :]
        
        # Combine neighbor contributions and add to original probabilities
        neighbors = gamma * (top + bottom + left + right)
        combined_probs = img + neighbors
        
        # Get final predictions
        class_predictions = torch.argmax(combined_probs, dim=2)
        all_images.append(class_predictions)
    
    return all_images  # List of (1024, 1024) tensors with class indices

In [156]:
positions = test_dataloader.dataset.positions  # Get positions from dataset
stitched_tensors = stitch_patches(preds, positions)

In [163]:
stitched_tensors[1][0][1023]

tensor([0.2092, 0.1700, 0.2911, 0.1693, 0.1603], dtype=torch.float64)

In [165]:
post_processed = post_process_torch(stitched_tensors)

In [175]:
import cv2

In [190]:
def converter(tensors):
    """
    Convert multiple image tensors to polygons.
    
    Args:
        tensors: List of tensors or single tensor. Each tensor should be (1024, 1024) 
                containing class labels 0-4 (either numpy array or torch.Tensor)
    
    Returns:
        list[dict]: List of dictionaries, one per image, where each dictionary
                contains polygons for each class {0: [...], 1: [...], ...}
    """
    # Handle single tensor case and convert to numpy
    if isinstance(tensors, (np.ndarray, torch.Tensor)):
        tensors = [tensors]
    
    all_image_polygons = []
    
    for idx, tensor in enumerate(tensors):
        print("Converting tensor to polygons for image", idx)
        image_polygons = {}
        
        # Convert to numpy and ensure proper 2D format
        if isinstance(tensor, torch.Tensor):
            # Handle both CPU and GPU tensors
            tensor_np = tensor.cpu().detach().numpy().squeeze().astype(np.uint8)
            print(tensor_np.shape)
        else:
            tensor_np = tensor.squeeze().astype(np.uint8)
        
        # Critical validation
        if tensor_np.ndim != 2:
            raise ValueError(f"Input tensor must be 2D after squeezing. Got {tensor_np.shape}")
        
        for class_id in range(5):
            # Create binary mask
            mask = (tensor_np == class_id).astype(np.uint8)
            
            # Verify OpenCV requirements
            if not isinstance(mask, np.ndarray):
                raise TypeError(f"Mask must be numpy array, got {type(mask)}")
            if mask.dtype != np.uint8:
                mask = mask.astype(np.uint8)
            
            # Find contours
            contours, _ = cv2.findContours(
                mask, 
                cv2.RETR_EXTERNAL, 
                cv2.CHAIN_APPROX_SIMPLE
            )
            
            # Convert and filter contours
            image_polygons[class_id] = [
                contour.squeeze().tolist() 
                for contour in contours 
                if contour.shape[0] >= 3  # Minimum 3 points for polygon
            ]
        
        all_image_polygons.append(image_polygons)
    
    return all_image_polygons

In [191]:
polygons = converter(post_processed)

Converting tensor to polygons for image 0
(1024, 1024)


TypeError: Mask must be numpy array, got <class 'numpy.ndarray'>