In [26]:
!pip install rasterio



# Preprocess

In [3]:
import rasterio
import torch
from torchvision import transforms
import numpy as np

def load_tiff_band(tiff_path, band_index=1):
    """
    Load a specific band from a multi-band TIFF file.

    :param tiff_path: Path to the .tif file.
    :param band_index: The band to extract (1-indexed).
    :return: PyTorch tensor of the band data.
    """
    # Open the TIFF file using rasterio
    with rasterio.open(tiff_path) as dataset:
        # Read the specific band (1-indexed in rasterio)
        band_data = dataset.read(band_index)

    # Normalize the band data to [0, 1] by dividing by the max value (65535 for UINT16)
    band_data = band_data.astype(np.float32) / 65535.0  # Adjust normalization as needed

    # Convert the band to a PyTorch tensor
    band_tensor = torch.tensor(band_data, dtype=torch.float32)

    # Add channel dimension to match [C, H, W] format (1 channel since it's a single band)
    band_tensor = band_tensor.unsqueeze(0)

    return band_tensor

# Model

In [7]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class BinaryClassifierCNN(nn.Module):
    def __init__(self, in_channels=1):  # Set in_channels=1 for single-band input
        super(BinaryClassifierCNN, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1)

        # Adaptive Global Pooling to ensure output is fixed size
        self.pool = nn.AdaptiveAvgPool2d((1, 1))

        # Fully connected layer (final layer)
        self.fc = nn.Linear(128, 1)  # Output a single value for binary classification

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Flatten the tensor before feeding it into fully connected layer
        x = self.fc(x)
        return torch.sigmoid(x)  # Sigmoid to get probability between 0 and 1

# Train

In [15]:
import torch.optim as optim

# Initialize the model
model = BinaryClassifierCNN(in_channels=1)
criterion = nn.BCELoss()  # Binary cross-entropy loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Move model to device (GPU if available)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

In [23]:
band_tensor1 = load_tiff_band('tir.tiff', band_index=1).to(device)
band_tensor2 = load_tiff_band('tir2.tiff', band_index=1).to(device)

# Simulating a binary label (0 or 1)
label = torch.tensor([1, 0], dtype=torch.float32).unsqueeze(1).to(device)  # 1 for positive class, 0 for negative

sample1 = band_tensor1.unsqueeze(0)  # Shape: [1, 1, H, W]
sample2 = band_tensor2.unsqueeze(0)  # Shape: [1, 1, H, W]
band_tensor = torch.cat([sample1, sample2], dim=0)  # Shape: [2, 1, H, W]

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    model.train()  # Set model to training mode

    # Zero the gradients
    optimizer.zero_grad()

    # Forward pass
    outputs = model(band_tensor)

    # Calculate the loss
    loss = criterion(outputs, label)

    # Backward pass
    loss.backward()

    # Update the weights
    optimizer.step()

    # Calculate accuracy
    predicted_class = (outputs > 0.5).float()  # Predicted class is 1 if output > 0.5 else 0
    correct = (predicted_class == label).sum().item()  # Compare with the true label
    accuracy = correct / label.size(0)  # Accuracy as a fraction of correct predictions

    # Print training progress and accuracy
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Accuracy: {accuracy:.4f}")



Epoch [1/10], Loss: 0.6911, Accuracy: 0.5000
Epoch [2/10], Loss: 0.6910, Accuracy: 0.5000
Epoch [3/10], Loss: 0.6907, Accuracy: 0.5000
Epoch [4/10], Loss: 0.6902, Accuracy: 0.5000
Epoch [5/10], Loss: 0.6897, Accuracy: 0.5000
Epoch [6/10], Loss: 0.6891, Accuracy: 0.5000
Epoch [7/10], Loss: 0.6884, Accuracy: 0.5000
Epoch [8/10], Loss: 0.6875, Accuracy: 0.5000
Epoch [9/10], Loss: 0.6866, Accuracy: 0.5000
Epoch [10/10], Loss: 0.6854, Accuracy: 1.0000


# Predict

In [25]:
# Switch model to evaluation mode
model.eval()

# Load a new TIFF image for inference
tiff_path = 'tir2.tiff'
band_tensor = load_tiff_band(tiff_path, band_index=1).to(device)

# Add batch dimension [B, C, H, W]
band_tensor = band_tensor.unsqueeze(0)

# Make a prediction (no gradient calculation needed)
with torch.no_grad():
    output = model(band_tensor)
    predicted_class = (output > 0.5).float()  # Threshold at 0.5 to decide between class 0 and 1

print(f"Predicted class: {predicted_class.item()}")



Predicted class: 0.0
