In [1]:
from preprocessing import ProcessData
import numpy as np
from sklearn.metrics import f1_score
import torch
from torch.utils.data import DataLoader, Dataset
import os
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from torch.utils.tensorboard import SummaryWriter
import timm
#import satlaspretrain_models
from models import UNet


In [43]:
# Define patch size and stride
PATCH_SIZE = 128  # Adjust as needed (256, 512, etc.)
STRIDE = 128  # Overlapping patches

class Sentinel2SegmentationDataset(Dataset):
    def __init__(self, images, labels, patch_size=PATCH_SIZE, stride=STRIDE, transform=None):
        """
        images: Tensor or numpy array of shape (N, 12, H, W)
        labels: Tensor or numpy array of shape (N, H, W)
        patch_size: Size of the patches (default 256)
        stride: Stride for patching (default 128 for overlapping)
        transform: Optional image transformations
        """
        self.images = images
        self.labels = labels
        self.patch_size = patch_size
        self.stride = stride
        self.transform = transform
        self.patches = []  # Store (image_patch, label_patch) pairs

        self.create_patches()

    def create_patches(self):
        """Extracts patches from the dataset."""
        N, C, H, W = self.images.shape

        for i in range(N):
            img = self.images[i]  # Shape (12, H, W)
            lbl = self.labels[i]  # Shape (H, W)

            # Divide the 1024 image dimension into 512x512 patches with no overlap
            for y in range(0, H, self.patch_size):
                for x in range(0, W, self.patch_size):
                    img_patch = img[:, y:y+self.patch_size, x:x+self.patch_size]
                    lbl_patch = lbl[y:y+self.patch_size, x:x+self.patch_size]
                    self.patches.append((img_patch, lbl_patch))

    def __len__(self):
        return len(self.patches)

    def __getitem__(self, idx):
        img_patch, lbl_patch = self.patches[idx]

        # Apply transformations if any
        if self.transform:
            img_patch = self.transform(img_patch)

        return torch.as_tensor(img_patch, dtype=torch.float32), torch.as_tensor(lbl_patch, dtype=torch.long)


In [3]:
#data = ProcessData()
#data.preprocess()
#data.save_preprocessed()

In [4]:
data = ProcessData()
data.prepared_data, data.labels = data.load_preprocessed_data()

Loaded preprocessed data from /Users/bragehs/Documents/INF367A-DeforestationDrivers


In [5]:
if torch.cuda.is_available():
    device = torch.device("cuda")  # Uses the first available GPU
    print("Using:", torch.cuda.get_device_name(0))

In [6]:
#bands = [1, 2, 3, 4, 5, 6, 7, 10, 11]
#prepared_data = data.prepared_data[:, bands]

In [7]:
prepared_data = data.prepared_data.astype('float32') / 10000
#prepared_data = prepared_data.astype('float32') / 10000

In [9]:
labels = data.labels

In [10]:
prepared_data.shape, labels.shape

((176, 12, 1024, 1024), (176, 1024, 1024))

In [11]:
from sklearn.model_selection import train_test_split

# First split into train and temp (val + test)
X_train, X_val, y_train, y_val = train_test_split(prepared_data, labels, test_size=0.15, random_state=42)

In [31]:
device = torch.device('cpu')

In [12]:
UNet_model = UNet(in_channels=12, num_classes=5)

  init.xavier_normal(m.weight)
  init.constant(m.bias, 0)


