# U-Net Model - Orbital Response

## Model Initialisation

In [None]:
import os
import torch
from torchvision import transforms
import torchvision.models as models
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision.io import decode_image
from PIL import Image

In [3]:
# Load pre-trained ResNet34 model
resnet34 = models.resnet34(pretrained=True)



In [12]:
# Replacing the first convolutionary layer to accept 6 channels (concatonated pre and post images)
resnet34.conv1 = torch.nn.Conv2d(
    in_channels=6,
    out_channels=64,
    kernel_size=7,
    stride=2,
    padding=3,
    bias=False
)


## Concatonating Images for Model Input

*Concatenate “pre” and “post” 3-channel disaster images from each pair into a single tensor with 6 channels and use it as an input to the U-Net model*

**NOTE: designed to retrieve images from local file**

Here, we are using the Pytorch Dataset class to control the loading of image data (features + labels) and the transformations that are applied (concatonation)

In [None]:
from torch.utils.data import Dataset

In [39]:
#Testing image to tensor
img_path = "../preprocessed_test/images/hurricane-florence_00000000_post_disaster.png"
img = Image.    open(img_path)
to_tensor = transforms.ToTensor()
img_tensor = to_tensor(img)

In [16]:
img_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

mask_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

In [None]:
class PrimaryDataset(Dataset):
    def __init__(self, root_dir, transform=None, mask_transform=None):
        self.image_dir = os.path.join(root_dir, "images")
        self.mask_dir = os.path.join(root_dir, "masks")
        self.transform = transform
        self.mask_transform = mask_transform

        # Find all base IDs (e.g. "hurricane-florence_00000000")
        all_files = os.listdir(self.image_dir)
        self.ids = sorted(list(set(
            f.replace("_pre_disaster.png", "").replace("_post_disaster.png", "")
            for f in all_files
            if f.endswith(".png")
        )))

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        image_id = self.ids[idx]

        # File paths
        pre_path = os.path.join(self.image_dir, f"{image_id}_pre_disaster.png")
        post_path = os.path.join(self.image_dir, f"{image_id}_post_disaster.png")
        post_mask_path = os.path.join(self.mask_dir, f"{image_id}_post_disaster_mask.png")

        # Open images
        pre_img = Image.open(pre_path)
        post_img = Image.open(post_path)
        post_mask = Image.open(post_mask_path)

        # Apply image transforms
        if self.transform:
            pre_img = self.transform(pre_img)
            post_img = self.transform(post_img)

        image = torch.cat([pre_img, post_img], dim=0)  # shape: [6, H, W]

        # Apply mask transforms
        if self.mask_transform:
            post_mask = self.mask_transform(post_mask)
        else:
            post_mask = torch.from_numpy(np.array(post_mask)).unsqueeze(0).float()

        return image, post_mask

In [63]:
dataset = PrimaryDataset("../preprocessed_test", transform=img_transform, mask_transform=mask_transform)

In [64]:
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

x_batch, y_batch = next(iter(dataloader))

UnboundLocalError: local variable 'pre_img' referenced before assignment

## Model Training 

## Model Evaluation

## Conversion of Output Mask to RGB .png File