# DrainageAI Workflow

Complete workflow for DrainageAI: setup, indices calculation, BYOL training, and inference.

In [None]:
import torch
import numpy as np
import os
import matplotlib.pyplot as plt
import rasterio
from google.colab import files
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.models as models
import copy
from tqdm import tqdm
import random

print(f"GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU name: {torch.cuda.get_device_name(0)}")

!pip install rasterio geopandas scikit-image matplotlib tqdm
!mkdir -p data/imagery data/indices results

## 1. Upload and Prepare Data

In [None]:
print("Upload your multispectral imagery file (GeoTIFF format):")
uploaded = files.upload()

# Save uploaded files
for filename in uploaded.keys():
    with open(f"data/imagery/{filename}", 'wb') as f:
        f.write(uploaded[filename])
    print(f"Saved {filename} to data/imagery/")
    
# Get the first uploaded file
imagery_filename = list(uploaded.keys())[0]
imagery_path = f"data/imagery/{imagery_filename}"

# Check image properties
with rasterio.open(imagery_path) as src:
    num_bands = src.count
    height = src.height
    width = src.width
    
print(f"Image has {num_bands} bands, dimensions: {width}x{height}")

## 2. Calculate Spectral Indices

In [None]:
def calculate_indices(imagery_path, output_path, indices="ndvi,msavi2", 
                     red_band=3, nir_band=4):
    print(f"Calculating spectral indices for {imagery_path}...")
    indices_to_calculate = indices.lower().split(",")
    
    with rasterio.open(imagery_path) as src:
        num_bands = src.count
        red_band = min(red_band, num_bands)
        nir_band = min(nir_band, num_bands) if num_bands >= 4 else red_band
        
        red = src.read(red_band)
        nir = src.read(nir_band)
        meta = src.meta.copy()
    
    calculated_indices = []
    band_names = []
    
    if "ndvi" in indices_to_calculate:
        ndvi = (nir - red) / (nir + red + 1e-8)
        calculated_indices.append(ndvi)
        band_names.append("NDVI")
    
    if "msavi2" in indices_to_calculate:
        msavi2 = (2 * nir + 1 - np.sqrt((2 * nir + 1)**2 - 8 * (nir - red))) / 2
        calculated_indices.append(msavi2)
        band_names.append("MSAVI2")
    
    indices_stack = np.stack(calculated_indices)
    meta.update({
        'count': len(calculated_indices),
        'dtype': 'float32'
    })
    
    with rasterio.open(output_path, 'w', **meta) as dst:
        for i, (index, name) in enumerate(zip(calculated_indices, band_names), 1):
            dst.write(index.astype(np.float32), i)
            dst.set_band_description(i, name)
    
    print(f"Indices saved to {output_path}")
    return indices_stack

# Calculate indices
indices_path = "data/indices/spectral_indices.tif"
indices = calculate_indices(
    imagery_path=imagery_path,
    output_path=indices_path,
    indices="ndvi,msavi2",
    red_band=3,
    nir_band=4
)

## 3. BYOL Model Implementation

In [None]:
class BYOLProjector(nn.Module):
    def __init__(self, in_features, hidden_dim=4096, out_dim=256):
        super(BYOLProjector, self).__init__()
        self.projection = nn.Sequential(
            nn.Linear(in_features, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def forward(self, x):
        return self.projection(x)

class BYOLPredictor(nn.Module):
    def __init__(self, in_dim=256, hidden_dim=4096, out_dim=256):
        super(BYOLPredictor, self).__init__()
        self.predictor = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, out_dim)
        )
        
    def forward(self, x):
        return self.predictor(x)

