In [None]:
import os
from glob import glob
import re
import logging
from tqdm import tqdm
from itertools import chain

import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt


import torch
from torch.utils.data import Dataset,DataLoader,random_split
import torchvision.transforms as transforms
from PIL import Image
import albumentations as A

from models import encoders, decoders
from src import datasets, utils, metrics

In [None]:
class arguments():
    def __init__(self) -> None:
        pass
    
global args
args = arguments()
args.log_level = "INFO"
args.dump_path = './results/field_delineation'

try:
    os.makedirs(args.dump_path)
except FileExistsError:
    print("Please delete the target directory if you would like to proceed.")

  
# Set up logger and log the arguments
def set_up_logger():
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        filename=os.path.join(args.dump_path, "output.log"),
        filemode="w",
        level=args.log_level,
        format="%(asctime)s - %(levelname)s - %(message)s",
    )
      
set_up_logger()
logging.info(args)
# Set up timer to time results
overall_timer = utils.Timer()

In [None]:
def matchImageWithMask_crop(img_folder, mask_folder):
    img_filepaths = glob(os.path.join( img_folder ,"*.jpeg"))
    mask_filepaths = glob(os.path.join(mask_folder,"*.png"))

    img_names  = [int(re.search("\d+",os.path.basename(x)).group(0)) for x in img_filepaths]
    mask_names = [int(re.search("\d+",os.path.basename(x)).group(0)) for x in mask_filepaths]
    #mask_names.sort()

    img_filepaths_keep = []
    mask_filepaths_keep = []#os.path.join(mask_folder,str(name)+'png') for name in self.img_names] # same name

    num_matched = 0
    unmatched_names = []
    for name in tqdm(img_names):
        if name in mask_names:
            img_path = os.path.join(img_folder,str(name)+'.jpeg')
            mask_path = os.path.join(mask_folder,str(name)+'.png')
            img_filepaths_keep.append(img_path)
            mask_filepaths_keep.append(mask_path)
            num_matched += 1
        else:
            unmatched_names.append(name)
            logging.info(f"image {name} has no matching mask")
    print(f"{len(img_filepaths_keep)} images has matched masks, {len(unmatched_names)}/{len(img_names)} images have no matching masks")
    logging.info(f"{len(img_filepaths_keep)} images has matched masks, {len(unmatched_names)}/{len(img_names)} images have no matching masks")
    
    return img_filepaths_keep,mask_filepaths_keep,unmatched_names

def split_train_test(img_filepaths,mask_filepaths,test_percent = 0.2, seed=123):
    N = len(img_filepaths)
    test_size  = int(N*test_percent)

    random.seed(seed)
    test_idx = random.sample(range(N),test_size )
    train_idx = [ elem for elem in range(N) if elem not in test_idx]
    train_img_filepaths = np.array(img_filepaths)[train_idx]
    train_mask_filepaths = np.array(mask_filepaths)[train_idx]
    test_img_filepaths = np.array(img_filepaths)[test_idx]
    test_mask_filepaths = np.array(mask_filepaths)[test_idx]
    
    return train_img_filepaths,train_mask_filepaths,test_img_filepaths,test_mask_filepaths

    
    
#img_folder = "/scratch/yc506/crop_delineation/batch/"
img_folder = "./crop_delineation/imgs"
mask_folder = "./crop_delineation/masks"  

# train: test: val = 6: 2: 2
test_percent = 0.2
val_percent = 0.2

img_filepaths_keep, mask_filepaths_keep,unmatched_names =  matchImageWithMask_crop(img_folder, mask_folder)
train_img_filepaths,train_mask_filepaths,test_img_filepaths,test_mask_filepaths  = split_train_test(img_filepaths_keep,mask_filepaths_keep,
                                                                                    test_percent = 0.2, seed=123)
train_img_filepaths,train_mask_filepaths,val_img_filepaths,val_mask_filepaths  = split_train_test(train_img_filepaths,train_mask_filepaths,
                                                                                     test_percent = val_percent/(1-test_percent), seed=123)

In [None]:
# Image dimension in PyTorch (B, C, H, W)
# Image dimension in Numpy  (H, W, C)

def img_np2torch_dim(X):
    X_ = np.vstack([np.expand_dims(X[:,:,i],0) for i in range(3)])
    return X_

def img_torch2np(X):
    X_ = X.permute((1,2,0))
    X_ = X_.numpy()
    return X_


def get_mean_and_std(dataloader):
    # function to compute mean and std over an image collection -> for image normalization
    channels_sum, channels_squared_sum, num_batches = 0, 0, 0
    for i,(X,y) in enumerate(dataloader):
        # Mean over batch, height and width, but not over the channels
        channels_sum += torch.mean(X, dim=[0,2,3])
        channels_squared_sum += torch.mean(X**2, dim=[0,2,3])
        num_batches += 1
    
    mean = channels_sum / num_batches
    # std = sqrt(E[X^2] - (E[X])^2)
    std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5

    return mean, std

