In order to train the LSTM branch, first we need to import libraries.

In [1]:
import os
import h5py
import datetime
from tensorflow.keras.applications import VGG16
from tensorflow.keras import Input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Conv2D
from tensorflow.keras.layers import Flatten
from tensorflow.keras.callbacks import ModelCheckpoint

Now, we define some useful functions.

In [2]:
def smooth_curve(points,factor=0.9):
    smoothed_points=[]
    for point in points:
        if smoothed_points:
            previous=smoothed_points[-1]
            smoothed_points.append(previous*factor+point*(1-factor))
        else:
            smoothed_points.append(point)
    return smoothed_points

def plot_loss(history, directory, bias=0, cols=2):
    # summarize history for accuracy
    plt.figure(figsize=(20,7))
    keys=list(history.keys())
    rows=math.ceil(len(history.keys())/4)
    counter=math.ceil(len(history.keys())/2)

    for i in range(counter):
        plt.subplot(rows,cols,i+1)
        plt.plot(smooth_curve(history[keys[i]][bias:]))
        plt.plot(smooth_curve(history[keys[i+counter]][bias:]))
        plt.ylabel('metric')
        plt.xlabel('epoch')
        plt.legend([keys[i],keys[i+counter]], loc='upper right')
        
    plt.show()
    file=os.path.join(directory,'training_metrics.png')
    plt.savefig(file)
    