class BYOLModel(nn.Module):
    def __init__(self, pretrained=True, with_indices=True, momentum=0.99):
        super(BYOLModel, self).__init__()
        
        self.online_encoder = models.resnet50(pretrained=pretrained)
        
        if with_indices:
            first_conv = self.online_encoder.conv1
            new_conv = nn.Conv2d(
                in_channels=3 + 2,  # RGB + 2 indices
                out_channels=64,
                kernel_size=7,
                stride=2,
                padding=3,
                bias=False
            )
            
            with torch.no_grad():
                new_conv.weight[:, :3] = first_conv.weight
                nn.init.kaiming_normal_(
                    new_conv.weight[:, 3:],
                    mode='fan_out',
                    nonlinearity='relu'
                )
            
            self.online_encoder.conv1 = new_conv
        
        self.online_encoder = nn.Sequential(*list(self.online_encoder.children())[:-1])
        self.online_projector = BYOLProjector(2048)
        self.predictor = BYOLPredictor()
        self.target_encoder = copy.deepcopy(self.online_encoder)
        self.target_projector = copy.deepcopy(self.online_projector)
        
        for param in self.target_encoder.parameters():
            param.requires_grad = False
        for param in self.target_projector.parameters():
            param.requires_grad = False
            
        self.momentum = momentum
        self.prediction_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1),
            nn.Sigmoid()
        )
        
        self.fine_tuned = False
        self.with_indices = with_indices
    
    def forward(self, x):
        features = self.online_encoder(x)
        features = torch.flatten(features, 1)
        if self.fine_tuned:
            return self.prediction_head(features)
        return features
    
    def byol_forward(self, x):
        online_features = self.online_encoder(x)
        online_features = torch.flatten(online_features, 1)
        online_proj = self.online_projector(online_features)
        online_pred = self.predictor(online_proj)
        
        with torch.no_grad():
            target_features = self.target_encoder(x)
            target_features = torch.flatten(target_features, 1)
            target_proj = self.target_projector(target_features)
        
        return online_pred, target_proj, online_features
    
    def byol_loss(self, view1, view2):
        online_pred1, target_proj2, online_feat1 = self.byol_forward(view1)
        online_pred2, target_proj1, online_feat2 = self.byol_forward(view2)
        
        online_pred1 = F.normalize(online_pred1, dim=-1)
        online_pred2 = F.normalize(online_pred2, dim=-1)
        target_proj1 = F.normalize(target_proj1, dim=-1)
        target_proj2 = F.normalize(target_proj2, dim=-1)
        
        loss1 = 2 - 2 * (online_pred1 * target_proj2).sum(dim=-1).mean()
        loss2 = 2 - 2 * (online_pred2 * target_proj1).sum(dim=-1).mean()
        
        loss = (loss1 + loss2) / 2
        return loss, (online_feat1 + online_feat2) / 2
    
    def update_target_network(self):
        for online_params, target_params in zip(
            self.online_encoder.parameters(), self.target_encoder.parameters()
        ):
            target_params.data = self.momentum * target_params.data + \
                                (1 - self.momentum) * online_params.data
        
        for online_params, target_params in zip(
            self.online_projector.parameters(), self.target_projector.parameters()
        ):
            target_params.data = self.momentum * target_params.data + \
                                (1 - self.momentum) * online_params.data
    
    def save(self, path):
        torch.save({
            'online_encoder': self.online_encoder.state_dict(),
            'online_projector': self.online_projector.state_dict(),
            'predictor': self.predictor.state_dict(),
            'target_encoder': self.target_encoder.state_dict(),
            'target_projector': self.target_projector.state_dict(),
            'prediction_head': self.prediction_head.state_dict(),
            'fine_tuned': self.fine_tuned,
            'with_indices': self.with_indices
        }, path)

## 4. Dataset and Training