In [13]:
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(residual)
        out = self.relu(out)
        return out

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ASPP, self).__init__()
        self.aspp1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.aspp2 = nn.Conv2d(in_channels, out_channels, 3, padding=6, dilation=6, bias=False)
        self.aspp3 = nn.Conv2d(in_channels, out_channels, 3, padding=12, dilation=12, bias=False)
        self.aspp4 = nn.Conv2d(in_channels, out_channels, 3, padding=18, dilation=18, bias=False)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.conv1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        
        self.conv_out = nn.Conv2d(out_channels * 5, out_channels, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        size = x.size()[2:]
        x1 = self.aspp1(x)
        x2 = self.aspp2(x)
        x3 = self.aspp3(x)
        x4 = self.aspp4(x)
        x5 = self.avg_pool(x)
        x5 = self.conv1(x5)
        x5 = F.interpolate(x5, size=size, mode='bilinear', align_corners=True)
        x = torch.cat((x1, x2, x3, x4, x5), dim=1)
        x = self.conv_out(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class EnhancedSatelliteSegmentationModel(nn.Module):
    def __init__(self, in_channels=12, num_classes=5):
        super(EnhancedSatelliteSegmentationModel, self).__init__()
        
        # Initial Conv Block
        self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        # Encoder path with residual blocks
        self.enc1 = ResidualBlock(64, 64)
        self.enc2 = ResidualBlock(64, 128, stride=2)
        self.enc3 = ResidualBlock(128, 256, stride=2)
        self.enc4 = ResidualBlock(256, 512, stride=2)
        
        # ASPP module
        self.aspp = ASPP(512, 256)
        
        # Decoder path with skip connections
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(256, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True)
        )
        self.skip1 = nn.Conv2d(256, 256, kernel_size=1)
        
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True)
        )
        self.skip2 = nn.Conv2d(128, 128, kernel_size=1)
        
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True)
        )
        self.skip3 = nn.Conv2d(64, 64, kernel_size=1)
        
        # Additional upsampling to match input resolution
        self.final_up1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.final_up2 = nn.Sequential(
            nn.ConvTranspose2d(32, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True)
        )
        
        self.final_conv = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, num_classes, kernel_size=1)
        )
        
        self.dropout = nn.Dropout2d(p=0.1)

    def forward(self, x):
        # Store input size for potential interpolation
        input_size = x.size()[2:]
        
        # Initial convolution
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # Encoder path with skip connections
        skip1 = self.enc1(x)
        skip2 = self.enc2(skip1)
        skip3 = self.enc3(skip2)
        x = self.enc4(skip3)
        
        # ASPP module
        x = self.aspp(x)
        
        # Decoder path with skip connections
        x = self.dec1(x)
        x = x + self.skip1(skip3)
        x = self.dropout(x)
        
        x = self.dec2(x)
        x = x + self.skip2(skip2)
        x = self.dropout(x)
        
        x = self.dec3(x)
        x = x + self.skip3(skip1)
        x = self.dropout(x)
        
        # Additional upsampling to match input resolution
        x = self.final_up1(x)
        x = self.final_up2(x)
        x = self.final_conv(x)
        
        # Optional: force output to match input spatial dimensions exactly
        if x.size()[2:] != input_size:
            x = F.interpolate(x, size=input_size, mode='bilinear', align_corners=True)
            
        return x
# Example instantiation (for testing):
model = EnhancedSatelliteSegmentationModel(in_channels=12, num_classes=5)
print(model)

