In [4]:
!apt install git-lfs
!rm -rf ml4rs
!git clone  https://github.com/fzimmermann89/ml4rs/

Reading package lists... Done
Building dependency tree       
Reading state information... Done
The following NEW packages will be installed:
  git-lfs
0 upgraded, 1 newly installed, 0 to remove and 14 not upgraded.
Need to get 2,129 kB of archives.
After this operation, 7,662 kB of additional disk space will be used.
Get:1 http://archive.ubuntu.com/ubuntu bionic/universe amd64 git-lfs amd64 2.3.4-1 [2,129 kB]
Fetched 2,129 kB in 2s (1,343 kB/s)
Selecting previously unselected package git-lfs.
(Reading database ... 144793 files and directories currently installed.)
Preparing to unpack .../git-lfs_2.3.4-1_amd64.deb ...
Unpacking git-lfs (2.3.4-1) ...
Setting up git-lfs (2.3.4-1) ...
Processing triggers for man-db (2.8.3-2ubuntu0.1) ...
Cloning into 'ml4rs'...
remote: Enumerating objects: 6312, done.[K
remote: Counting objects: 100% (6312/6312), done.[K
remote: Compressing objects: 100% (6240/6240), done.[K
remote: Total 6312 (delta 22), reused 6308 (delta 18), pack-reused 0[K
Receiv

In [None]:
import sys
sys.path.append('ml4rs')

In [5]:
!wget "https://drive.google.com/uc?export=download&id=1ES5bALNZcS5AwiLZKuZW-aBNYiac80Nz" -O ds.zip -q && unzip -o ds.zip

Archive:  ds.zip
   creating: WV2_3bands_Site1/
  inflating: WV2_3bands_Site1/gt.bmp  
  inflating: WV2_3bands_Site1/t1.bmp  
  inflating: WV2_3bands_Site1/t2.bmp  


In [8]:
import numpy as np
import torch
import torch.nn
import matplotlib.pyplot as plt

from cd.ds import *

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
ds=WV_S1(Path('WV2_3bands_Site1/'),64)



In [None]:
trainds,testds,valds=split(ds,0.1,0.1)
trainds.augment=True

In [10]:
from torch.utils.data import DataLoader
train_loader = DataLoader(trainds, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)
validate_loader = DataLoader(valds, batch_size=64, shuffle=True, num_workers=8, pin_memory=True)


In [11]:
from cd.models.fcef.siamunet_diff import SiamUnet_diff
model=SiamUnet_diff(3,2)
model=model.float()
model=model.to(device)

In [12]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print('Number of trainable parameters:', count_parameters(model))

Number of trainable parameters: 1350146


In [13]:
from torch import nn

criterion = nn.NLLLoss().to(device)

optimizer = torch.optim.Adam(model.parameters(), weight_decay=1e-4)
#     optimizer = torch.optim.Adam(net.parameters(), lr=0.0005)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.95)

In [18]:
import time
def train_epoch(train_loader, model, criterion, optimizer, device):
    """
    train one epoch
    :param train_loader: DataLoader
    :param model: model
    :param criterion: loss function
    :param optimizer: optimizer
    returns (losses,data_time,batch_time)
    """
    model.train()
    start = time.time()
    losses = []
    for i, (im1,im2,cm) in enumerate(train_loader):
        # Move to device
        
        im1, im2, cm = im1.float().to(device), im2.float().to(device),cm.long().to(device)
        # Forward
        output = model(im1, im2)
        # Loss
        loss = criterion(output, cm)

        # Backward
        optimizer.zero_grad()
        loss.backward()


        # Update model
        optimizer.step()
        losses.append(loss.detach().cpu().numpy())
    return np.mean(losses)

In [19]:
def predict(im1, im2, model, device):
    model.eval()
    model = model.to(device)
    output = model(im1.float().to(device), im2.float().to(device)).detach()
    _, predicted = torch.max(output.data, 1)
    return predicted


def validate(validate_loader, model, device):
    tp, tn, fp, fn = 0.0, 0.0, 0.0, 0.0
    for i, (im1, im2, cm) in enumerate(validate_loader):
        gt = cm.cpu().numpy().astype(bool)
        pr = predict(im1, im2, model, device).cpu().numpy().astype(bool)
        tp += np.logical_and(pr, gt).sum()
        tn += np.logical_and(~pr, ~gt).sum()
        fp += np.logical_and(pr, ~gt).sum()
        fn += np.logical_and(~pr, gt).sum()
    precision = tp / (tp + fp)
    recall = tp / (tp + fn)
    return precision, recall


def checkpoint():
    pass


In [22]:
from tqdm.notebook import tqdm


def train(train_loader, validate_loader, model, criterion, optimizer, scheduler, nepochs, device):
    for epoch in tqdm(range(nepochs)):
        l = train_epoch(train_loader, model, criterion, optimizer, device)
        precision, recall = validate(validate_loader, model, device)
        print(f"epoch {epoch}/{nepochs} --- loss:{l}  precision:{precision}  recall:{recall}")
        scheduler.step()
        checkpoint()


In [None]:
train(train_loader, validate_loader,model, criterion, optimizer,scheduler,20,device)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

epoch 0/20 --- loss:0.4366258680820465  precision:0.9173195487709662  recall:0.9749565058144859
epoch 1/20 --- loss:0.395425409078598  precision:0.9201768112668786  recall:0.9792887556084607
epoch 2/20 --- loss:0.3765321373939514  precision:0.9214270374744981  recall:0.9821959985349327
epoch 3/20 --- loss:0.32296642661094666  precision:0.9212627520282753  recall:0.9845252266275982
epoch 4/20 --- loss:0.30552199482917786  precision:0.9230765099229261  recall:0.9835580532918231
epoch 5/20 --- loss:0.2965083718299866  precision:0.9261298549722264  recall:0.9808911729695083
epoch 6/20 --- loss:0.32120615243911743  precision:0.9246762887928948  recall:0.9837068491896347
