In [None]:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import rasterio

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# # Uncomment to see what object you have when you read the images : 
# # Sentinel-2 image bands : {"Blue": 0, "Green": 1, "Red": 2, "NIR": 3, "SWIR1": 4, "SWIR2": 5} 

# with rasterio.open("/home/llalla/Documents/SWOT/tuto_unet/data_folder/s2/5_tile_2048_0.tif") as src: 
#     s2_image = src.read()
# s2_image = s2_image[:3, :, :]
# s2_image.shape

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, thumbnail_name, s1_image_folder, s2_image_folder, mask_folder):
        self.thumbnail_name = thumbnail_name
        self.s1_image_folder = s1_image_folder
        self.s2_image_folder = s2_image_folder
        self.mask_folder = mask_folder

    def __len__(self):
        return len(self.thumbnail_name)

    def __getitem__(self, idx):
        s1_image_path = os.path.join(self.s1_image_folder, self.thumbnail_name[idx])
        s2_image_path = os.path.join(self.s2_image_folder, self.thumbnail_name[idx])
        mask_path = os.path.join(self.mask_folder, self.thumbnail_name[idx])
        
        with rasterio.open(s1_image_path) as src: 
            s1_image = src.read()
                                     
        with rasterio.open(s2_image_path) as src: 
            s2_image = src.read()
            s2_image = s2_image[:3, :, :] # keep only first 3 visible bands (RGB). Yan can also keep the 6 bands. 
                               
        with rasterio.open(mask_path) as src: 
            mask = src.read()
        mask = mask.astype('float32')
        
        # early fusion 
        image = np.concatenate((s1_image, s2_image), axis = 0)
        
        return torch.from_numpy(image),torch.from_numpy(mask)

In [None]:
data_folder = "/home/llalla/Documents/SWOT/tuto_unet/data_folder/"
s1_image_folder = data_folder + "s1"
s2_image_folder = data_folder + "s2"
mask_folder = data_folder + "masks"

# thumbnails all have the same name : s1, s2, mask
thumbnails = os.listdir(s1_image_folder)
train_images, test_images = train_test_split(thumbnails, test_size=0.2, random_state=42)                                                    
test_images, val_images = train_test_split(test_images, test_size=0.5, random_state=42)

In [None]:
print(len(train_images), len(test_images), len(val_images))

In [None]:
# train_images

In [None]:
trainset = Dataset(train_images, s1_image_folder, s2_image_folder, mask_folder)
testset = Dataset(test_images, s1_image_folder, s2_image_folder, mask_folder)
valset = Dataset(val_images, s1_image_folder, s2_image_folder, mask_folder)

In [None]:
train_batch_size = 32
val_batch_size = 32
test_batch_size = 8

trainloader = torch.utils.data.DataLoader(trainset, batch_size= train_batch_size, num_workers= 2, shuffle= True)
valloader = torch.utils.data.DataLoader(valset, batch_size= val_batch_size, num_workers= 2, shuffle= True)
testloader = torch.utils.data.DataLoader(testset, batch_size= test_batch_size, num_workers= 2, shuffle= True)

In [None]:
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.conv(x)

class Down(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.down = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )
    def forward(self, x):
        return self.down(x)

class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=False):
        super().__init__()
        if bilinear:
            self.up = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
                                    nn.Conv2d(in_channels, in_channels // 2, 1))
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, 2, stride=2)
            
        self.conv = DoubleConv(in_channels, out_channels)
        
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        return self.sigmoid(self.conv(x))

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        self.down4 = Down(512, 1024)
        self.up1 = Up(1024, 512)
        self.up2 = Up(512, 256)
        self.up3 = Up(256, 128)
        self.up4 = Up(128, 64)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x4, x3)
        x = self.up3(x3, x2)
        x = self.up4(x2, x1)
        logits = self.outc(x)
        return logits

In [None]:
class DICE_BCE_Loss(nn.Module):
    def __init__(self, smooth=1):
        super().__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        intersection = 2*(logits * targets).sum() + self.smooth
        union = (logits + targets).sum() + self.smooth
        dice_loss = 1. - intersection / union

        loss = nn.BCELoss()
        bce_loss = loss(logits, targets)

        return dice_loss + bce_loss
    
def dice_coeff(logits, targets):
    intersection = 2*(logits * targets).sum()
    union = (logits + targets).sum()
    if union == 0:
        return 1
    dice_coeff = intersection / union
    return dice_coeff.item()