EnhancedSatelliteSegmentationModel(
  (conv1): Conv2d(12, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (enc1): ResidualBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (shortcut): Sequential()
  )
  (enc2): ResidualBlock(
    (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (re

In [17]:
x = torch.randn(16, 12, 256, 256) 
output = model(x)
print(output.shape)

torch.Size([16, 5, 256, 256])


In [15]:
x = torch.randn(16, 12, 256, 256) 
output = UNet_model(x)
print(output.shape)

torch.Size([16, 5, 256, 256])


In [20]:
len(prepared_data)

176

In [44]:
batch_size = 8
train_dataset = Sentinel2SegmentationDataset(X_train[:7], y_train[:7])
val_dataset = Sentinel2SegmentationDataset(X_val[:7], y_val[:7])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)



# Check dataset output
for img_patch, lbl_patch in train_dataloader:
    print("Image Patch Shape:", img_patch.shape)  # Expected: (batch_size, 9, 512, 512)
    print("Label Patch Shape:", lbl_patch.shape)  # Expected: (batch_size, 512, 512)
    break

Image Patch Shape: torch.Size([8, 12, 128, 128])
Label Patch Shape: torch.Size([8, 128, 128])


In [45]:
train_dataset.__len__(), val_dataset.__len__()

(448, 448)

In [28]:
X_train.shape

(149, 12, 1024, 1024)

In [46]:
import numpy as np
from sklearn.utils.class_weight import compute_class_weight
# Convert each label tensor to a NumPy array
all_labels = np.concatenate([
    y.detach().cpu().numpy().flatten() if torch.is_tensor(y) else np.array(y).flatten()
    for _, y in train_dataset
])
classes = np.unique(all_labels)
weights = compute_class_weight('balanced', classes=classes, y=all_labels)
print("Class weights:", weights, "for classes:", classes)

Class weights: [ 0.28784642  1.11019248  1.74212239 28.21244571 63.60237425] for classes: [0 1 2 3 4]


In [32]:
weights = torch.tensor(weights, dtype=torch.float32).to(device)

In [23]:
def print_model_weights(model):
    # Print first layer weights from a few layers
    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"{name}: mean={param.data.mean():.4f}, std={param.data.std():.4f}")

In [24]:
print_model_weights(model)

conv1.weight: mean=-0.0000, std=0.0238
bn1.weight: mean=1.0000, std=0.0000
enc1.conv1.weight: mean=0.0001, std=0.0241
enc1.bn1.weight: mean=1.0000, std=0.0000
enc1.conv2.weight: mean=-0.0001, std=0.0240
enc1.bn2.weight: mean=1.0000, std=0.0000
enc2.conv1.weight: mean=0.0000, std=0.0241
enc2.bn1.weight: mean=1.0000, std=0.0000
enc2.conv2.weight: mean=0.0001, std=0.0170
enc2.bn2.weight: mean=1.0000, std=0.0000
enc2.shortcut.0.weight: mean=0.0002, std=0.0726
enc2.shortcut.1.weight: mean=1.0000, std=0.0000
enc3.conv1.weight: mean=-0.0001, std=0.0170
enc3.bn1.weight: mean=1.0000, std=0.0000
enc3.conv2.weight: mean=-0.0000, std=0.0120
enc3.bn2.weight: mean=1.0000, std=0.0000
enc3.shortcut.0.weight: mean=0.0001, std=0.0511
enc3.shortcut.1.weight: mean=1.0000, std=0.0000
enc4.conv1.weight: mean=-0.0000, std=0.0120
enc4.bn1.weight: mean=1.0000, std=0.0000
enc4.conv2.weight: mean=-0.0000, std=0.0085
enc4.bn2.weight: mean=1.0000, std=0.0000
enc4.shortcut.0.weight: mean=-0.0001, std=0.0360
enc4.sh

In [33]:
# Experiment arguments.
num_epochs = 40
val_step = 1  # evaluate every val_step epochs
#loss_fn = nn.CrossEntropyLoss(weight=weights)
save_path = os.path.split(os.getcwd())[0] + '/weights/'  # where to save model weights
os.makedirs(save_path, exist_ok=True)

In [71]:
class WeightedSegmentationLoss(nn.Module):
    def __init__(self, num_classes, weights=None):
        super().__init__()
        if weights is not None:
            weights = torch.tensor(weights, dtype=torch.float32)
        self.ce = nn.CrossEntropyLoss(weight=weights)
        self.num_classes = num_classes
        
    def forward(self, pred, target):
        # Ensure pred and target are float32
        pred = pred.float()
        target = target.long()  # CrossEntropyLoss expects long targets
        
        # Basic cross entropy
        ce_loss = self.ce(pred, target)
        
        # Add auxiliary losses if needed
        aux_loss = 0
        
        return ce_loss + aux_loss


In [95]:
class LogCoshDiceLoss(nn.Module):
    def __init__(self, num_classes, epsilon=1e-6):
        super(LogCoshDiceLoss, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon

    def forward(self, pred, target):
        # pred shape: (B, C, H, W)
        # target shape: (B, H, W)
        
        # Convert predictions to probabilities
        pred_probs = F.softmax(pred, dim=1)  # (B, C, H, W)
        
        dice_scores = []
        for cls in range(self.num_classes):
            # Create binary masks for each class
            pred_cls = pred_probs[:, cls]  # (B, H, W)
            target_cls = (target == cls).float()  # (B, H, W)
            
            # Calculate intersection and union
            intersection = (pred_cls * target_cls).sum(dim=(1, 2))
            cardinality = pred_cls.sum(dim=(1, 2)) + target_cls.sum(dim=(1, 2))
            
            # Calculate Dice coefficient
            dice = (2. * intersection + self.epsilon) / (cardinality + self.epsilon)
            dice_scores.append(dice)
        
        # Stack dice scores for all classes
        dice_scores = torch.stack(dice_scores, dim=1)  # (B, C)
        
        # Compute Log-Cosh loss
        loss = torch.log(torch.cosh(1. - dice_scores))
        
        return loss.mean()

In [91]:
class WeightedIoULoss(nn.Module):
    def __init__(self, weights=None):
        super(WeightedIoULoss, self).__init__()
        self.weights = weights

    def forward(self, pred, target):
        # pred shape: (B, C, H, W)
        # target shape: (B, H, W)
        
        B, C, H, W = pred.shape
        
        # Convert predictions to probabilities
        pred = F.softmax(pred, dim=1)
        
        # One-hot encode target
        target = target.permute(0, 3, 1, 2).float()
        
        # Calculate intersection and union
        intersection = (pred * target).sum(dim=(2, 3))  # (B, C)
        union = pred.sum(dim=(2, 3)) + target.sum(dim=(2, 3)) - intersection  # (B, C)
        
        # Calculate IoU
        iou = (intersection + 1e-7) / (union + 1e-7)  # (B, C)
        
        if self.weights is not None:
            # Make sure weights is a tensor on the same device
            if not isinstance(self.weights, torch.Tensor):
                self.weights = torch.tensor(self.weights, device=pred.device)
            iou = iou * self.weights  # Apply class weights
        
        # Average over classes and batch
        iou_loss = 1 - iou.mean()
        
        return iou_loss

In [37]:
def validate(model, val_loader, criterion):
    model.eval()
    val_loss = 0
    with torch.no_grad():
        for batch_idx, (images, targets) in enumerate(val_loader):
            images, targets = images.to(device), targets.to(device)
            outputs = model(images)
            if isinstance(outputs, tuple):
                outputs = outputs[0]  # Get main output if model returns multiple outputs

            # Ensure correct shape before loss calculation
            if outputs.shape[0] != targets.shape[0]:
                outputs = outputs.permute(0, 2, 1)
            loss = criterion(outputs, targets)
            val_loss += loss.item()

    return val_loss / len(val_loader)

In [38]:
from tqdm.auto import tqdm 

In [39]:
train_dataloader.dataset[0][0].shape

torch.Size([12, 256, 256])

In [98]:
def train_with_warmup(model, train_loader, val_loader, num_epochs, filename,
                      num_warm_up= 5, learning_rate = 1e-3, patience = 5,
                      weight_decay = 1e-2):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr = learning_rate, weight_decay = weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=patience)
    criterion = LogCoshDiceLoss(num_classes=5)

    # Warmup phase
    model.train()
    model = model.float()
    print("Starting warmup phase...")
    for epoch in range(num_warm_up):  # epochs of warm up:
        with tqdm(train_loader, desc=f'Warmup Epoch {epoch+1}/5') as pbar:
            for batch_idx, (images, targets) in enumerate(pbar):
                images, targets = images.to(device), targets.to(device)
                optimizer.zero_grad()
                outputs = model(images)
                # Ensure correct shape before loss calculation
                if outputs.shape[0] != targets.shape[0]:
                    outputs.permute(2, 3, 0, 1).contiguous().view(-1, 5)

                loss = criterion(outputs, targets) / 4  # Accumulate gradients over 4 steps
                loss.backward()

                if (batch_idx + 1) % 4 == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    optimizer.zero_grad()

                pbar.set_postfix({'loss': loss.item() * 4})  # Show accumulated loss
    optimizer.zero_grad()
    print("Starting main training phase...")
    # Main training phase
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0
        # Training loop with progress bar
        with tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}') as pbar:
            for batch_idx, (images, targets) in enumerate(pbar):
                images, targets = images.to(device).float(), targets.to(device).float()
                optimizer.zero_grad()
                outputs = model(images)

                # Ensure correct shape before loss calculation
                if outputs.shape[0] != targets.shape[0]:
                    outputs.permute(2, 3, 0, 1).contiguous().view(-1, 5)
                loss = criterion(outputs, targets)
                loss.backward()

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()

                train_loss += loss.item()
                pbar.set_postfix({'train_loss': loss.item()})

        # Calculate average training loss
        avg_train_loss = train_loss / len(train_loader)

        # Validation phase
        val_loss = validate(model, val_loader, criterion)
        scheduler.step(val_loss)

        # Save best model
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), save_path + filename)

        # Print epoch results
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {avg_train_loss:.4f}')
        print(f'Val Loss: {val_loss:.4f}')
        print(f'Learning Rate: {optimizer.param_groups[0]["lr"]:.6f}')
        print('-' * 50)