In [None]:
class MultiViewDataset(Dataset):
    def __init__(self, imagery_paths, indices_paths=None, transform=None):
        self.imagery_paths = imagery_paths
        self.indices_paths = indices_paths if indices_paths is not None else [None] * len(imagery_paths)
        self.transform = transform
        
    def __len__(self):
        return len(self.imagery_paths)
    
    def __getitem__(self, idx):
        with rasterio.open(self.imagery_paths[idx]) as src:
            imagery = src.read()
        
        indices = None
        if self.indices_paths[idx] is not None:
            with rasterio.open(self.indices_paths[idx]) as src:
                indices = src.read()
        
        imagery = torch.from_numpy(imagery).float()
        if indices is not None:
            indices = torch.from_numpy(indices).float()
        
        if self.transform:
            imagery_view1 = self.transform(imagery)
            imagery_view2 = self.transform(imagery)
            
            if indices is not None:
                indices_view1 = self.transform(indices)
                indices_view2 = self.transform(indices)
                
                view1 = torch.cat([imagery_view1, indices_view1], dim=0)
                view2 = torch.cat([imagery_view2, indices_view2], dim=0)
            else:
                view1 = imagery_view1
                view2 = imagery_view2
        else:
            if indices is not None:
                view1 = view2 = torch.cat([imagery, indices], dim=0)
            else:
                view1 = view2 = imagery
        
        return view1, view2

# Define augmentations
class RandomAugmentation:
    def __call__(self, x):
        # Random horizontal flip
        if random.random() > 0.5:
            x = torch.flip(x, [2])
        # Random vertical flip
        if random.random() > 0.5:
            x = torch.flip(x, [1])
        return x

In [None]:
# Create dataset and data loader
transform = RandomAugmentation()
dataset = MultiViewDataset(
    imagery_paths=[imagery_path],
    indices_paths=[indices_path],
    transform=transform
)
data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

# Initialize model and optimizer
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BYOLModel(pretrained=True, with_indices=True).to(device)
optimizer = optim.Adam(
    list(model.online_encoder.parameters()) +
    list(model.online_projector.parameters()) +
    list(model.predictor.parameters()),
    lr=0.0001
)

In [None]:
# Train model
model.train()
epochs = 5  # Reduced for demonstration
losses = []

for epoch in range(epochs):
    epoch_loss = 0.0
    pbar = tqdm(data_loader, desc=f"Epoch {epoch+1}/{epochs}")
    
    for batch in pbar:
        view1, view2 = batch
        view1 = view1.to(device)
        view2 = view2.to(device)
        
        loss, _ = model.byol_loss(view1, view2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        model.update_target_network()
        
        epoch_loss += loss.item()
        pbar.set_postfix({"Loss": loss.item()})
    
    avg_loss = epoch_loss / len(data_loader)
    losses.append(avg_loss)
    
    print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}")

# Save model
model.save("results/byol_model.pth")

## 5. Inference and Visualization

In [None]:
def run_inference(model, imagery_path, indices_path, output_path):
    model.eval()
    
    # Load imagery and indices
    with rasterio.open(imagery_path) as src:
        imagery = src.read()
        meta = src.meta.copy()
    
    with rasterio.open(indices_path) as src:
        indices = src.read()
    
    # Convert to torch tensors
    imagery = torch.from_numpy(imagery).float().unsqueeze(0).to(device)
    indices = torch.from_numpy(indices).float().unsqueeze(0).to(device)
    
    # Combine imagery and indices
    x = torch.cat([imagery, indices], dim=1)
    
    # Extract features
    with torch.no_grad():
        features = model(x)
    
    # Convert to numpy
    features = features.cpu().numpy()[0]
    
    # Save features
    np.save(output_path, features)
    
    print(f"Inference completed. Results saved to {output_path}")
    return features

# Run inference
output_path = "results/drainage_features.npy"
features = run_inference(model, imagery_path, indices_path, output_path)

In [None]:
# Visualize results
with rasterio.open(imagery_path) as src:
    rgb = src.read([1, 2, 3])
    rgb = np.transpose(rgb, (1, 2, 0))
    rgb = rgb / rgb.max()

with rasterio.open(indices_path) as src:
    ndvi = src.read(1)

plt.figure(figsize=(10, 4))

plt.subplot(121)
plt.imshow(rgb)
plt.title("RGB Image")
plt.axis('off')

plt.subplot(122)
plt.imshow(ndvi, cmap='RdYlGn')
plt.title("NDVI")
plt.colorbar(shrink=0.5)
plt.axis('off')

plt.tight_layout()
plt.savefig("results/visualization.png")
plt.show()