In [62]:
import sys
import numpy as np
import os
import cv2
import logging
from image_utils import standardize
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as Fu
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
from unet_val import UNet
import torch.backends.cudnn as cudnn
from dataset_generator_2 import DataLoader,Dataset_sat
from IOU_computations import *
from random import randint

In [63]:
##########
GLOBAL_PATH='MODEL_BASIC_TEST_120/'
##########

if not os.path.exists(GLOBAL_PATH):
            os.makedirs(GLOBAL_PATH)
        
#############
PATH_TRAINING='TRAINING/'
PATH_VALIDATION='VALIDATION/'
PATH_TEST='TEST/'

PATH_INPUT='INPUT/'
PATH_OUTPUT='OUTPUT/'
##############

        
INPUT_CHANNELS=9
OUTPUT_CHANNELS=2
NB_CLASSES=2

SIZE_PATCH=120
##############
MODEL_PATH_SAVE=GLOBAL_PATH+'RESUNET_pytorch_BASIC_test'
MODEL_PATH_RESTORE=''
TEST_SAVE=GLOBAL_PATH+'TEST_SAVE/'
if not os.path.exists(TEST_SAVE):
            os.makedirs(TEST_SAVE)
        
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(message)s')

##############

REC_SAVE=200#2000
DROPOUT=0.9#0.9
DEFAULT_BATCH_SIZE = 8#10
DEFAULT_EPOCHS = 1#50
DEFAULT_ITERATIONS =618#495
DEFAULT_VALID=100#100
DISPLAY_STEP=100#50

###############
DEFAULT_LAYERS=3
DEFAULT_FEATURES_ROOT=32
DEFAULT_FILTERS_SIZE=3
DEFAULT_LR=0.0001

####### TMP folder for IOU

TMP_IOU=TEST_SAVE+'TMP_IOU/'
if not os.path.exists(TMP_IOU):
            os.makedirs(TMP_IOU)

