In [60]:
from preprocessing import ProcessData
from models import SatelliteDataset
import numpy as np
from sklearn.metrics import f1_score
import torch
import satlaspretrain_models
from torch.utils.data import DataLoader, Dataset
import os

In [55]:
# Define patch size and stride
PATCH_SIZE = 256  # Adjust as needed (256, 512, etc.)
STRIDE = 128  # Overlapping patches

class Sentinel2SegmentationDataset(Dataset):
    def __init__(self, images, labels, patch_size=PATCH_SIZE, stride=STRIDE, transform=None):
        """
        images: Tensor or numpy array of shape (N, 12, H, W)
        labels: Tensor or numpy array of shape (N, H, W)
        patch_size: Size of the patches (default 256)
        stride: Stride for patching (default 128 for overlapping)
        transform: Optional image transformations
        """
        self.images = images
        self.labels = labels
        self.patch_size = patch_size
        self.stride = stride
        self.transform = transform
        self.patches = []  # Store (image_patch, label_patch) pairs

        self.create_patches()

    def create_patches(self):
        """Extracts patches from the dataset."""
        N, C, H, W = self.images.shape

        for i in range(N):
            img = self.images[i]  # Shape (12, H, W)
            lbl = self.labels[i]  # Shape (H, W)

            for y in range(0, H - self.patch_size + 1, self.stride):
                for x in range(0, W - self.patch_size + 1, self.stride):
                    img_patch = img[:, y:y + self.patch_size, x:x + self.patch_size]  # (12, 256, 256)
                    lbl_patch = lbl[y:y + self.patch_size, x:x + self.patch_size]  # (256, 256)
                    
                    self.patches.append((img_patch, lbl_patch))

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        img_patch, lbl_patch = self.patches[idx]

        # Apply transformations if any
        if self.transform:
            img_patch = self.transform(img_patch)

        return torch.as_tensor(img_patch, dtype=torch.float32), torch.as_tensor(lbl_patch, dtype=torch.long)


In [2]:
#test = ProcessData()
#test.preprocess()
#test.save_preprocessed()

In [40]:
test = ProcessData()
test.prepared_data, test.labels = test.load_preprocessed_data()
test.prepared_data = test.prepared_data[:, :9, :, :]

Loaded preprocessed data from /Users/bragehs/Documents/INF367A-DeforestationDrivers


In [41]:
test.prepared_data.shape

(176, 9, 1024, 1024)

In [42]:
train_sample = 5
train_data = test.prepared_data[:train_sample]
train_labels = test.labels[:train_sample]
train_dataset = SatelliteDataset(train_data, train_labels)

In [43]:
val_sample = 2
val_data = test.prepared_data[train_sample:train_sample + val_sample]
val_labels = test.labels[train_sample:train_sample + val_sample]
val_dataset = SatelliteDataset(val_data, val_labels)

In [44]:
test_sample = 1
test_data = test.prepared_data[train_sample:val_sample + train_sample + test_sample]
test_labels = test.labels[train_sample:val_sample + train_sample + test_sample]
test_dataset = SatelliteDataset(test_data, test_labels)

In [45]:
weights_manager = satlaspretrain_models.Weights()

In [46]:
# Experiment arguments.
device = torch.device('cpu')
num_epochs = 1
criterion = torch.nn.CrossEntropyLoss()
val_step = 1  # evaluate every val_step epochs
save_path = 'weights/'  # where to save model weights
os.makedirs(save_path, exist_ok=True)

In [47]:
model = weights_manager.get_pretrained_model("Sentinel2_SwinT_SI_MS", fpn=True, head=satlaspretrain_models.Head.SEGMENT, 
                                                num_categories=5, device='cpu')
model = model.to(device)

In [48]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

In [56]:
train_dataset = Sentinel2SegmentationDataset(train_data, train_labels)
val_dataset = Sentinel2SegmentationDataset(val_data, val_labels)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)

# Check dataset output
for img_patch, lbl_patch in val_dataloader:
    print("Image Patch Shape:", img_patch.shape)  # Expected: (8, 12, 256, 256)
    print("Label Patch Shape:", lbl_patch.shape)  # Expected: (8, 256, 256)
    break

Image Patch Shape: torch.Size([2, 9, 256, 256])
Label Patch Shape: torch.Size([2, 256, 256])


In [80]:
# Training loop.
for epoch in range(num_epochs):
    print("Starting Epoch...", epoch)

    for data, target in train_dataloader:
        data = data.to(device)
        target = target.to(device)

        output, loss = model(data, target)
        print("Train Loss = ", loss)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        break

    # Validation.
    if epoch % val_step == 0:
        model.eval()
        val_predictions = []
        val_targets = []

        for val_data, val_target in val_dataloader:
            val_data = val_data.to(device)
            val_target = val_target.to(device)

            val_output, val_loss = model(val_data, val_target)

            pred = np.argmax(val_output.cpu().detach().numpy(), axis=1).flatten()
            true = ((val_target.cpu().detach().numpy()).astype(int)).flatten()
        #print(val_predictions)

            val_f1 = f1_score(true, pred, average='weighted')
            print("Validation F1 score = ", val_f1)

        # Save the model checkpoint at the end of each epoch.
        torch.save(model.state_dict(), save_path + str(epoch) + '_model_weights.pth')

Starting Epoch... 0
Train Loss =  tensor(0.8449, grad_fn=<MeanBackward0>)
Validation F1 score =  0.5500531786614055
