In [1]:
# !wget  https://hf.co/torchgeo/satlas/resolve/081d6607431bf36bdb59c223777cbb267131b8f2/sentinel2_resnet50_mi_ms-da5413d2.pth 

In [2]:
import os
from torchvision import transforms
import torch
import torch.nn as nn
import rasterio
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import List, Optional
import torch
import torch.nn as nn
from torch.optim import Adam
from tqdm import tqdm
from dataset_create import WaterDataset,prepare,pad_to_stride
from seg_model import seg_model_consturct,get_loss

In [3]:
weights_path="sentinel2_resnet50_mi_ms-da5413d2.pth"
model = seg_model_consturct(
    model= 'deeplabv3+',
    backbone= 'resnet50',
    weights= weights_path,
    in_channels=10,
    num_classes=1,
)

  state_dict = torch.load(weights)


In [4]:
train_names = {v for v in os.listdir('../train/images/')  if "9" not in v }
test_names = {v for v in os.listdir('../train/images/') if "9"  in v }


In [5]:
from torch.utils.data import DataLoader, Dataset, random_split

train_ds = WaterDataset(
    img_path='../train/images/',
    mask_path='../train/masks/',
    file_names=train_names
)
trans  = train_ds.trans 
total_size = len(train_ds)
val_size = int(total_size * 0.2)
train_size = total_size - val_size


train_ds, test_ds = random_split(
    train_ds,
    [train_size, val_size],
    generator=torch.Generator().manual_seed(42)
)


In [6]:
# dl = DataLoader(ds)