In [None]:
# define fieldDelineationDataset using PyTorch Dataset

class fieldDelineationDataset(Dataset):
    def __init__(self, image_filepaths,mask_filepaths,transform=None, augmentations=None):
        
    
        self.transform = transform
        self.aug = augmentations

        self.img_names  = [int(re.search("\d+",os.path.basename(x)).group(0)) for x in image_filepaths]
        
        self.img_filepaths = []#os.path.join(img_folder,str(name)+'.jpeg') for name in self.img_names]
        self.mask_filepaths = []#os.path.join(mask_folder,str(name)+'png') for name in self.img_names] # same name

        self.images = []
        self.masks = []
        num_unmatched = 0
        for i in tqdm(range(len(self.img_names[::10]))):
            try:
                im  = np.asarray(Image.open(image_filepaths[i]))
                mask  = np.asarray(Image.open(mask_filepaths[i]))
                self.img_filepaths.append(image_filepaths[i])
                self.mask_filepaths.append(image_filepaths[i])
                self.images.append(im)
                self.masks.append(mask)
            except:
                num_unmatched += 1
        logging.info(f"{len(self.images)} images has matched masks, {num_unmatched} images have no matching masks")
        print(f"{len(self.images)} images has matched masks, {num_unmatched} images have no matching masks")
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, index: int):
    
        X = self.images[index]
        y = self.masks[index]
        y = y/255

        
        # apply agumentation using albumentations: 
        # albumentations works best with numpym augment then transform to tensor
        # apply augmenation on mask and image together
        if self.aug:
            augmented_X_y =  self.aug(image=X, mask =y)
            X = augmented_X_y ['image']
            y = augmented_X_y ['mask']
            
        y = np.expand_dims(y,0)
        y = torch.from_numpy(y)
        # apply transformation
        if self.transform:
            X = self.transform(X)
        else:
            # by default, turn numpy to tensor, adjust dimension
            X = img_np2torch_dim(X)
            X = torch.from_numpy(X)
        
        return X.type(torch.FloatTensor), y.type(torch.LongTensor)

In [None]:
compute_normalization_params =  True

# load train without any transformation to compute mean and std
train_dataset = fieldDelineationDataset(train_img_filepaths,train_mask_filepaths,
                                            transform=transforms.ToTensor())
if compute_normalization_params:
    torch.manual_seed(12)
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
    mean, std = get_mean_and_std(train_loader)
    print(mean, std)
else:
    mean= [0.2397, 0.2972, 0.3173]
    std = [0.1876, 0.1223, 0.1136]

In [None]:
transform_norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

# invert transfomration when plotting
invTrans = transforms.Compose([ transforms.Normalize(mean = [ -mean[i]/std[i] for i in range(3)],
                                                     std = [ 1/std[i] for i in range(3) ]),
                               ])

In [None]:
prob = 0.25
aug = A.Compose(
            [
                A.RandomRotate90(p= prob),
                A.VerticalFlip(p = prob),
                A.HorizontalFlip(p = prob),
                A.Transpose(p = prob),
            ]
        )

In [None]:
# check augmentation
X,y =train_dataset[20]
X_ = img_torch2np(invTrans(X))
y_ = img_torch2np(y)
X_aug = aug(image  = X_,mask = y_)
fig, ax= plt.subplots(1,4,figsize=(15,60))
ax[0].imshow(X_)
ax[0].set_title("original image")
ax[0].axis('off')
ax[1].imshow(X_aug['image'])
ax[1].set_title("augmented image")
ax[1].axis("off")
ax[2].imshow(y_)
ax[2].set_title("original mask")
ax[2].axis("off")
ax[3].imshow(X_aug['mask'])
ax[3].set_title("augmented mask")
ax[3].axis("off")
plt.show()

In [None]:
# update train dataset's transfromation and augmentation
train_dataset.transform = transform_norm
train_dataset.aug = aug
# load validation
valid_dataset = fieldDelineationDataset(val_img_filepaths,val_mask_filepaths,
                                        transform=transform_norm)

In [None]:
N = 3
fig,axes = plt.subplots(1,2*N,figsize=(20,8))
for i in range(N):
    X,y =train_dataset[i]
    X_ = img_torch2np(invTrans(X)) # apply invert normlaization for visualization
    y_ = img_torch2np(y)
    axes[2*i+0].imshow(X_)
    axes[2*i+1].imshow(y_)
    axes[2*i+0].axis('off')
    axes[2*i+1].axis('off')
plt.show()