In [None]:
def train(model, trainloader, optimizer, loss, epochs=10):
    train_losses, val_losses = [], []
    train_dices, val_dices = [], []
    for epoch in tqdm(range(epochs)):
        model.train()
        train_loss = 0
        train_dice = 0
        for i, (images, masks) in enumerate(trainloader):
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            logits = model(images)
            l = loss(logits, masks)
            l.backward()
            optimizer.step()
            train_loss += l.item()
            train_dice += dice_coeff(logits, masks)
        train_loss /= len(trainloader)
        train_dice /= len(trainloader)
        train_losses.append(train_loss)
        train_dices.append(train_dice)
        
        #Validation
        model.eval()
        val_loss = 0
        val_dice = 0
        with torch.no_grad():
            for i, (images, masks) in enumerate(valloader):
                images, masks = images.to(device), masks.to(device)
                logits = model(images)
                l = loss(logits, masks)
                val_loss += l.item()
                val_dice += dice_coeff(logits, masks)
        val_loss /= len(valloader)
        val_dice /= len(valloader)
        val_losses.append(val_loss)
        val_dices.append(val_dice)
        print(f"Epoch: {epoch + 1}  Train Loss: {train_loss:.4f} | Train DICE Coeff: {train_dice:.4f} | Val Loss: {val_loss:.4f} | Val DICE Coeff: {val_dice:.4f}")
        
    return train_losses, train_dices, val_losses, val_dices

In [None]:
epochs = 30
loss = DICE_BCE_Loss()
model = UNet(5, 1).to(device) # 5 for 3 optical bands + 2 radars. Could also be 8 si keep 6 optical bands. 
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
train_losses, train_dices, val_losses, val_dices = train(model, trainloader, optimizer, loss, epochs)

# Visualise Loss and DICE during training 

In [None]:
plt.figure(figsize= (10, 6))
plt.subplot(1, 2, 1)
plt.plot(np.arange(epochs), train_dices)
plt.plot(np.arange(epochs), val_dices)
plt.xlabel("Epoch")
plt.ylabel("DICE Coeff")
plt.legend(["Train DICE", "Val DICE"])
plt.subplot(1, 2, 2)
plt.plot(np.arange(epochs), train_losses)
plt.plot(np.arange(epochs), val_losses)
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Train Loss", "Val Loss"])

# Visualise test images

In [None]:
images, masks = next(iter(testloader))

with torch.no_grad():
    pred = model(images.to(device)).cpu().detach()
    pred = pred > 0.5

def display_batch(images, masks, pred):
    # move axis for plotting with matplotlib
    images = images.permute(0, 2, 3, 1)
    masks = masks.permute(0, 2, 3, 1)
    pred = pred.permute(0, 2, 3, 1)

    # convert from tensor to numpy array
    images = images.numpy()
    masks = masks.numpy()
    pred = pred.numpy()
    
    # (1, 256, 256, 5) -> (256, 256, 5) 
    images = np.concatenate(images, axis=1)
    masks = np.concatenate(masks, axis=1)
    pred = np.concatenate(pred, axis=1)
    
    # separate s1 / s2
    s1_0 = images[:, : , 0] # 1 band image
    s1_1 = images[:, : , 1] # 1 band image
    s2 = images[:, : , 2:] # color 3 bands image
        
    fig, ax = plt.subplots(5, 1, figsize=(20, 6))
    fig.tight_layout()
    
    ax[0].imshow(s1_0)
    ax[0].set_title('s1_0 Images')    
    ax[1].imshow(s1_1)
    ax[1].set_title('s1_1 Images')
    ax[2].imshow(s2)
    ax[2].set_title('s2 Images')
    ax[3].imshow(masks, cmap= 'gray')
    ax[3].set_title('Masks')
    ax[4].imshow(pred, cmap= 'gray')
    ax[4].set_title('Predictions')

display_batch(images, masks, pred)

In [None]:
torch.save(model.state_dict(), 'unet_s1s2_earlyfusion_RGB.pth')

# load previously trained model

In [None]:
model = UNet(5, 1).to(device) # 3 optical bands + 2 radars
model.load_state_dict(torch.load("/home/llalla/Documents/SWOT/tuto_unet/unet_s1s2_earlyfusion_RGB.pth"))
model.eval()

In [None]:
images, masks = next(iter(valloader))

with torch.no_grad():
    pred = model(images.to(device)).cpu().detach()
    pred = pred > 0.5

In [None]:
# move axis for plotting with matplotlib
images = images.permute(0, 2, 3, 1)
masks = masks.permute(0, 2, 3, 1)
pred = pred.permute(0, 2, 3, 1)

# convert from tensor to numpy array
images = images.numpy()
masks = masks.numpy()
pred = pred.numpy()

# (1, 256, 256, 5) -> (256, 256, 5) 
images = np.concatenate(images, axis=1)
masks = np.concatenate(masks, axis=1)
pred = np.concatenate(pred, axis=1)

# separate s1 / s2
s1_0 = images[:, :, 0] # 1 band image
s1_1 = images[:, :, 1] # 1 band image
s2 = images[:, :, 2:] # color 3 bands image

fig, ax = plt.subplots(5, 1, figsize=(20, 6))
fig.tight_layout()

ax[0].imshow(s1_0)
ax[0].set_title('s1_0 Images')    
ax[1].imshow(s1_1)
ax[1].set_title('s1_1 Images')
ax[2].imshow(s2)
ax[2].set_title('s2 Images')
ax[3].imshow(masks, cmap= 'gray')
ax[3].set_title('Masks')
ax[4].imshow(pred, cmap= 'gray')
ax[4].set_title('Predictions')