In [7]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128
num_workers = 4
max_epochs = 500
train_loader = DataLoader(train_ds,batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(test_ds,batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [8]:
import os
import rasterio
import argparse
import numpy as np
import geopandas as gpd
import matplotlib.pyplot as plt

from shapely import affinity
from shapely import Point

from sklearn.metrics import f1_score



In [9]:
def flooded_houses(
    gdf,
    lats: np.ndarray, 
    lons: np.ndarray, 
    pred: np.ndarray, 
    ground_truth: np.ndarray 
):
    

    flooded_pred = []
    flooded_gt = []
    pred = pred.flatten()  # Flatten the prediction array
    ground_truth = ground_truth.flatten()  # Flatten the ground_truth array
    
    for _, row in gdf.iterrows():
        polygon = row.geometry
        # Scale the polygon for more accurate coverage
        scaled_polygon = affinity.scale(polygon, xfact=1.5, yfact=1.5)
        
        # Get the polygon's bounding box (xmin, ymin, xmax, ymax)
        xmin, ymin, xmax, ymax = scaled_polygon.bounds

        # Find the indices of points that fall inside the bounding box of the polygon
        selected_indices = np.where((ymin <= lats) & (lats <= ymax) & (xmin <= lons) & (lons <= xmax))
        
        lats_to_check = lats[selected_indices]
        lons_to_check = lons[selected_indices]
        flood_pred_to_check = pred[selected_indices]
        flood_gt_to_check = ground_truth[selected_indices]

        # Check if at least one point inside the polygon is flooded in the prediction mask
        is_flooded_pred = any(
            flood_pred_to_check[i] and scaled_polygon.contains(Point(lons_to_check[i], lats_to_check[i]))
            for i in range(len(flood_pred_to_check))
        )

        # Check if at least one point inside the polygon is flooded in the ground truth mask
        is_flooded_gt = any(
            flood_gt_to_check[i] and scaled_polygon.contains(Point(lons_to_check[i], lats_to_check[i]))
            for i in range(len(flood_gt_to_check))
        )

        flooded_pred.append(1 if is_flooded_pred else 0)
        flooded_gt.append(1 if is_flooded_gt else 0)

    return f1_score(flooded_gt, flooded_pred, average='macro')




def test_metric(pre_gt_path, pre_pred, post_gt_path, post_pred, osm_path):
    gdf = gpd.read_file(osm_path)
    gdf = gdf.to_crs(4326)
    gdf.tags.unique()
    
    with rasterio.open(pre_gt_path) as multi_band_src:
        pre_mask = multi_band_src.read(1)
        pre_height, pre_width = pre_mask.shape
        pre_cols, pre_rows = np.meshgrid(np.arange(pre_width), np.arange(pre_height))
        pre_x, pre_y = rasterio.transform.xy(multi_band_src.transform, pre_rows, pre_cols) 
        pre_lons, pre_lats = np.array(pre_x), np.array(pre_y)

    with rasterio.open(post_gt_path) as multi_band_src:
        post_mask = multi_band_src.read(1)
        post_height, post_width = post_mask.shape
        post_cols, post_rows = np.meshgrid(np.arange(post_width), np.arange(post_height))
        post_x, post_y = rasterio.transform.xy(multi_band_src.transform, post_rows, post_cols) 
        post_lons, post_lats = np.array(post_x), np.array(post_y)
    
    f1_water = (f1_score(pre_mask, pre_pred, average='macro') + f1_score(post_mask, post_pred, average='macro'))/2

    pre_f1 = flooded_houses(gdf, pre_lats, pre_lons, pre_pred, pre_mask)
    post_f1 = flooded_houses(gdf, post_lats, post_lons, post_pred, post_mask)
    avg_f1_business = (pre_f1 + post_f1) / 2

    return (f1_water + avg_f1_business)/2

In [10]:
import numpy as np

In [11]:
def get_pred(model,file,paths,trans,device):
    output_stride=256
    with rasterio.open(f'{paths}/{file}') as src:
        image = src.read().astype(np.float32)  
        image = torch.tensor(image)
        size_out = image.size()
        image = pad_to_stride(image, output_stride)
        h_splits = image.shape[1] // output_stride
        w_splits = image.shape[2] // output_stride
        tensor_split = image.unfold(1, output_stride, output_stride).unfold(2, output_stride, output_stride)
        tensor_split = tensor_split.contiguous().view(-1, image.shape[0], output_stride, output_stride)
        tensor_split = trans(tensor_split)
        meta = src.meta

    model.eval()
    sig = torch.nn.Sigmoid()
    outputs = []

    for tensor in tensor_split:
        with torch.no_grad():
            output = sig(model(tensor.to(device).unsqueeze(0))).cpu()
            outputs.append(output)
    output_rows = []
    for h in range(h_splits):
        row_outputs = outputs[h * w_splits : (h + 1) * w_splits]
        output_rows.append(torch.cat(row_outputs, dim=-1))  

    final_output = torch.cat(output_rows, dim=-2)


    meta['count'] = 1
    pred = (final_output.squeeze(0)<0.5).cpu().numpy().astype("int32")


    pred = pred[0][:size_out[1], :size_out[2]]
    return pred


In [12]:
from torch.optim.lr_scheduler import StepLR
def train_segmentation(model, train_loader,criterion,val_loader, trans,num_epochs=25, learning_rate=1e-5,step_size=5, gamma=0.1):
    model = model.to(device)
    
    optimizer = Adam(model.parameters(), lr=learning_rate,weight_decay=1e-5)
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
    for epoch in range(num_epochs):
        model.train()  
        train_loss = 0.0
        
        for images, masks in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            
            outputs = model(images)
          
            loss = criterion(outputs, masks)
            
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
        
        avg_train_loss = train_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}")
        
        model.eval()  
        val_loss = 0.0
        with torch.no_grad():
            for images, masks in tqdm(val_loader, desc="Validating"):
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = criterion(outputs, masks)
                val_loss += loss.item()

   
            
        with rasterio.open("../train/masks/9_1.tif") as multi_band_src:
            pre_mask = multi_band_src.read(1)
            
        with rasterio.open("../train/masks/9_2.tif") as multi_band_src:
            post_mask = multi_band_src.read(1)
        
        avg_val_loss = val_loss / len(val_loader)
        pre_pred = get_pred(model,"9_1.tif",'../train/images/',trans,device)
        post_pred = get_pred(model,"9_2.tif",'../train/images/',trans,device)
        metric = test_metric("../train/masks/9_1.tif", pre_pred, "../train/masks/9_2.tif", post_pred, "../train/osm/9.geojson")
        print(f"Epoch [{epoch+1}/{num_epochs}], Validation Loss: {avg_val_loss:.4f} , Validation metric: {metric:.4f}")
        scheduler.step()
        print(f"Epoch [{epoch+1}/{num_epochs}], Learning Rate: {scheduler.get_last_lr()[0]:.6f}")
        
    
    print("Training complete.")




In [13]:
criterion = get_loss('ce')
train_segmentation(model, train_loader,criterion,val_loader, trans,num_epochs=25, learning_rate=1e-5,step_size=7, gamma=0.1)


Training Epoch 1/25: 100%|██████████| 27/27 [00:28<00:00,  1.05s/it]


Epoch [1/25], Train Loss: 0.6179


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [1/25], Validation Loss: 0.6220 , Validation metric: 0.2012
Epoch [1/25], Learning Rate: 0.000010


Training Epoch 2/25: 100%|██████████| 27/27 [00:27<00:00,  1.04s/it]


Epoch [2/25], Train Loss: 0.5703


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.55it/s]


