# Ensemble predictions of DNNs for potentially dominant woody species in Ticino and averaging across yeardays


## Load necessary packages

In [2]:
# Import modules
import os
import torch
import numpy as np
import pandas as pd
import torch.nn.functional as F

# Custom module
import models

## Global definitions

In [3]:
# =============================================================================
# Definitions
# =============================================================================

out_tag = ''

mod_ndcg='Mod_LR_ndcg.pth'
mod_cent='Mod_LR_cent.pth'

## Parameter dictionnary for Fitting    
# Main
params = {}
params['batch_size'] = 500
params['num_workers'] = 2
params['device'] = 'cuda:0'

env_file = 'Env_data_Tici.csv'

# Read indices of woody species
woodyf = open('woody_cands2.txt', 'r').read().splitlines()
lili=[eval(i) for i in woodyf]
wootens=torch.tensor(lili)

## Define data set

In [4]:
class FullDataset(torch.utils.data.Dataset):
    """InfoFlora dataset as stored on Sauron."""

    def __init__(self, test_file):       
        """
        Args:
            csv_file (string): Path to the csv file .

        """
        feat_frame = pd.read_csv(test_file,low_memory=False)
               
        self.feat_frame = feat_frame

        
    def __len__(self):
        return len(self.feat_frame)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        feats = np.array(self.feat_frame.loc[idx,:])

        return  feats

## Create the Dataset

In [5]:
dase = FullDataset(test_file = env_file)

dataloader = torch.utils.data.DataLoader(dase, 
                                         batch_size=params['batch_size'], 
                                         shuffle=False, 
                                         num_workers=params['num_workers'])

## Prepare the model

In [None]:
map_location=torch.device(params['device'])
mod1=torch.load(mod_ndcg,map_location=map_location)
mod1.eval()

mod2=torch.load(mod_ndcg,map_location=map_location)
mod2.eval()

##  Define yearday data

In [7]:
doy=torch.arange(32,335) # Feb 1 to Nov 30

dnorm=(doy)/365 *2 -1

dcos=torch.cos(np.pi*dnorm)
dsin=torch.sin(np.pi*dnorm)

tdoy=torch.column_stack((dsin,dcos))

# Prepare indices for merging of yearday and env. variables
ind1=torch.arange(len(dcos))
ind2=torch.arange(params['batch_size'])

grid_x, grid_y = torch.meshgrid(ind1, ind2, indexing='ij')

gd_x=grid_x.reshape(grid_x.size()[0]*grid_x.size()[1])
gd_y=grid_y.reshape(grid_y.size()[0]*grid_y.size()[1])

## Do the predictions

In [None]:
# Create empty list
list_out = []

# Loop over dataloader and predict
for feats in dataloader:
    
    # Create an array with all combinations between env variables and yeardays (Feb 1 to Nov 30)
    if feats.size()[0] != params['batch_size']:
        
        ind2=torch.arange(feats.size()[0])

        grid_x, grid_y = torch.meshgrid(ind1, ind2, indexing='ij')

        gdi_x=grid_x.reshape(grid_x.size()[0]*grid_x.size()[1])
        gdi_y=grid_y.reshape(grid_y.size()[0]*grid_y.size()[1])
        stk=torch.column_stack((feats[gdi_y,:],tdoy[gdi_x,:]))
    
    else: 
        gdi_x = gd_x
        gdi_y = gd_y
        stk=torch.column_stack((feats[gd_y,:],tdoy[gd_x,:]))
    
    # Make predictions
    stk = stk.to(params['device'])
    out1 = mod1(stk).detach()
    out2 = mod2(stk).detach()
    
    prd1 = F.softmax(out1,dim=1)
    prd2 = F.softmax(out2,dim=1)
    
    del out1
    del out2
    
    # Ensemble both models by square root of geometric mean
    prdens = torch.sqrt(prd1 * prd2)
    
    del prd1
    del prd2

    li_px = []
    
    # Calculate pixel-wise time series statistics
    for i in ind2.tolist():
        
        pdi = prdens[gdi_y == i,:]
        
        ### average probabilities across yearday for each pixel
        md=torch.mean(pdi,0).detach()
        rnd=torch.round(md[wootens],decimals=4)
        li_px.append(rnd)
        
    tosta=torch.stack(li_px)
    list_out.append(tosta)  
    print("sali")

## Bind full list and save

In [None]:
# Create an array with averaged probabilities for each pixel (row) and woody species (column)
tab_out=torch.cat(list_out,0)

tab2 = tab_out.detach().cpu().numpy()

np.save("EcoImg_spatpred_woody_Ticino.npy", tab2)