In [64]:
 class Trainer(object):
    """
    Trains a unet instance
    
    :param net: the unet instance to train
    :param batch_size: size of training batch
    :param lr: learning rate
    """
    def __init__(self, net, batch_size=10, lr=0.0001,nb_classes=2):
        self.net = net
        self.batch_size = batch_size
        self.lr = lr
        self.nb_classes=nb_classes
    def _initialize(self, prediction_path):
        
        self.optimizer = optim.Adam(self.net.parameters(),lr=self.lr)
        self.prediction_path = prediction_path
        
    
    def train(self, data_provider_path, save_path='', restore_path='', training_iters=4, epochs=3, dropout=0.9, display_step=1, validation_batch_size=30,rec_save=1, prediction_path = '',data_aug=None):
        """
        Lauches the training process
        
        :param data_provider_path: where the DATASET folder is
        :param save_path: path where to store checkpoints
        :param restore_path: path where is the model to restore is stored
        :param training_iters: number of training mini batch iteration
        :param epochs: number of epochs
        :param dropout: dropout probability
        :param display_step: number of steps till outputting stats
        :param restore: Flag if previous model should be restored 
        :param prediction_path: path where to save predictions on each epoch
        """
        
        PATH_TRAINING=data_provider_path+'TRAINING/'
        PATH_VALIDATION=data_provider_path+'VALIDATION/'
        PATH_TEST=data_provider_path+'TEST/'
        
        loss_train,file_train,loss_verif,file_verif,IOU_verif,IOU_file_verif,IOU_acc_verif,IOU_acc_file_verif,f1_IOU_verif,f1_IOU_file_verif=save_metrics(epochs,training_iters,TEST_SAVE,'a')
        
        if epochs == 0:
            return save_path
        if save_path=='':
            return 'Specify a path where to store the Model'
        self._initialize(prediction_path)
            
        if restore_path=='':
            print('Model trained from scratch')
        else:            
            self.net.load_state_dict(torch.load(restore_path))
            print('Model loaded from {}'.format(restore_path))
          
        
        val_generator=Dataset_sat.from_root_folder(PATH_VALIDATION,self.nb_classes)
        val_loader = DataLoader(val_generator, batch_size=validation_batch_size,shuffle=False, num_workers=4)
        RBD=randint(0,int(val_loader.__len__() /validation_batch_size))
        self.store_init(val_loader,"_init",validation_batch_size,RBD)
        
        train_len = self.batch_size*training_iters
        train_generator=Dataset_sat.from_root_folder(PATH_TRAINING,self.nb_classes)
        

        logging.info("Start optimization")

        counter=0
        
        for epoch in range(epochs):
            total_loss = 0
            train_loader = DataLoader(train_generator, batch_size=self.batch_size,shuffle=True, num_workers=4)
            for i_batch,sample_batch in enumerate(train_loader):
                batch_x=standardize(sample_batch['input'])
                batch_y=sample_batch['groundtruth']
                _,loss=predict(self.net,batch_x,batch_y)
                total_loss+=loss.data[0]

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                counter+=1
                if step % display_step == 0:
                    self.output_minibatch_stats(i_batch,batch_x,batch_y)
                if counter % rec_save == 0:
                    torch.save(self.net.state_dict(),save_path + 'CP{}.pth'.format(counter))
                    print('Checkpoint {} saved !'.format(counter))

                
            error_rate_v,loss_v,iou_v,iou_acc_v,f1_v=self.store_validation(val_loader, "epoch_%s"%epoch,validation_batch_size,RBD,True)
            IOU_verif[epoch]=iou_v
            IOU_acc_verif[epoch]=iou_acc_v
            f1_IOU_verif[epoch]=f1_v
            loss_verif[epoch]=loss_v
            
            IOU_file_verif.write(str(IOU_verif[epoch])+'\n')
            IOU_acc_file_verif.write(str(IOU_acc_verif[epoch])+'\n')
            f1_IOU_file_verif.write(str(f1_IOU_verif[epoch])+'\n')
            file_verif.write(str(loss_verif[epoch])+'\n')
            
            print(" Loss {:.4f}, Error rate {:.4f} ,Validation IoU {:.4f}, Validation IoU_acc {:.4f}%,Validation F1 IoU {:.4f}%".format(loss_v,error_rate_v,iou_v,iou_acc_v,f1_v))
    
    def output_minibatch_stats(self, step, batch_x, batch_y):
        # Calculate batch loss and accuracy
        predictions,loss=self.predict(self.net,batch_x,batch_y)
        loss=loss.data[0]
        predictions=predictions.data.cpu().numpy()
        groundtruth=np.asarray(batch_y)
        logging.info("Iter {:}, Minibatch Loss= {:.4f}, Minibatch error= {:.1f}%".format(step,loss,error_rate(predictions, groundtruth)))
   
    def store_train(self,train_loader,name,training_batch_size):
        print('nada')
        
        
    def store_init(self,val_loader,name,validation_batch_size,random_batch_display,*,save_patches=True):
        loss_v=0
        error_rate_v=0
        for i_batch,sample_batch in enumerate(val_loader):
            batch_x=standardize(sample_batch['input'])
            probs,loss=predict(self.net,batch_x,sample_batch['groundtruth'])
            loss_v+=loss.data[0]
            prediction_v=probs.data.cpu().numpy()
            groundtruth=np.asarray(sample_batch['groundtruth'])
            error_rate_v+=error_rate(prediction_v,groundtruth)
            if i_batch==random_batch_display and save_patches:
                batch_x=np.asarray(batch_x)
                pansharp=np.stack((batch_x[:,:,:,5],batch_x[:,:,:,3],batch_x[:,:,:,2]),axis=3)
                plot_summary(prediction_v,groundtruth,pansharp,name,self.prediction_path,save_patches)
        loss_v/=val_loader.__len__()   
        error_rate_v/=val_loader.__len__()
        logging.info("Verification  loss= {:.4f},error= {:.1f}%".format(loss_v,error_rate_v))
        
    
    def store_validation(self,val_loader, name,validation_batch_size,random_batch_display,*,save_patches=True):
        loss_v=0
        iou_v=0
        iou_acc_v=0
        f1_v=0
        error_rate_v=0

        for i_batch,sample_batch in enumerate(val_loader):
            probs,loss=predict(self.net,sample_batch['input'],sample_batch['groundtruth'])
            loss_v+=loss.data[0]
            
            prediction_v=probs.data.cpu().numpy()
            groundtruth=np.asarray(sample_batch['groundtruth'])
            iou_acc,f1,iou=predict_score_batch(TMP_IOU,np.argmax(groundtruth,3),np.argmax(prediction_v,3))
            iou_acc_v+=iou_acc
            iou_v+=iou
            f1_v+=f1
            error_rate_v+=error_rate(prediction_v,groundtruth)
            if i_batch==random_batch_display and save_patches:
                batch_x=np.asarray(sample_batch['input'])
                pansharp=np.stack((batch_x[:,:,:,5],batch_x[:,:,:,3],batch_x[:,:,:,2]),axis=3)
                plot_summary(prediction_v,groundtruth,pansharp,name,self.prediction_path,save_patches)

        loss_v/=val_loader.__len__()  
        iou_v/=val_loader.__len__()
        iou_acc_v/=val_loader.__len__()
        f1_v/=val_loader.__len__()
        error_rate_v/=val_loader.__len__()

        logging.info("Verification  loss= {:.4f},error= {:.1f}%, IOU = {:.4f}, IOU Precision = {:.4f}%, F1 IOU= {:.4f}%".format(loss_v,error_rate_v,iou_v,iou_acc_v,f1_v))

        return error_rate_v,loss_v,iou_v,iou_acc_v,f1_v

