In [1]:
from preprocessing import ProcessData
from models import SatelliteDataset
import numpy as np
from sklearn.metrics import f1_score
import torch
import satlaspretrain_models
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
import os
import tqdm
import torch.nn as nn

In [4]:
# Define patch size and stride
PATCH_SIZE = 256  # Adjust as needed (256, 512, etc.)
STRIDE = 128  # Overlapping patches

class Sentinel2SegmentationDataset(Dataset):
    def __init__(self, images, labels, patch_size=PATCH_SIZE, stride=STRIDE, transform=None):
        """
        images: Tensor or numpy array of shape (N, 12, H, W)
        labels: Tensor or numpy array of shape (N, H, W)
        patch_size: Size of the patches (default 256)
        stride: Stride for patching (default 128 for overlapping)
        transform: Optional image transformations
        """
        self.images = images
        self.labels = labels
        self.patch_size = patch_size
        self.stride = stride
        self.transform = transform
        self.patches = []  # Store (image_patch, label_patch) pairs

        self.create_patches()

    def create_patches(self):
        """Extracts patches from the dataset."""
        N, C, H, W = self.images.shape

        for i in range(N):
            img = self.images[i]  # Shape (12, H, W)
            lbl = self.labels[i]  # Shape (H, W)

            for y in range(0, H - self.patch_size + 1, self.stride):
                for x in range(0, W - self.patch_size + 1, self.stride):
                    img_patch = img[:, y:y + self.patch_size, x:x + self.patch_size]  # (12, 256, 256)
                    lbl_patch = lbl[y:y + self.patch_size, x:x + self.patch_size]  # (256, 256)
                    
                    self.patches.append((img_patch, lbl_patch))

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        img_patch, lbl_patch = self.patches[idx]

        # Apply transformations if any
        if self.transform:
            img_patch = self.transform(img_patch)

        return torch.as_tensor(img_patch, dtype=torch.float32), torch.as_tensor(lbl_patch, dtype=torch.long)


In [5]:
#test = ProcessData()
#test.preprocess()
#test.save_preprocessed()

In [6]:
test = ProcessData()
test.prepared_data, test.labels = test.load_preprocessed_data()
test.prepared_data = test.prepared_data[:, :9, :, :]

Loaded preprocessed data from /Users/bragehs/Documents/INF367A-DeforestationDrivers


In [7]:
test.prepared_data.shape

(176, 9, 1024, 1024)

In [8]:
train_sample = 20
train_data = test.prepared_data[:train_sample]
train_labels = test.labels[:train_sample]

In [9]:
val_sample = 15
val_data = test.prepared_data[train_sample:train_sample + val_sample]
val_labels = test.labels[train_sample:train_sample + val_sample]

In [10]:
test_sample = 1
test_data = test.prepared_data[174:175]
test_labels = test.labels[train_sample:val_sample + train_sample + test_sample]

In [11]:
test_data.shape

(1, 9, 1024, 1024)

In [12]:
weights_manager = satlaspretrain_models.Weights()

In [13]:
class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0):
        super(FocalLoss, self).__init__()
        self.alpha = alpha if alpha is not None else torch.tensor([0.1, 1.0, 1.0, 1.0, 1.0])
        self.gamma = gamma

    def forward(self, logits, targets):
        ce_loss = torch.nn.functional.cross_entropy(logits, targets, weight=self.alpha, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        return focal_loss.mean()



In [14]:
# Experiment arguments.
device = torch.device('cpu')
num_epochs = 20
class_weights = torch.tensor([0.1, 1.0, 1.0, 1.0, 1.0]).to(device)
criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
criterion_2 = FocalLoss()
val_step = 10  # evaluate every val_step epochs
save_path = os.path.split(os.getcwd())[0] + '/weights/'  # where to save model weights
os.makedirs(save_path, exist_ok=True)

In [15]:
save_path

'/Users/bragehs/Documents/weights/'

In [16]:
model = weights_manager.get_pretrained_model("Sentinel2_SwinT_SI_MS", fpn=True, 
                                             head=satlaspretrain_models.Head.SEGMENT, 
                                                num_categories=5, device='cpu')
model = model.to(device)

In [17]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [18]:
train_dataset = Sentinel2SegmentationDataset(train_data, train_labels)
val_dataset = Sentinel2SegmentationDataset(val_data, val_labels)
test_dataset = Sentinel2SegmentationDataset(test_data, test_labels)
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=0)