N = 3
fig,axes = plt.subplots(1,2*N,figsize=(20,8))
for i in range(N):
    X,y =valid_dataset[i]
    X_ = img_torch2np(invTrans(X))
    y_ = img_torch2np(y)
    axes[2*i+0].imshow(X_)
    axes[2*i+1].imshow(y_)
    axes[2*i+0].axis('off')
    axes[2*i+1].axis('off')
plt.show()

In [None]:
args.encoder = 'swav'
args.decoder = 'unet'
args.fine_tune_encoder = True


pretrain = "geoNet_subset"
# pretrain = "imagenet"
swav_encoder_path = "/home/mh613/updatedswav/rgbpaaath/checkpoints/ckp-eval.pth"

if pretrain == "imagenet":  
    encoder = encoders.load("swav")
else:
    encoder = encoders._load_swav_pretrained(swav_encoder_path)

#the first layer of the decoder depends on encoder dimension
decoder = decoders.load(args.decoder, encoder)

TRAINING

In [None]:
#  whether we are fine-tuning encoder or not
args.fine_tune_encoder = False # True
if args.fine_tune_encoder:
    # Chain the iterators to combine them.
    params = list(encoder.parameters())+list( decoder.parameters())
else:
    params = decoder.parameters()

In [None]:
def set_device(d):
    if d == "auto":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    else:
        device = d
    return device

# set up where to train the model
global DEVICE
args.device  = 'cpu'
DEVICE = set_device(args.device)
print("Device is " + DEVICE)
encoder = encoder.to(DEVICE)
decoder = decoder.to(DEVICE)

# learning hyperparameters, and loss
args.lr = 1e-3
args.weight_decay  = 0.0
args.criterion = "softiou"
args.epochs = 10
args.batch_size = 20

# define data loader given batch size
train_loader = DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True)
valid_loader = DataLoader(
    valid_dataset, batch_size=args.batch_size, shuffle=False)

# set optimization algorithm
optimizer = torch.optim.Adam(
    params, lr=args.lr, weight_decay=args.weight_decay)

criterion = metrics.load(args.criterion, DEVICE)

In [None]:
def save_model(enc, dec, dump_path, name):
    torch.save(enc.state_dict(), os.path.join(dump_path, "enc_" + name))
    torch.save(dec.state_dict(), os.path.join(dump_path, "dec_" + name))
    
def train(loader, encoder, decoder, optimizer, criterion):

    if args.fine_tune_encoder:
        encoder.train()
    else:
        encoder.eval()

    decoder.train()
    criterion = criterion.to(DEVICE)
    avg_loss = utils.AverageMeter()
    num_batches = len(loader)
    for batch_idx, (inp, target) in enumerate(loader):
        if batch_idx % 10 == 0:
            print(f"Beginning batch {batch_idx} of {num_batches}")
        logging.debug(f"Training batch {batch_idx}...")
        # Move to the GPU
        inp = inp.to(DEVICE)
        target = target.to(DEVICE)

        if args.fine_tune_encoder:
            output = encoder(inp)
        else:
            with torch.no_grad():
                output = encoder(inp)

        output = decoder(output)
        loss = criterion(output, target)

        if batch_idx % 10 == 0:
            print(f"\t Train Loss: {loss.item()}")
        # Calculate the gradients
        optimizer.zero_grad()
        loss.backward()
        avg_loss.update(loss.item(), inp.size(0))
        # Step forward
        optimizer.step()

    return avg_loss.avg

@torch.no_grad()
def test(data_loader, encoder, decoder, criterion):

    encoder.eval()
    decoder.eval()
    criterion = criterion.to(DEVICE)
    avg_loss = utils.AverageMeter()
    for batch_idx, (inp, target) in enumerate(data_loader):
        # Move to the GPU
        if batch_idx % 10 == 0:
            print(f"Testing batch {batch_idx}")
        inp = inp.to(DEVICE)
        target = target.to(DEVICE)

        # Compute output
        output = decoder(encoder(inp))
        loss = criterion(output, target)
        avg_loss.update(loss.item(), inp.size(0))
        if batch_idx % 10 == 0:
            print(f"\t Test Loss: {loss.item()}")

    return avg_loss.avg

In [None]:
epoch_timer = utils.Timer()
monitor = utils.PerformanceMonitor(args.dump_path)
best_val_loss = float("inf")


for epoch in range(args.epochs):
        print(f"Beginning epoch {epoch}")
        logging.info(f"Beginning epoch {epoch}...")

        loss_train = train(train_loader, encoder, decoder, optimizer, criterion)
        monitor.log(epoch, "train", loss_train)
        
        loss_val = test(valid_loader, encoder, decoder, criterion)
        monitor.log(epoch, "val", loss_val)
        logging.info(
            f"Epoch {epoch} took {epoch_timer.minutes_elapsed()} minutes.")
        epoch_timer.reset()

        if loss_val < best_val_loss:
            logging.info("Saving model")
            save_model(encoder, decoder, args.dump_path, "best.pt")
            best_val_loss = loss_val