In [65]:
def predict(net,batch_x,batch_y):
    X=batch_x.permute(0,3,1,2)
    X = Variable(X).type(torch.FloatTensor).cuda()
    Y=batch_y.permute(0,3,1,2)
    Y = Variable(Y).type(torch.FloatTensor).cuda()

    y_pred=net(X)
    probs = Fu.softmax(y_pred,dim=1)
    loss=Fu.binary_cross_entropy_with_logits(probs,Y)
    probs=probs.permute(0,2,3,1)
    return probs,loss

def save_metrics(epochs,training_iters,prediction_path,mode):
    #STORE loss for ANALYSIS
    loss_train=np.zeros(training_iters*epochs)
    file_train = open(prediction_path+'loss_train.txt',mode) 
    loss_verif=np.zeros(epochs)
    file_verif = open(prediction_path+'loss_verif.txt',mode) 
    #STORE IOU for ANALYSIS
    IOU_verif=np.zeros(epochs)
    IOU_file_verif = open(prediction_path+'iou_verif.txt',mode)
    #STORE IOU_ACC for ANALYSIS
    IOU_acc_verif=np.zeros(epochs)
    IOU_acc_file_verif = open(prediction_path+'iou_acc_verif.txt',mode)
    #STORE f1_IOU for ANALYSIS
    f1_IOU_verif=np.zeros(epochs)
    f1_IOU_file_verif = open(prediction_path+'f1_iou_verif.txt',mode) 
    
    return loss_train,file_train,loss_verif,file_verif,IOU_verif,IOU_file_verif,IOU_acc_verif,IOU_acc_file_verif,f1_IOU_verif,f1_IOU_file_verif
def error_rate(predictions, labels):
    """
    Return the error rate based on dense predictions and 1-hot labels.
    """

    return 100.0 - (
        100.0 *
        np.sum(np.argmax(predictions, 3) == np.argmax(labels, 3)) /
        (predictions.shape[0]*predictions.shape[1]*predictions.shape[2]))
def plot_summary(predictions,labels,pansharp,epoch,prediction_path,save_patches):
    
    fig,axs=plt.subplots(3, len(pansharp),figsize=(8*len(pansharp),24))

    axs[0,0].set_title(epoch+' Pansharpened ', fontsize='large')
    axs[1,0].set_title(epoch+' Groundtruth ', fontsize='large')
    axs[2,0].set_title(epoch+' Predictions ', fontsize='large')

    labels=np.argmax(labels, 3) 
    logits=np.argmax(predictions, 3)
    for i in range(len(pansharp)):

        axs[0,i].imshow(pansharp[i])
        axs[1,i].imshow(labels[i]) 
        axs[2,i].imshow(logits[i])
        
        
        if save_patches:
            plt.imsave(prediction_path+epoch+'_Panchro_'+str(i)+'.jpg',pansharp[i])
            plt.imsave(prediction_path+epoch+'_Groundtruth_'+str(i)+'.jpg',labels[i])
            plt.imsave(prediction_path+epoch+'_Predictions_'+str(i)+'.jpg',1-logits[i])

    plt.subplots_adjust()
    plt.show()

In [None]:
if __name__ == '__main__':

    
    model=UNet(INPUT_CHANNELS,NB_CLASSES,DEFAULT_LAYERS,DEFAULT_FEATURES_ROOT,DROPOUT)
    model.cuda()
    cudnn.benchmark = True
#     root_folder ='/scratch/SPACENET_DATA_PROCESSED/DATASET/120_x_120_8_bands_pansh/'
    root_folder = '../DATA_GHANA/DATASET/120_x_120_8_bands/'
    trainer=Trainer(model,DEFAULT_BATCH_SIZE,DEFAULT_LR,NB_CLASSES)
    trainer.train( root_folder, MODEL_PATH_SAVE, MODEL_PATH_RESTORE, DEFAULT_ITERATIONS,DEFAULT_EPOCHS,DROPOUT, DISPLAY_STEP, DEFAULT_VALID,REC_SAVE, TEST_SAVE)
    

Model trained from scratch