# Check dataset output
for img_patch, lbl_patch in train_dataloader:
    print("Image Patch Shape:", img_patch.shape)  # Expected: (8, 12, 256, 256)
    print("Label Patch Shape:", lbl_patch.shape)  # Expected: (8, 256, 256)
    break

Image Patch Shape: torch.Size([8, 9, 256, 256])
Label Patch Shape: torch.Size([8, 256, 256])


In [2]:
# Training loop.
writer = SummaryWriter()
for epoch in range(num_epochs):
    model.train()
    print(f"Starting Epoch... {epoch}")
    epoch_loss = 0
    num_batches = 0
    
    # Progress bar for each epoch
    progress_bar = tqdm.tqdm(train_dataloader, desc=f'Epoch {epoch}')
    
    for data, target in progress_bar:
        data = data.to(device)
        target = target.to(device)
        
        output = model(data)
        loss = criterion_2(output[0], target)
        loss.backward()
        epoch_loss += loss.item()
        num_batches += 1
 
 
        optimizer.step()
        optimizer.zero_grad()
        
        # Update progress bar
        progress_bar.set_postfix({'batch_loss': f'{loss.item():.4f}'})
    
    # Calculate and log average epoch loss
    avg_epoch_loss = epoch_loss / num_batches
    writer.add_scalar('Training/Loss', avg_epoch_loss, epoch)
    print(f"Epoch {epoch} Average Loss = {avg_epoch_loss:.4f}")

    # Validation.
    if epoch == 0 or (epoch % val_step == 0):
        model.eval()
        val_predictions = []
        val_targets = []

        with torch.no_grad():
            for val_data, val_target in val_dataloader:
                val_data = val_data.to(device)
                val_target = val_target.to(device)

                val_output, val_loss = model(val_data, val_target)

                pred = np.argmax(val_output.cpu().detach().numpy(), axis=1).flatten()
                true = ((val_target.cpu().detach().numpy()).astype(int)).flatten()
                
                val_predictions.extend(pred)
                val_targets.extend(true)

    val_f1 = f1_score(val_predictions, val_targets, average='weighted')
    print("Validation F1 score = ", val_f1)

    # Save the model checkpoint at the end of each epoch.
    torch.save(model.state_dict(), save_path + '_model_weights.pth')

NameError: name 'num_epochs' is not defined

In [19]:
def load_model(save_path):
    # 1. Initialize model architecture
    
    # 2. Load the saved weights
    weights_path = save_path + '_model_weights.pth'
    model.load_state_dict(torch.load(weights_path))
    
    # 3. Set to evaluation mode
    model.eval()
    
    return model

In [20]:
loaded_model = load_model(save_path)

In [21]:
def predict(model, test_dataset):
    test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)
    
    predictions = []
    print("Starting prediction...")
    print(test_dataset.__len__())   
    for data, target in test_dataloader:
        data = data.to(device)
        target = target.to(device)
        
        output = model(data)[0]
        print(type(output))
        pred = np.argmax(output.cpu().detach().numpy(), axis=1).flatten()
        predictions.extend(pred)
        print(pred)
    return predictions

In [139]:
def evaluate(predictions, targets):
    f1 = f1_score(predictions, targets, average='weighted')
    return f1

In [140]:
predictions = predict(loaded_model, test_dataset)

Starting prediction...
49
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 ... 1 1 1]
<class 'torch.Tensor'>
[1 1 1 .

In [144]:
test_labels.flatten().shape

(16777216,)