In [99]:
train_with_warmup(
    model=UNet_model,
    train_loader=train_dataloader, 
    val_loader=val_dataloader,      
    num_epochs=10,
    filename = 'unet.pth',
    num_warm_up=0
)

Starting warmup phase...
Starting main training phase...


Epoch 1/10:   0%|          | 0/56 [00:00<?, ?it/s]

Epoch 1/10:
Train Loss: 0.3294
Val Loss: 0.3497
Learning Rate: 0.001000
--------------------------------------------------


Epoch 2/10:   0%|          | 0/56 [00:00<?, ?it/s]

Epoch 2/10:
Train Loss: 0.3285
Val Loss: 0.3467
Learning Rate: 0.001000
--------------------------------------------------


Epoch 3/10:   0%|          | 0/56 [00:00<?, ?it/s]

Epoch 3/10:
Train Loss: 0.3262
Val Loss: 0.3295
Learning Rate: 0.001000
--------------------------------------------------


Epoch 4/10:   0%|          | 0/56 [00:00<?, ?it/s]

Epoch 4/10:
Train Loss: 0.0960
Val Loss: 0.1071
Learning Rate: 0.001000
--------------------------------------------------


Epoch 5/10:   0%|          | 0/56 [00:00<?, ?it/s]

Epoch 5/10:
Train Loss: 0.0802
Val Loss: 0.1071
Learning Rate: 0.001000
--------------------------------------------------