save_model(encoder, decoder, args.dump_path, "final.pt")
logging.info(f"Code completed in {overall_timer.minutes_elapsed()}.")

In [None]:
train_progress = pd.read_csv(os.path.join(args.dump_path,"performance.csv"))
fig,ax = plt.subplots()
ax.plot(train_progress.epoch[train_progress["stage"]=="train"],train_progress.loss[train_progress["stage"]=="train"],label="train")
ax.plot(train_progress.epoch[train_progress["stage"]=="val"],train_progress.loss[train_progress["stage"]=="val"],label="validation")
ax.set_xlabel("epoch")
ax.set_ylabel("Dice coefficient loss")
ax.legend()
plt.show()

In [None]:
bestOrFinal = "best"
weights_folder = args.dump_path #"results/field_delineation_10ep/" #
encoderWeights_path = os.path.join(weights_folder,f"enc_{bestOrFinal}.pt")
decoderWeights_path = os.path.join(weights_folder,f"dec_{bestOrFinal}.pt")

# load the weigths we saved
encoder_trained = encoders._load_swav_pretrained(encoderWeights_path)
decoder_trained = decoders.load(args.decoder,encoder)
decoder_trained.load_state_dict(torch.load(decoderWeights_path))
encoder_trained = encoder_trained.eval()
decoder_trained = decoder_trained.eval()

In [None]:
def plot_prediction(X0,y,pred,title):    
    fig,axes = plt.subplots(1,3)#))
    X_ = img_torch2np(invTrans(X0))
    y_ = img_torch2np(y)
    pred_np= img_torch2np(pred.detach().cpu().squeeze(0))
    axes[0].imshow(X_)
    axes[0].set_title("input")
    axes[1].imshow(y_)
    axes[1].set_title("true label")
    axes[2].imshow(pred_np)
    axes[2].set_title("predicted label")
    for i in range(3):
        axes[i].axis('off')
    fig.suptitle(title,y=0.8,fontsize=20)
    fig.tight_layout()
    plt.show()
    

for i in [1,5,7]:
    X0,y = train_dataset[i]
    X = torch.unsqueeze(X0,0)
    X = X.to(DEVICE)
    pred = decoder_trained(encoder_trained(X))
    pred_prob = torch.sigmoid(pred) # map predicted values to probabilities
    plot_prediction(X0,y,pred_prob,"Train")
    
for i in [1,5,7]:
    X0,y = valid_dataset[i]
    X = torch.unsqueeze(X0,0)
    X = X.to(DEVICE)
    pred = decoder_trained(encoder_trained(X))
    pred_prob = torch.sigmoid(pred) # map predicted values to probabilities
    plot_prediction(X0,y,pred_prob,"Validation")

In [None]:
def get_predictions(dataloader, encoder,decoder):
    preds = []
    targets = []
    for i, (img, mask) in enumerate(valid_loader):
        # Load through the model.
        img = img.to(DEVICE)
        mask = mask.to(DEVICE)
        with torch.no_grad():
            output = encoder(img)
            output = decoder(output)
            pred_prob = torch.sigmoid(output) #activation function sigmoid function
            preds.append(pred_prob .cpu().numpy())#.flatten())
            targets.append(mask.cpu().numpy())#.flatten())
            
    return np.array(preds),np.array(targets)

def get_dice_score(preds,targets,smooth = 1):
    # https://github.com/sustainlab-group/ParcelDelineation/blob/master/utils/metrics.py
    # https://discuss.pytorch.org/t/calculating-dice-coefficient/44154
    y_true_f = np.array(targets).flatten()
    y_pred_f =np.array(preds).flatten()
    intersection = np.sum(y_true_f * y_pred_f)

    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)


def get_IoU(preds,targets,thresh):
    preds_copy = np.array(preds).flatten()
    targets_f = np.array(targets).flatten()
    preds_copy[preds_copy>=thresh] = 1
    preds_copy[preds_copy<thresh] = 0
    intersection = np.logical_and(preds_copy,targets_f)
    union = np.logical_or(preds_copy,targets_f)
    iou_score = np.sum(intersection) / np.sum(union)
    
    return iou_score

In [None]:
# get test data and compute prediction

test_dataset =  fieldDelineationDataset(test_img_filepaths,test_mask_filepaths,
                                        transform=transform_norm)
test_loader = DataLoader(
    test_dataset, batch_size=args.batch_size, shuffle=False)


preds, targets = get_predictions(test_loader,encoder_trained,decoder_trained)
dice_score = get_dice_score(preds,targets,smooth=0.00001)
iou = get_IoU(preds,targets, 0.5)
print("Performanc on test data: Dice score %.3f and IoU %.3f"%(dice_score,iou))