In [2]:
pip install torchvision 

Collecting torchvisionNote: you may need to restart the kernel to use updated packages.

  Downloading torchvision-0.22.1-cp312-cp312-win_amd64.whl.metadata (6.1 kB)
Collecting torch==2.7.1 (from torchvision)
  Downloading torch-2.7.1-cp312-cp312-win_amd64.whl.metadata (28 kB)
Collecting sympy>=1.13.3 (from torch==2.7.1->torchvision)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting networkx (from torch==2.7.1->torchvision)
  Downloading networkx-3.5-py3-none-any.whl.metadata (6.3 kB)
Collecting fsspec (from torch==2.7.1->torchvision)
  Downloading fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy>=1.13.3->torch==2.7.1->torchvision)
  Downloading mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Downloading torchvision-0.22.1-cp312-cp312-win_amd64.whl (1.7 MB)
   ---------------------------------------- 0.0/1.7 MB ? eta -:--:--
   ------------ --------------------------- 0.5/1.7 MB 5.6 MB/s eta 0:00:01
   -------------------


[notice] A new release of pip is available: 25.0 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score, jaccard_score
import rasterio
from glob import glob

ModuleNotFoundError: No module named 'sklearn'

In [None]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_classes=3):
        super().__init__()
        def CBR(in_c, out_c):
            return nn.Sequential(
                nn.Conv2d(in_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_c, out_c, 3, padding=1),
                nn.BatchNorm2d(out_c),
                nn.ReLU(inplace=True),
            )

        self.enc1 = CBR(3, 64)
        self.enc2 = CBR(64, 128)
        self.enc3 = CBR(128, 256)
        self.enc4 = CBR(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.up3 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dec3 = CBR(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dec2 = CBR(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dec1 = CBR(128, 64)

        self.final = nn.Conv2d(64, out_classes, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        d3 = self.up3(e4)
        d3 = torch.cat([d3, e3], dim=1)
        d3 = self.dec3(d3)
        d2 = self.up2(d3)
        d2 = torch.cat([d2, e2], dim=1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        d1 = torch.cat([d1, e1], dim=1)
        d1 = self.dec1(d1)

        return self.final(d1)

In [None]:
class TOADataset(Dataset):
    def __init__(self, image_paths, mask_paths, transform=None):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        with rasterio.open(self.image_paths[idx]) as img_src:
            img = img_src.read().astype(np.float32) / 1.0
        with rasterio.open(self.mask_paths[idx]) as mask_src:
            mask = mask_src.read(1).astype(np.int64)

        if self.transform:
            img = self.transform(torch.tensor(img))

        return torch.tensor(img, dtype=torch.float32), torch.tensor(mask, dtype=torch.long)

In [None]:
def train_model(model, dataloader, optimizer, loss_fn, num_epochs=20):
    model.train()
    history = {"loss": [], "f1": [], "iou": []}

    for epoch in range(num_epochs):
        epoch_loss = 0
        preds_all, labels_all = [], []

        for imgs, labels in dataloader:
            imgs, labels = imgs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(imgs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            epoch_loss += loss.item()
            preds = outputs.argmax(dim=1).cpu().numpy().flatten()
            labels_np = labels.cpu().numpy().flatten()

            preds_all.extend(preds)
            labels_all.extend(labels_np)

        iou = jaccard_score(labels_all, preds_all, average='macro')
        f1 = f1_score(labels_all, preds_all, average='macro')

        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss:.4f} - F1: {f1:.4f} - IoU: {iou:.4f}")
        history['loss'].append(epoch_loss)
        history['f1'].append(f1)
        history['iou'].append(iou)

    return history


In [None]:
image_paths = sorted(glob("data/processed/images/*.tif"))
mask_paths = sorted(glob("data/processed/masks/*.tif"))

transform = transforms.Lambda(lambda x: x)
dataset = TOADataset(image_paths, mask_paths, transform=transform)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
loss_fn = nn.CrossEntropyLoss()

history = train_model(model, dataloader, optimizer, loss_fn, num_epochs=20)

In [None]:
os.makedirs("outputs/logs", exist_ok=True)
os.makedirs("model", exist_ok=True)
torch.save(model.state_dict(), "model/unet_cloud_shadow.pth")

plt.plot(history['loss'], label='Loss')
plt.plot(history['f1'], label='F1')
plt.plot(history['iou'], label='IoU')
plt.title("Training Metrics")
plt.xlabel("Epoch")
plt.ylabel("Value")
plt.legend()
plt.savefig("outputs/logs/training_plot.png")
plt.show()


In [None]:
def mask_to_shapefile(mask_path, output_shapefile_path):
    with rasterio.open(mask_path) as src:
        mask = src.read(1)
        transform = src.transform

    geometries = list(shapes(mask, mask > 0, transform=transform))
    records = []
    for geom, val in geometries:
        records.append({"geometry": shape(geom), "class": int(val)})

    gdf = gpd.GeoDataFrame.from_records(records, crs=src.crs)
    gdf.to_file(output_shapefile_path)

    print(f"✅ Shapefile saved to {output_shapefile_path}")