Epoch 6/10:   0%|          | 0/56 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [103]:
def load_model(save_path, model_name):
    # 1. Initialize model architecture
    
    # 2. Load the saved weights
    weights_path = save_path + model_name
    model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
    
    # 3. Set to evaluation mode
    model.eval()
    
    return model

In [104]:
save_path = os.path.split(os.getcwd())[0] + '/weights/'  # where to save model weights
loaded_model = load_model(save_path, "unet.pth")

RuntimeError: Error(s) in loading state_dict for EnhancedSatelliteSegmentationModel:
	Missing key(s) in state_dict: "conv1.weight", "bn1.weight", "bn1.bias", "bn1.running_mean", "bn1.running_var", "enc1.conv1.weight", "enc1.bn1.weight", "enc1.bn1.bias", "enc1.bn1.running_mean", "enc1.bn1.running_var", "enc1.conv2.weight", "enc1.bn2.weight", "enc1.bn2.bias", "enc1.bn2.running_mean", "enc1.bn2.running_var", "enc2.conv1.weight", "enc2.bn1.weight", "enc2.bn1.bias", "enc2.bn1.running_mean", "enc2.bn1.running_var", "enc2.conv2.weight", "enc2.bn2.weight", "enc2.bn2.bias", "enc2.bn2.running_mean", "enc2.bn2.running_var", "enc2.shortcut.0.weight", "enc2.shortcut.1.weight", "enc2.shortcut.1.bias", "enc2.shortcut.1.running_mean", "enc2.shortcut.1.running_var", "enc3.conv1.weight", "enc3.bn1.weight", "enc3.bn1.bias", "enc3.bn1.running_mean", "enc3.bn1.running_var", "enc3.conv2.weight", "enc3.bn2.weight", "enc3.bn2.bias", "enc3.bn2.running_mean", "enc3.bn2.running_var", "enc3.shortcut.0.weight", "enc3.shortcut.1.weight", "enc3.shortcut.1.bias", "enc3.shortcut.1.running_mean", "enc3.shortcut.1.running_var", "enc4.conv1.weight", "enc4.bn1.weight", "enc4.bn1.bias", "enc4.bn1.running_mean", "enc4.bn1.running_var", "enc4.conv2.weight", "enc4.bn2.weight", "enc4.bn2.bias", "enc4.bn2.running_mean", "enc4.bn2.running_var", "enc4.shortcut.0.weight", "enc4.shortcut.1.weight", "enc4.shortcut.1.bias", "enc4.shortcut.1.running_mean", "enc4.shortcut.1.running_var", "aspp.aspp1.weight", "aspp.aspp2.weight", "aspp.aspp3.weight", "aspp.aspp4.weight", "aspp.conv1.weight", "aspp.conv_out.weight", "aspp.bn.weight", "aspp.bn.bias", "aspp.bn.running_mean", "aspp.bn.running_var", "dec1.0.weight", "dec1.0.bias", "dec1.1.weight", "dec1.1.bias", "dec1.1.running_mean", "dec1.1.running_var", "skip1.weight", "skip1.bias", "dec2.0.weight", "dec2.0.bias", "dec2.1.weight", "dec2.1.bias", "dec2.1.running_mean", "dec2.1.running_var", "skip2.weight", "skip2.bias", "dec3.0.weight", "dec3.0.bias", "dec3.1.weight", "dec3.1.bias", "dec3.1.running_mean", "dec3.1.running_var", "skip3.weight", "skip3.bias", "final_up1.0.weight", "final_up1.0.bias", "final_up1.1.weight", "final_up1.1.bias", "final_up1.1.running_mean", "final_up1.1.running_var", "final_up2.0.weight", "final_up2.0.bias", "final_up2.1.weight", "final_up2.1.bias", "final_up2.1.running_mean", "final_up2.1.running_var", "final_conv.0.weight", "final_conv.0.bias", "final_conv.1.weight", "final_conv.1.bias", "final_conv.1.running_mean", "final_conv.1.running_var", "final_conv.3.weight", "final_conv.3.bias". 
	Unexpected key(s) in state_dict: "conv_final.weight", "conv_final.bias", "down_convs.0.conv1.weight", "down_convs.0.conv1.bias", "down_convs.0.conv2.weight", "down_convs.0.conv2.bias", "down_convs.1.conv1.weight", "down_convs.1.conv1.bias", "down_convs.1.conv2.weight", "down_convs.1.conv2.bias", "down_convs.2.conv1.weight", "down_convs.2.conv1.bias", "down_convs.2.conv2.weight", "down_convs.2.conv2.bias", "down_convs.3.conv1.weight", "down_convs.3.conv1.bias", "down_convs.3.conv2.weight", "down_convs.3.conv2.bias", "down_convs.4.conv1.weight", "down_convs.4.conv1.bias", "down_convs.4.conv2.weight", "down_convs.4.conv2.bias", "up_convs.0.upconv.weight", "up_convs.0.upconv.bias", "up_convs.0.conv1.weight", "up_convs.0.conv1.bias", "up_convs.0.conv2.weight", "up_convs.0.conv2.bias", "up_convs.1.upconv.weight", "up_convs.1.upconv.bias", "up_convs.1.conv1.weight", "up_convs.1.conv1.bias", "up_convs.1.conv2.weight", "up_convs.1.conv2.bias", "up_convs.2.upconv.weight", "up_convs.2.upconv.bias", "up_convs.2.conv1.weight", "up_convs.2.conv1.bias", "up_convs.2.conv2.weight", "up_convs.2.conv2.bias", "up_convs.3.upconv.weight", "up_convs.3.upconv.bias", "up_convs.3.conv1.weight", "up_convs.3.conv1.bias", "up_convs.3.conv2.weight", "up_convs.3.conv2.bias". 

