In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from kymatio.torch import Scattering2D
import segmentation_models_pytorch as smp
import albumentations as A
from torch.utils.data import Dataset, DataLoader
import cv2
import numpy as np
import glob


In [None]:
# 1. Configurazione Scattering Wavelet
class ScatteringPreprocessor(nn.Module):
      def __init__(self, J=2, shape=(256, 256)):
            super().__init__()
            self.scattering = Scattering2D(J=J, shape=shape)
            self.J = J
            self.shape = shape
            
      def forward(self, x):
            S = self.scattering(x)
            batch_size, channels, coeffs, h, w = S.shape
            return S.view(batch_size, channels * coeffs, h, w)


# 2. Architettura Custom Unet con Scattering
class ScatteringRiverSegmenter(nn.Module):
      def __init__(self, J=2, input_shape=(256, 256), num_classes=1):
            super().__init__()
            
            # Calcolo dimensioni scattering
            self.scattering = ScatteringPreprocessor(J=J, shape=input_shape)
            dummy_in = torch.randn(1, 3, *input_shape)
            dummy_out = self.scattering(dummy_in)
            scat_channels = dummy_out.shape[1]
            
            # Custom Encoder-Decoder
            self.encoder = nn.Sequential(
                  nn.Conv2d(scat_channels, 64, 3, padding=1),
                  nn.ReLU(),
                  nn.MaxPool2d(2),
                  nn.Conv2d(64, 128, 3, padding=1),
                  nn.ReLU(),
                  nn.MaxPool2d(2)
            )
            
            self.decoder = nn.Sequential(
                  nn.ConvTranspose2d(128, 64, 2, stride=2),
                  nn.ReLU(),
                  nn.ConvTranspose2d(64, 32, 2, stride=2),
                  nn.ReLU(),
                  nn.Conv2d(32, num_classes, 1),
                  nn.Sigmoid()
            )
            
      def forward(self, x):
            x = self.scattering(x)
            x = self.encoder(x)
            x = self.decoder(x)
            return F.interpolate(x, scale_factor=2**self.J, mode='bilinear', align_corners=False)


# 3. Dataset con Augmentation
class RiverDataset(Dataset):
      def __init__(self, img_dir, mask_dir, scattering, transform=None):
            self.img_paths = sorted(glob.glob(f"{img_dir}/*.jpg"))
            self.mask_paths = sorted(glob.glob(f"{mask_dir}/*.png"))
            self.transform = transform
            self.scattering = scattering
            
            # Controllo dimensionale
            test_img = cv2.imread(self.img_paths[0])
            self.original_size = test_img.shape[:2]
            
      def __getitem__(self, idx):
            img = cv2.cvtColor(cv2.imread(self.img_paths[idx]), cv2.COLOR_BGR2RGB)
            mask = cv2.imread(self.mask_paths[idx], 0)
            
            if self.transform:
                  augmented = self.transform(image=img, mask=mask)
                  img = augmented['image']
                  mask = augmented['mask']
            
            # Applicazione scattering
            img_tensor = torch.from_numpy(img).permute(2, 0, 1).float()
            with torch.no_grad():
                  scattering_coeffs = self.scattering(img_tensor.unsqueeze(0)).squeeze()
            
            return scattering_coeffs, torch.from_numpy(mask).float()


# 4. Pipeline di Training
def train_river_segmenter():
      # Configurazioni
      J = 2
      input_shape = (256, 256)
      device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
      
      # Modello e preprocessing
      scattering = ScatteringPreprocessor(J=J, shape=input_shape).to(device)
      model = ScatteringRiverSegmenter(J=J, input_shape=input_shape).to(device)
      
      # Augmentations specifiche per drone
      transform = A.Compose([
            A.Resize(256, 256),
            A.RandomRotate90(),
            A.RandomBrightnessContrast(p=0.5),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.GaussNoise(var_limit=(0.001, 0.005)),  # Rumore tipico da drone
            A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
      ])
      
      # Dataset e DataLoader
      dataset = RiverDataset(
            img_dir='path/to/images',
            mask_dir='path/to/masks',
            scattering=scattering,
            transform=transform
      )
      
      train_loader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=4)
      
      # Ottimizzatore e Loss
      optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
      loss_fn = smp.losses.DiceLoss(smp.losses.BINARY_MODE, from_logits=False)
      
      # Training Loop
      for epoch in range(50):
            model.train()
            for batch in train_loader:
                  coeffs, masks = batch
                  coeffs, masks = coeffs.to(device), masks.to(device)
                  
                  optimizer.zero_grad()
                  outputs = model(coeffs)
                  loss = loss_fn(outputs, masks.unsqueeze(1))
                  loss.backward()
                  optimizer.step()
            
            # Validazione e logging (aggiungere secondo dataset)
            print(f'Epoch {epoch+1}, Loss: {loss.item():.4f}')
      
      # Salvataggio modello
      torch.save(model.state_dict(), 'river_segmenter_scattering.pth')


# 5. Inference con Post-Processing
class RiverSegmenter:
      def __init__(self, model_path, J=2, device='cuda'):
            self.device = device
            self.scattering = ScatteringPreprocessor(J=J, shape=(256, 256)).to(device)
            self.model = ScatteringRiverSegmenter(J=J).to(device)
            self.model.load_state_dict(torch.load(model_path))
            self.model.eval()
            
            self.transform = A.Compose([
                  A.Resize(256, 256),
                  A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
            ])
      
      def predict(self, image_path):
            img = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
            original_size = img.shape[:2]
            
            # Preprocessing
            augmented = self.transform(image=img)
            img_tensor = torch.from_numpy(augmented['image']).permute(2, 0, 1).float()
            
            # Scattering e predizione
            with torch.no_grad():
                  coeffs = self.scattering(img_tensor.unsqueeze(0).to(self.device))
                  pred = self.model(coeffs).squeeze().cpu().numpy()
            
            # Post-processing
            mask = cv2.resize(pred, original_size[::-1])
            binary_mask = (mask > 0.5).astype(np.uint8)
            
            # Pulizia morfologica
            kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
            cleaned = cv2.morphologyEx(binary_mask, cv2.MORPH_CLOSE, kernel)
            
            return cleaned * 255


# Utilizzo
if __name__ == "__main__":
      # Addestramento
      train_river_segmenter()
      
      # Inference
      segmenter = RiverSegmenter('river_segmenter_scattering.pth')
      mask = segmenter.predict('drone_image.jpg')
      cv2.imwrite('river_mask.png', mask)