def preprocess(data_cube, pixels_per_sample=1, average='True', dim=3):
    #Data preprocessing
    idx=[]
    data_prepro=[]
    data_ready=np.reshape(data_cube,(-1,150))
    for i in range(len(data_ready)):
        if data_ready[i].all()==0: idx.append(i)
    #delete black pixels
    data_ready=np.delete(data_ready,idx,axis=0)
    #reshape data
    if  average=='Full':#average over ALL pixels
        data_prepro=np.mean(data_ready,axis=0,keepdims=True)
        data_prepro=np.nan_to_num(data_prepro)
        data_prepro=np.reshape(data_prepro,(1,150,1))
    else:
        for j in range(data_ready.shape[0]//pixels_per_sample):
            if average=='False' or pixels_per_sample==1:
                data_prepro.append(data_ready[j:j+pixels_per_sample])            
            else:
                #Calculates average of the pixels if needed
                data_aux=np.mean(data_ready[j:j+pixels_per_sample],axis=(0))
                data_aux=np.reshape(data_aux,(1,-1))
                data_prepro.append(data_aux)
                
        data_prepro=np.asarray(data_prepro)
    
        if dim==3: 
            data_prepro=np.swapaxes(data_prepro,1,2)
        elif dim==2:
            data_prepro=np.reshape(data_prepro,(-1,150))
        else:
            print('Wrong data dimensions selected.')
    
    return data_prepro

    
#Data generator
def hs_generator(h5f, gt='moisture', mode='train',
                 pixels_per_sample=1, average='True', dim=3, full_sample=False,
                 scaled=False, box_size=10):
    
    while True:
        
        for key in h5f.keys():
            data_temp=h5f[key]
            print('\nLoading group: ', key)
            
            for key in data_temp.keys():
                rand_split=np.random.randint(100)
                        
                if mode=='train':
                    if key[0]!='4' and rand_split<=80: #80% for training
                        #print('\nTraining set. Loading HS image from plot: ', key)
                        hs_temp=data_temp[key]
                        #Obtains the target value
                        for key in hs_temp.attrs.keys():
                            if key==gt:
                                print('{} => {}'.format(key, hs_temp.attrs[key]))
                                target=hs_temp.attrs[key]
                                
                         #Gets the data in the hdf5 file
                        sample_hs=hs_temp[()]
                        if full_sample==True:
                            sample=hs_temp[()]
                            if scaled==True:
                                small_sample=[]
                                for i in range(int(np.floor(sample.shape[0]/box_size))):
                                    for j in range(int(np.floor(sample.shape[1]/box_size))):
                                        min_box=sample[i*box_size:(i+1)*box_size,j*box_size:(j+1)*box_size,:]
                                        small_sample.append(min_box)
                                sample=np.asarray(small_sample)
                        else:
                            sample=preprocess(sample_hs,pixels_per_sample,average,dim)  
                            
                        if gt=='treatment':
                            if target[:4]=='L233':
                                target=0
                            else:
                                target=1
                        
                        y=np.asarray([target for i in range(sample.shape[0])])
                        
                        yield sample,y
                                                
                elif mode=='val':
                    if key[0]!='4' and rand_split>80: #20% for validation
                        #print('\nValidation set. Loading HS image from plot: ', key)
                        hs_temp=data_temp[key]
                        #Obtains the target value
                        for key in hs_temp.attrs.keys():
                            if key==gt:
                                print('{} => {}'.format(key, hs_temp.attrs[key]))
                                target=hs_temp.attrs[key]
                                
                         #Gets the data in the hdf5 file
                        sample_hs=hs_temp[()]
                        if full_sample==True:
                            sample=hs_temp[()]
                            if scaled==True:
                                small_sample=[]
                                for i in range(int(np.floor(sample.shape[0]/box_size))):
                                    for j in range(int(np.floor(sample.shape[1]/box_size))):
                                        min_box=sample[i*box_size:(i+1)*box_size,j*box_size:(j+1)*box_size,:]
                                        small_sample.append(min_box)
                                sample=np.asarray(small_sample)
                        else:
                            sample=preprocess(sample_hs,pixels_per_sample,average,dim)  
                        
                        if gt=='treatment':
                            if target[:4]=='L233':
                                target=0
                            else:
                                target=1
                                
                        y=np.asarray([target for i in range(sample.shape[0])])
                        
                        yield sample,y
                                            
                else:
                    ValueError    

Now, we define some variables for this model.

In [3]:
gt='moisture'
num_epochs=2000
box_size=20
optimizer='Adadelta'
out_layer='block3_pool'

Then we define some directories and load the data.

In [5]:
main_dir='/your/directory/'
data_dir=os.path.join(main_dir,'data_hs')

saving_dir=os.path.join(main_dir,'results_github','canola_hsi')
string_dir='mkdir -p '+saving_dir
os.system(string_dir)

#%%
#Load data and labels
file_name_h5=os.path.join(data_dir,'NUE_canola_hsi_dataset.hdf5')
h5f=h5py.File(file_name_h5,'r')

print('\nOpening dataset...')

print('\nGeneral information about the dataset:')
for i in h5f.attrs.keys():
      print('{} => {}'.format(i, h5f.attrs[i]))

print('\nGroups contained in dataset:')
for key in h5f.keys():
    print(key)


Opening dataset...

General information about the dataset:
creator => Julio Torres-Tello
institution => University of Saskatchewan
year => 2020
crop => canola
bands => 400-1000nm
scanner => Corning microHSI SHARK

Groups contained in dataset:
03092019_DAS95
09092019_DAS101
16082019_DAS77
16092019_DAS108
20082019_DAS81
25092019_DAS117
27082019_DAS88
30082019_DAS91


In [6]:
#Data generators
train_gen_plot=hs_generator(h5f, gt, mode='train', full_sample=True, scaled=True, box_size=box_size)
val_gen_plot=hs_generator(h5f, gt, mode='val', full_sample=True, scaled=True, box_size=box_size)

Now, we implement the spatial (VGG16-pretrained based) model.

In [7]:
#VGG module
vgg_block=VGG16(weights='imagenet', include_top=False)
deep_model=Model(inputs=vgg_block.input,outputs=vgg_block.get_layer(out_layer).output)

# Create model

input_spatial=Input(shape=(box_size, box_size, 150),name='input_spatial')
spatial=Conv2D(3,kernel_size=(1,1),activation='relu',
               data_format='channels_last')(input_spatial)
spatial=deep_model(spatial)

fc=Flatten()(spatial)
fc=Dense(120,activation='relu')(fc)
fc=Dense(84,activation='relu')(fc)
output=Dense(1,activation='linear')(fc)

model=Model(input_spatial,output)

deep_model.trainable=False
model.compile(optimizer=optimizer,loss='mape',
              metrics=['mae','mse'])

print("\nModel created successfully!\n")
model.summary()


Model created successfully!

Model: "model_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_spatial (InputLayer)   [(None, 20, 20, 150)]     0         
_________________________________________________________________
conv2d (Conv2D)              (None, 20, 20, 3)         453       
_________________________________________________________________
model (Model)                multiple                  1735488   
_________________________________________________________________
flatten (Flatten)            (None, 1024)              0         
_________________________________________________________________
dense (Dense)                (None, 120)               123000    
_________________________________________________________________
dense_1 (Dense)              (None, 84)                10164     
_________________________________________________________________
dense_2 (Dense)              

Now, we train the model. In this example, we used only 5 epochs; however, the correct value was already mentioned.

In [8]:
callbacks_list=[ModelCheckpoint(os.path.join(saving_dir,'model_lstm.h5'),save_best_only=True)]

#Training the model
init_time=datetime.datetime.now()
print('Model starts training at:',init_time)   
reg=model.fit(train_gen,epochs=5,verbose=False,shuffle=True,
                        steps_per_epoch=308,validation_data=val_gen,
                        validation_steps=77,callbacks=callbacks_list)
end_time=datetime.datetime.now()
print('Model finishes training at:',end_time)

total_time=end_time-init_time
print('Total duration of model training:',total_time)

print("\nTraining finished correctly!\n")

#closing data file
h5f.close()

Model starts training at: 2021-06-24 20:46:10.717611

Model trained.

Model finishes training at: 2021-06-24 20:46:10.717820
Total duration of model training: 0:00:00.000209

Training finished correctly!