In [9]:
print(loaded_model)

Model(
  (backbone): ResnetBackbone(
    (resnet): ResNet(
      (conv1): Conv2d(9, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu): ReLU(inp

In [23]:
def predict(model, test_dataloader):
    print("Starting prediction...")
    with torch.no_grad():
        model.eval()
        predictions_flattened = []
        targets = []
        predictions = []
        with torch.no_grad():
            for data, target in test_dataloader:
                data = data.to(device)
                target = target.to(device)

                output = model(data)
                pred = np.argmax(output.cpu().detach().numpy(), axis=1)
                true = ((target.cpu().detach().numpy()).astype(int))
                print('unique predictions:', np.unique(pred), 'unique labels:', np.unique(true))
                predictions_flattened.extend(pred.flatten())
                targets.extend(true.flatten())
                predictions.extend(pred)
    f1 = f1_score(predictions_flattened, targets, average='weighted')
    print("F1 score = ", f1)
    return predictions

In [24]:
def evaluate(predictions, targets):
    f1 = f1_score(predictions, targets, average='weighted')
    return f1

In [101]:
def predict_probs(model, test_dataloader):
    """
    Predicts the probabilities of each class for each pixel in the image.

    Args:
        model (torch.nn.Module): The trained model.
        test_dataloader (torch.utils.data.DataLoader): The DataLoader for the test dataset.
        device (str): The device to use for prediction (e.g., 'cuda' or 'cpu').

    Returns:
        list: A list of predictions. Each element in the list corresponds to an image,
              and contains 64 patches of predictions, each patch containing 128x128
              probabilities for each class.
    """
    print("Starting prediction...")
    with torch.no_grad():
        model.eval()
        predictions = []
        for data, _ in test_dataloader:  # Changed to _ because target is not used
            data = data.to(device)

            output = model(data)[0]  # Assuming model returns a tuple, take the first element
            probs = torch.nn.functional.softmax(output, dim=1).cpu().numpy()  # Convert to probabilities

            # Iterate through the batch
            batch_predictions = []
            for batch_idx in range(data.shape[0]):
                # Get coordinates of each pixel in the image
                height, width = data.shape[2], data.shape[3]
                x_coords, y_coords = np.meshgrid(np.arange(height), np.arange(width), indexing='ij')
                x_coords = x_coords.flatten()
                y_coords = y_coords.flatten()

                # Combine coordinates and probabilities
                image_predictions = []
                for i in range(len(x_coords)):
                    x = x_coords[i]
                    y = y_coords[i]
                    pixel_probs = probs[batch_idx, :, x, y].tolist()  # Extract probabilities for the pixel
                    image_predictions.append((x, y, pixel_probs))  # Append as tuple
                batch_predictions.append(image_predictions)
            predictions.append(batch_predictions)

        return predictions


In [100]:
test_dataset = Sentinel2SegmentationDataset(prepared_data[70:71], labels[70:71])
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

In [45]:
device = torch.device('cpu')

In [46]:
predictions = predict_probs(loaded_model, test_dataloader)

Starting prediction...


In [61]:
test_dataloader.__len__()

1

In [62]:
len(predictions), len(predictions[0]), len(predictions[0][0])

(1, 16, 65536)

In [47]:
predictions

[[[(0,
    0,
    [0.3730228841304779,
     0.15259605646133423,
     0.16907618939876556,
     0.15266335010528564,
     0.15264159440994263]),
   (0,
    1,
    [0.3982568383216858,
     0.1496375948190689,
     0.15282943844795227,
     0.14963789284229279,
     0.14963820576667786]),
   (0,
    2,
    [0.40153706073760986,
     0.1492329239845276,
     0.15076415240764618,
     0.1492329239845276,
     0.14923293888568878]),
   (0,
    3,
    [0.40372976660728455,
     0.14895857870578766,
     0.14939454197883606,
     0.14895857870578766,
     0.14895857870578766]),
   (0,
    4,
    [0.40394991636276245,
     0.14893083274364471,
     0.14925751090049744,
     0.14893083274364471,
     0.14893083274364471]),
   (0,
    5,
    [0.4043707847595215,
     0.14887775480747223,
     0.14899590611457825,
     0.14887775480747223,
     0.14887775480747223]),
   (0,
    6,
    [0.4043619632720947,
     0.14887887239456177,
     0.14900138974189758,
     0.14887887239456177,
     0.148878

In [48]:
def reconstruct_from_patches(predictions, original_size=(1024, 1024), patch_size=128):
    """
    Reconstructs a full image from patches with probability lists.

    Args:
        predictions (list): A list of lists of lists of tuples. The outer list corresponds to images,
                            the middle list corresponds to batches, and the inner list contains tuples
                            of (x-coordinate, y-coordinate, probability list).
        original_size (tuple): The size of the original image (height, width). Default is (1024, 1024).
        patch_size (int): The size of the patches. Default is 128.

    Returns:
        numpy.ndarray: A 3D numpy array representing the reconstructed image with shape (1024, 1024, 5).
                       Each pixel contains a probability distribution over the 5 classes.
    """
    # Initialize an empty array to hold the reconstructed image
    full_img = np.zeros((original_size[0], original_size[1], 5))  # Shape: (1024, 1024, 5)

    # Iterate through the images
    for image_idx, image_predictions in enumerate(predictions):
        # Iterate through the batches
        for batch_idx, batch_patches in enumerate(image_predictions):
            # Iterate through the patches and place them in the full image
            for patch_data in batch_patches:
                x, y, probs = patch_data  # Unpack the tuple

                # Calculate the patch indices
                patch_x_idx = (image_idx * len(image_predictions) + batch_idx) // (original_size[0] // patch_size)
                patch_y_idx = (image_idx * len(image_predictions) + batch_idx) % (original_size[1] // patch_size)

                # Calculate the actual pixel coordinates in the full image
                pixel_x = x + patch_x_idx * patch_size
                pixel_y = y + patch_y_idx * patch_size

                # Place the probability list into the corresponding pixel in the full image
                full_img[pixel_x, pixel_y, :] = probs  # Assign the probability list to the pixel

    return full_img

In [49]:
reconstructed_image = reconstruct_from_patches(predictions, patch_size=256)

In [50]:
reconstructed_image

array([[[0.37302288, 0.15259606, 0.16907619, 0.15266335, 0.15264159],
        [0.39825684, 0.14963759, 0.15282944, 0.14963789, 0.14963821],
        [0.40153706, 0.14923292, 0.15076415, 0.14923292, 0.14923294],
        ...,
        [0.3977589 , 0.14969815, 0.15314646, 0.14969815, 0.14969829],
        [0.38398221, 0.15131298, 0.16207212, 0.15131283, 0.15131989],
        [0.3455472 , 0.15526834, 0.18719296, 0.15533042, 0.15666106]],

       [[0.39539313, 0.14998488, 0.15465155, 0.14998485, 0.14998555],
        [0.40352476, 0.14898436, 0.14952219, 0.14898436, 0.14898436],
        [0.40448773, 0.148863  , 0.14892329, 0.148863  , 0.148863  ],
        ...,
        [0.40425923, 0.14889185, 0.14906526, 0.14889185, 0.14889185],
        [0.40134531, 0.14925677, 0.15088437, 0.14925677, 0.14925677],
        [0.38897282, 0.15074347, 0.15877783, 0.1507435 , 0.15076227]],

       [[0.4019312 , 0.14918385, 0.15051726, 0.14918385, 0.14918388],
        [0.40451849, 0.1488591 , 0.14890419, 0.1488591 , 0.1

In [51]:
import pickle

with open('reconstructed_image.pkl', 'wb') as f:
    pickle.dump(reconstructed_image, f)

In [71]:
import matplotlib.pyplot as plt

In [72]:
y_test = labels[70:71]

In [73]:
y_test.shape

(1, 1024, 1024)

In [74]:
from PIL import Image



# Assuming y_test is a numpy array or torch tensor containing the
# class labels for the entire 1024x1024 image.
# Example:
# y_test = ... # Your test labels (numpy array or torch tensor)

# If y_test is a torch tensor, convert it to a numpy array
if isinstance(y_test, torch.Tensor):
    y_test = y_test.cpu().numpy()  # Move to CPU if it's on GPU

# Ensure y_test is of integer type
y_test = y_test.astype(np.uint8)

# Create a color mapping for the classes
color_map = {
    0: [0, 0, 0],      # background: black
    1: [255, 0, 0],    # plantation: red
    2: [0, 255, 0],    # grassland_shrubland: green
    3: [0, 0, 255],    # mining: blue
    4: [255, 255, 0]     # logging: yellow
}

# Create an RGB image where each class is represented by its color
height, width = y_test[0].shape
rgb_image = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(height):
    for j in range(width):
        class_id = y_test[0][i, j]
        rgb_image[i, j] = color_map[class_id]

# Create a PIL image from the numpy array
img = Image.fromarray(rgb_image)

# Save the image to a file
img.save("rasterized_image.png")

print("Rasterized image saved to rasterized_image.png")

Rasterized image saved to rasterized_image.png