Epoch [2/25], Validation Loss: 0.5923 , Validation metric: 0.2134
Epoch [2/25], Learning Rate: 0.000010


Training Epoch 3/25: 100%|██████████| 27/27 [00:28<00:00,  1.04s/it]


Epoch [3/25], Train Loss: 0.5133


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.53it/s]


Epoch [3/25], Validation Loss: 0.5870 , Validation metric: 0.2010
Epoch [3/25], Learning Rate: 0.000010


Training Epoch 4/25: 100%|██████████| 27/27 [00:28<00:00,  1.04s/it]


Epoch [4/25], Train Loss: 0.4717


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [4/25], Validation Loss: 0.5019 , Validation metric: 0.1968
Epoch [4/25], Learning Rate: 0.000010


Training Epoch 5/25: 100%|██████████| 27/27 [00:27<00:00,  1.03s/it]


Epoch [5/25], Train Loss: 0.4379


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]


Epoch [5/25], Validation Loss: 0.4679 , Validation metric: 0.1966
Epoch [5/25], Learning Rate: 0.000010


Training Epoch 6/25: 100%|██████████| 27/27 [00:28<00:00,  1.04s/it]


Epoch [6/25], Train Loss: 0.4043


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]


Epoch [6/25], Validation Loss: 0.4353 , Validation metric: 0.3216
Epoch [6/25], Learning Rate: 0.000010


Training Epoch 7/25: 100%|██████████| 27/27 [00:27<00:00,  1.04s/it]


Epoch [7/25], Train Loss: 0.3769


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [7/25], Validation Loss: 0.4170 , Validation metric: 0.3216
Epoch [7/25], Learning Rate: 0.000001


Training Epoch 8/25: 100%|██████████| 27/27 [00:27<00:00,  1.03s/it]


Epoch [8/25], Train Loss: 0.3613


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [8/25], Validation Loss: 0.4094 , Validation metric: 0.3216
Epoch [8/25], Learning Rate: 0.000001


Training Epoch 9/25: 100%|██████████| 27/27 [00:27<00:00,  1.03s/it]


Epoch [9/25], Train Loss: 0.3575


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [9/25], Validation Loss: 0.4084 , Validation metric: 0.3216
Epoch [9/25], Learning Rate: 0.000001


Training Epoch 10/25: 100%|██████████| 27/27 [00:27<00:00,  1.04s/it]


Epoch [10/25], Train Loss: 0.3552


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [10/25], Validation Loss: 0.3940 , Validation metric: 0.3216
Epoch [10/25], Learning Rate: 0.000001


Training Epoch 11/25: 100%|██████████| 27/27 [00:27<00:00,  1.03s/it]


Epoch [11/25], Train Loss: 0.3551


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.57it/s]


Epoch [11/25], Validation Loss: 0.4207 , Validation metric: 0.3216
Epoch [11/25], Learning Rate: 0.000001


Training Epoch 12/25: 100%|██████████| 27/27 [00:28<00:00,  1.04s/it]


Epoch [12/25], Train Loss: 0.3521


Validating: 100%|██████████| 7/7 [00:04<00:00,  1.56it/s]


Epoch [12/25], Validation Loss: 0.4121 , Validation metric: 0.3216
Epoch [12/25], Learning Rate: 0.000001


Training Epoch 13/25:  33%|███▎      | 9/27 [00:11<00:22,  1.27s/it]


KeyboardInterrupt: 

In [None]:
# from torch.utils.data import DataLoader, Dataset, random_split

# total_size = len(dataset)
#         val_size = int(total_size * self.val_split)
#         train_size = total_size - val_size


#         self.train_set, self.val_set = random_split(
#             dataset,
#             [train_size, val_size],
#             generator=torch.Generator().manual_seed(42)
#         )

In [None]:
# torch.save(trainer.model.model.state_dict(), "weight_ckpt/best.pth")


In [None]:
# !ls weight_ckpt

In [None]:
# import pickle 

In [None]:
# with open("weight_ckpt/best.pikle","wb") as f:
#     pickle.dump(trainer.model.model, f)