# Following the exercises from https://pawarlab.slack.com/files/U06GE8MMN49/F07866US999/proposal_pipeline_randomforest_mgrainger_mscthesis.odt

## Importing data and packages

In [None]:
# Data
asv_table <- read.csv("../data/seqtable_readyforanalysis.csv", sep = "\t")
meta_data <- read.csv("../data/metadata_Time0D-7D-4M_May2022_wJSDpart_ext.csv", sep = "\t")
cluster_data <- read.csv("../data/max_tot_ext_network_table.tsv", sep = "\t")
function_data <- read.csv("../data/20151016_Functions_remainder.csv")

In [None]:
function_data

In [None]:
meta_data

In [None]:
asv_table

## Correcting data
Steps:
1. Change units of function data, normalise, and take log.
2. Add column for the logarithm of the total abundance of each sample in the ASV table, as a way of later controlling the number of reads.

In [None]:
# Changing units
function_data$mgCO2.7 <- function_data$mgCO2.7 * 1000 # Converting CO2 from miligr. to microgr.
names(function_data)[names(function_data) == "mgCO2.7"] <- "μgCO2.7" # Changing column name accordingly
function_data$ATP7 <- function_data$ATP7 / 1000 # Changing nanomolar to micromolar
function_data$ATP14 <- function_data$ATP14 / 1000

# Normalising by number of cells
function_data$ATP7.norm <- function_data$ATP7 / function_data$CPM7
function_data$mG7.norm <- function_data$mG7 / function_data$CPM7
function_data$mN7.norm <- function_data$mN7 / function_data$CPM7
function_data$mX7.norm <- function_data$mX7 / function_data$CPM7
function_data$mP7.norm <- function_data$mP7 / function_data$CPM7
function_data$μgCO2.7.norm <- function_data$μgCO2.7 / function_data$CPM7

# Taking log of normalised functions and cell count
function_data$log.ATP7.norm <- log(function_data$ATP7.norm + 0.00001)
function_data$log.mG7.norm <- log(function_data$mG7.norm + 0.00001)
function_data$log.mN7.norm <- log(function_data$mN7.norm + 0.00001)
function_data$log.mX7.norm <- log(function_data$mX7.norm + 0.00001)
function_data$log.mP7.norm <- log(function_data$mP7.norm + 0.00001)
function_data$log.μgCO2.7.norm <- log(function_data$μgCO2.7.norm + 0.00001)

# Adding column for log of total abundance of each sample
asv_table$log.reads <- log(rowSums(asv_table))

# Writing modified data to files
write.csv(function_data, "../data/20151016_Functions_remainder_corrected.csv", row.names = FALSE)



In [None]:
asv_table

## Exercise 1
Random forest code copied from random_forest_cfl.R and then modified to match the exercises.
Exercise 1 includes the following 8 random forests:
1. ASVs and reads as predictors, log normal atp response. Training data is one replicate, testing data is the other 3.
2. ASVs and reads as predictors, log normal xylo response. Training data is one replicate, testing data is the other 3.
3. ASVs and reads as predictors, log normal atp response. Training data is across all replicates, testing data is across all replicates.
4. ASVs and reads as predictors, log normal xylo response. Training data is across all replicates, testing data is across all replicates.
5. Same as all of these 4 but with clusters instead of ASVs.

Exercise 1A Steps:
1. Making predictor data frame with the ASV abundances of the final experiment (all 4 replicates) and the logarithm of the total abundance of each sample.
2. Making univariate response data frames. One has the logarithm of normalised ATP7, and one has the logarithm of normalised xylosidase.
3. Conduct 2 univariate random forest regressions, one for each of these response data frames.
4. Identify the most important ASVs contributing to the function as those with high 
mse_reduction and Gini index. See this tutorial (multi-way importance): 
https://modeloriented.github.io/randomForestExplainer/articles/randomForestExplainer.html. 
5. Check where in network these ASVs are.
6. Investigate interactions between a set of important ASVs and compare them with your 
network (see in the above tutorial “variable interactions”). Are RF interactions in 
correspondence with FlashWeave associations?

Exercise 1B Steps:
1. Making multivariate response data frame with the logarithm of the following normalized functions: ATP, X,
G, P, N.
2. Same thing as in 1A, except with multivariate RF regression.


In [None]:
#

In [None]:
###################################
#     random_forest_cfl.R  
##################################
# Author: Alberto Pascual-García
# Copyright (c)  Alberto Pascual-García,  2024
# Web:  apascualgarcia.github.io
# 
# Date: 2024-05-13
# Script Name: random_forest.R   


rm(list = ls()) # Clear workspace
######### START EDITING ------------

# --- Random forest parameters
optimization=1 # Should RF parameters be optimized? (=1)
run_RF=1 # Should RF be run (=1) or just read from file (=0)
type_RF="regression" # Should RF be a "classification" or a "regression"
ntree.min=1000 # number of trees, mandatory if optimization = 0 (and only used in that case)
mtry.min="default" # number of variables randomly selected to build the trees. Fix to "default" 
# if you don't have an informed guess
partial.plots=0 # should partial plots be generated (=1), fix to 0 otherwise. It will
# generate plots for the 10 most important variables. This is controlled by
# variable Nsel in the section plots below
select.func="ATP" # determine the function you want to predict. Name should match those from
                # the column name of the function data frame.
select.cond=c("7") # determine the function conditions you want to predict. Names should 
                      #match those used to differentiate the different conditions within the function df.
funcxcond=1 #number of total combinations of function and condition 
log.func=TRUE #logical indicating whether the values of the function should be logged for the analysis


# SET WORKING DIRECTORY -----------------------------

select.class = select.func

# --- Input and output files
file.ASV = "seqtable_readyforanalysis.csv"
file.taxa = "taxa_wsp_readyforanalysis.csv"
file.Meta = "metadata_Time0D-7D-4M_May2022_wJSDpart_ext.csv"
file.func = "20151016_Functions_remainder.csv"

# --- Directories
# this.dir=strsplit(rstudioapi::getActiveDocumentContext()$path, "/src/")[[1]][1] # don't edit, just comment it if problems...
this.dir=strsplit(rstudioapi::getActiveDocumentContext()$path, "/code/")[[1]][1]
dirSrc=paste(this.dir,"/code/",sep="") # Directory where the code is
dirASV=paste(this.dir,"/data/",sep="") # Dir of ASV table
dirMeta=paste(this.dir,"/data/",sep="") # Dir of metadata
dirFunc=paste(this.dir,"/data/",sep="") #Dir of response function data frame
dirOut=paste(this.dir,"/results/fnl_random_forest",sep="") # Dir of output data

# --- Packages

packages <- c("tidyverse", "stringr", "ggplot2",
              "caret", # additional RF functions
              "randomForest", "randomForestExplainer") # list of packages to load

scripts <- c("clean_ASV_table.R") # list of functions to load

# --- Reshaping data parameters
nreads = 10000 # minimum number of reads to consider a sample
exclude_exp = c("4M") # A vector of characters with the experiments that should be excluded
match_exp = TRUE # Set to true if only starting communities that were resurrected should be included
output.label = "Time0D_7D_matched" 

###### STOP EDITING -------------

# INSTALL PACKAGES & LOAD LIBRARIES -----------------
cat("INSTALLING PACKAGES & LOADING LIBRARIES... \n\n", sep = "")

n_packages <- length(packages) # count how many packages are required

new.pkg <- packages[!(packages %in% installed.packages())] # determine which packages aren't installed

# install missing packages
if(length(new.pkg)){
  install.packages(new.pkg)
}

# load all requried libraries
for(n in 1:n_packages){
  cat("Loading Library #", n, " of ", n_packages, "... Currently Loading: ", packages[n], "\n", sep = "")
  lib_load <- paste("library(\"",packages[n],"\")", sep = "") # create string of text for loading each library
  eval(parse(text = lib_load)) # evaluate the string to load the library
}
# SOURCE FUNCTIONS ---------
setwd(dirSrc)
n_scripts <- length(scripts) # count how many packages are required

for(n in 1:n_scripts){
  cat("Loading script #", n, " of ", n_scripts, "... Currently Loading: ", scripts[n], "\n", sep = "")
  lib_load <- paste("source(\"",scripts[n],"\")", sep = "") # create string of text for loading each library
  eval(parse(text = lib_load)) # evaluate the string to load the library
}









# READ INPUT FILES ----------
# --- Read ASVs table
setwd(dirASV)
ASV.table=read.table(file = file.ASV, sep="\t")
colnames(ASV.table)[1:5]
rownames(ASV.table)[1:5]
dim(ASV.table)
head(ASV.table)[1:5,1:5]

# ..... read metadata. Samples present in metadata were those passing the filtering
setwd(dirMeta)
sample_md <-read.table(file = file.Meta, sep="\t", header=TRUE)
head(sample_md)[1:5,1:5]

# --- Read function data frame
setwd(dirMeta)
data.func <- read.csv(file = file.func)
head(data.func)[1:5,1:5]

# # CLEAN DATA ---------- 
# 
# # Clean data   ------
# clean.data.list = clean_ASV_table(ASV.table,sample_md,match.exp = T)
# 
# ASV.table = clean.data.list$ASV.table
# sample_md = clean.data.list$sample_md
# 
# head(ASV.table)[1:5,1:5]


# BUILD REGRESSION DATA ---------- 

# .... Build the table of predictors, first ASVs (abundances or presence/absence).
# Here we will have a table with parent's ids and ASVs repeated six times
# we will regress these variables to the response (function value)
id.t7 = grep("7D", sample_md$Experiment)
samples.t7 = as.character(sample_md$sampleid[id.t7])

xIn.tmp = ASV.table[samples.t7, ]
xIn.tmp$counts = rowSums(xIn.tmp) # add the sum of the counts for each sample as an additional predictor
xIn.tmp$counts = log(xIn.tmp$counts) # log it to avoid having too high numbers
xIn.tmp$sampleid = rownames(xIn.tmp) # create a new column, rownames cannot be repeated

xIn = xIn.tmp[rep(seq_len(nrow(xIn.tmp)), funcxcond), ] # we repeat xIn.tmp as many times as combination
                                                        # between function and condition are available

rownames(xIn) = seq(1,dim(xIn)[1]) 
xIn = xIn[,c(dim(xIn)[2],1:(dim(xIn)[2]-1))] # reorder
head(xIn)[1:5,1:5]


# # .... Now add the replicate as an additional envir predictor   ###We consider each replicate as an independent sample
# replicate = c(rep(1, length(id.t0)), rep(2, length(id.t0)),
#               rep(3, length(id.t0)),rep(4, length(id.t0)))
# xIn$replicate = replicate

# We add as additional predictors variables that give information about the function and the conditions
# But first, we have to adapt the response data frame data.func
functions.list <- colnames(data.func)
data.func.x <- data.frame(data.func[,grep(select.func, functions.list), drop = FALSE])

# .... Now add the functional conditions as an additional predictor.
# .... In our case the sole condition is growth time
functions.list <- colnames(data.func.x)
data.func.x <- data.func.x[,grep(select.cond, functions.list), drop = FALSE]
#############################################################################################################

# Change data.func.x into the long format
source(paste0(dirSrc, "reshape_df_to_ggplot.R"))
data.func.x$sampleid <- paste0(data.func$Community, ".", data.func$Replicate)
sample.id <- intersect(data.func.x$sampleid, xIn.tmp$sampleid) 
# IMPORTANT NOTE: Only 1035/1402 samples in xIn.tmp are present in data.func
#which(data.func.x$sampleid == sample.id)
data.func.x <- data.func.x[data.func.x$sampleid %in% sample.id, ]

x = data.func.x$sampleid


ex.df = data.func.x[,which(colnames(data.func.x) == "sampleid"), drop = FALSE]
col.names <- c("x")

for (i in seq(1, dim(data.func.x)[2])) {

  if (colnames(data.func.x)[i] == "sampleid") {
    
  } else {
    col.names <- c(col.names, paste0("y", i))
    ex.df = cbind(ex.df, data.func.x[i])
  }
}

colnames(ex.df) <- col.names

x.vec <- col.names[1]
y.vec <- col.names[-1]

char1 = rep(select.func, each = length(select.cond))
char2 = rep(select.cond, times = length(select.func))
  
xIn.2 = reshape_df_to_ggplot(ex.df,x.vec = x.vec,y.vec = y.vec,
                                              char.list = list(char1,
                                                               char2))

colnames(xIn.2) <- c("sampleid", "y.in", "v1", "v2")


#Combine all the predictors from xIn and xIn.2
xIn <- xIn[xIn$sampleid %in% sample.id, ]

xIn <- data.frame(xIn, xIn.2$v1, xIn.2$v2)

# .... Extract response variable
# # Each sample will have a function response according to the function and condition variables

yIn = xIn.2$y.in
yIn = as.numeric(yIn)

if (log.func == TRUE) {
  yIn = yIn + 0.0001 #We add an insignificant value to all of the samples to prevent 
  #the appearance of -Inf values after calculating the log
  yIn = log(yIn)
  
} else {
  
}

#quantile(yIn)

# ... Finally, drop sampleid names
xIn = subset(xIn, select = -c(sampleid))
head(xIn)[1:5,1:5]

# RANDOM FOREST computation -------------
set.seed(3032024) # today

# --- First estimate the optimal RF parameters:
# the two steps (ntree and mtry) should possibly be iterated
dir.create(dirOut)
setwd(dirOut)
mtry.default=sqrt(dim(xIn)[2]) # default number of variables taken to build the trees
if(optimization == 1){
  range="large"
  Nrand=25
  if(range == "small"){
    ntree.test=seq(from=100,to=500,by=20) # small range
  }else{
    ntree.test=c(100,250,500, seq(from=1000,to=11000,by=2000)) # large
    ntree.test=c(ntree.test,14000,18000,22000) # very large
    #ntree.test=c(ntree.test,12000,14000,16000,18000,20000) # very large
  }
  mtry.test=mtry.default #300 # if different than default this was fixed after one iteration with the next step below
  #OOB=vector(mode="numeric",length=length(ntree.test))
  OOB=matrix(0,nrow=length(ntree.test),ncol=Nrand)
  i=0
  for(ntree.tmp in ntree.test){
    i=i+1
    for(k in 1:Nrand){
      RF.tmp <-randomForest(y=yIn,x=xIn,
                            importance=T, proximity = T, 
                            ntree=ntree.tmp,mtry = mtry.test)
      
      if (type_RF == "classification") {
        OOB[i,k]=mean(RF.tmp$err.rate[,1])
      } else if (type_RF == "regression") {
        OOB[i,k]=mean(RF.tmp$mse)
      } else {
        warning("type not recognised")
      }
      
    }
  }
  OOB.mean=rowMeans(OOB)
  OOB.std=apply(OOB,1, sd, na.rm = TRUE)
  OOB.df=data.frame(cbind(ntree.test,OOB.mean,OOB.std))
  #rownames(OOB.df)=ntree.test
  OOB.min.id=which.min(OOB.mean)
  ntree.min=OOB.df$ntree.test[OOB.min.id]
  ymin=OOB.df$OOB.mean-OOB.df$OOB.std/sqrt(Nrand)
  ymax=OOB.df$OOB.mean+OOB.df$OOB.std/sqrt(Nrand)
  fileOut=paste("optimization_ntree_Class-",select.class,"_mtry",trunc(mtry.test),
                "_",range,".csv",sep="")
  write.table(OOB.df,file = fileOut,sep="\t",quote = FALSE,row.names = FALSE)
  plotOut=paste("Plot_optimization_ntree_Class-",select.class,"_mtry",trunc(mtry.test),
                "_",range,".pdf",sep="")
  pdf(plotOut)
  g=ggplot()+
    geom_vline(xintercept = ntree.min,linetype = 'dotted', col = 'red')+
    geom_point(data=OOB.df,aes(x=ntree.test,y=OOB.mean))+
    ylab("Mean OOB error")+xlab("Number of trees")+
    scale_y_continuous(trans='log10')+
    geom_errorbar(aes(x=ntree.test,ymin=ymin,ymax=ymax))+
    theme_bw()
  print(g)
  dev.off()
  
  # --- Now we fix the optimal ntree and look for the optimization of mtry
  ntree.in=ntree.min
  #ntree.in = 6000 # same order of magnitude than the minimum.
  mtry.test=seq(from=mtry.default/2,to= dim(xIn)[2],by=mtry.default/2)
  #OOB=vector(mode="numeric",length=length(mtry.test))
  OOB=matrix(0,nrow=length(mtry.test),ncol=Nrand)
  i=0
  for(mtry.tmp in mtry.test){
    i=i+1
    for(k in 1:Nrand){
      RF.tmp <-randomForest(y=yIn,x=xIn,
                            importance=T, proximity = T, 
                            ntree=ntree.in,mtry=mtry.tmp)
      
      if (type_RF == "classification") {
        OOB[i,k]=mean(RF.tmp$err.rate[,1])
      } else if (type_RF == "regression") {
        OOB[i,k]=mean(RF.tmp$mse)
      } else {
        warning("type not recognised")
      }
      
    }
  }
  OOB.mean=rowMeans(OOB)
  OOB.std=apply(OOB,1, sd, na.rm = TRUE)
  OOB.df=data.frame(cbind(mtry.test,OOB.mean,OOB.std))
  #rownames(OOB.df)=ntree.test
  OOB.min.id=which.min(OOB.mean)
  mtry.min=OOB.df$mtry.test[OOB.min.id]
  ymin=OOB.df$OOB.mean-OOB.df$OOB.std/sqrt(Nrand)
  ymax=OOB.df$OOB.mean+OOB.df$OOB.std/sqrt(Nrand)
  fileOut=paste("optimization_mtry_Class-",select.class,"_ntree-",trunc(ntree.min),
                "_",range,".csv",sep="")
  write.table(OOB.df,file = fileOut,sep="\t",quote = FALSE,row.names = FALSE)
  plotOut=paste("Plot_optimization_mtry_Class-",select.class,
                "_ntree",ntree.in,".pdf",sep="")
  pdf(plotOut)
  g=ggplot()+
    geom_point(data=OOB.df,aes(x=mtry.test,y=OOB.mean))+
    ylab("Mean OOB error")+xlab("Number of variables")+
    geom_errorbar(aes(x=mtry.test,ymin=ymin,ymax=ymax))
  print(g)
  dev.off()
}else{ # if we do not optimize
  if(mtry.min == "default"){ # we need to give a value if the user choose a default value
    mtry.min = mtry.default
  }
}

# --- With the optimal parameters run the RF
ntree.in=ntree.min
mtry.in=mtry.min
fileOut=paste("RandForestOut_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".RDS",sep="")
if(run_RF == 1){
  RF.out <-randomForest(y=yIn,x=xIn,
                        importance=T, proximity = T, 
                        ntree=ntree.in,mtry=mtry.in)
  # --- Have a look at the output
  RF.out
  
  if (type_RF == "classification") {
    mean(RF.out$err.rate[,1]) # this is the mean OOB
  } else if (type_RF == "regression") {
    mean(RF.out$mse) # this is the mean OOB
  } else {
    warning("type not recognised")
  }
  
  saveRDS(RF.out, file = fileOut)
}else{
  RF.out = readRDS(file = fileOut)
}

RF.out 

# ANALYSE -------------

# --- Extract important variables
setwd(dirOut)
importance.df = measure_importance(RF.out)
ASV.top = important_variables(importance.df)

fileOut=paste("varImpExt_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".tsv",sep="")

write.table(importance.df,file=fileOut,sep = "\t",quote=FALSE)

fileOut=paste("varTop_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".tsv",sep="")

write.table(ASV.top,file=fileOut,sep = "\t",quote=FALSE)

# Plots ---------

# first and overview, it will generate an html file. Takes time.
explain_forest(RF.out)

# --- Error
plotOut=paste("Plot_Error_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".pdf",sep="")
pdf(plotOut)
plot(RF.out)
dev.off()


# --- Variable importance
fileOut=paste("varImp_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".tsv",sep="")
varImp.out=varImp(RF.out)
varImp.out=varImp.out[which(rowSums(varImp.out) != 0),]
write.table(varImp.out,file=fileOut,sep = "\t",quote=FALSE)
plotOut=paste("Plot_varImp_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".pdf",sep="")
pdf(plotOut,width=10)
#varImpPlot(RF.out, type=1,main="")#,xlab="Variable importance",title="")
#dev.off()

# Get variable importance from the model fit
ImpData <- as.data.frame(importance(RF.out))
ImpData$Var.Names <- row.names(ImpData)
ImpData.sort.idx = sort(ImpData$MeanDecreaseAccuracy, 
                        decreasing = TRUE, index.return = T)
quantile(ImpData.sort.idx$x)
Imp.Data.sort = ImpData[ImpData.sort.idx$ix, ]
explained = 50
id.exp = which(Imp.Data.sort$MeanDecreaseAccuracy > explained)
Imp.Data.sort = Imp.Data.sort[id.exp,]

ggplot(Imp.Data.sort, aes(x=Var.Names, y=MeanDecreaseAccuracy)) +
  geom_segment( aes(x=Var.Names, xend=Var.Names, y=0, yend=MeanDecreaseAccuracy), color="skyblue") +
  geom_point(aes(size = MeanDecreaseGini), color="blue", alpha=0.6) +
  theme_light() +
  coord_flip() +
  theme(
    legend.position="bottom",
    panel.grid.major.y = element_blank(),
    panel.border = element_blank(),
    axis.ticks.y = element_blank()
  )

dev.off()

# Variable importance, multiway

plotOut=paste("Plot_accuracyVsGini_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".pdf",sep="")

pdf(plotOut, height =6)
p = plot_multi_way_importance(
  importance_frame = importance.df,
  x_measure = "accuracy_decrease",
  y_measure = "gini_decrease",
  size_measure = "p_value",
  min_no_of_trees = 0,
  no_of_labels = 10 #,
  #main = "Multi-way importance plot"
)
p = p + xlab("Accuracy decrease") +ylab("Gini index decrease")
p
#print(p)
dev.off()

# --- Partial plots sorted by importance
# ... these plots may take a long time
if(partial.plots == 1){
  Nsel=10 # select only 10 vars
  imp <- importance(RF.out)
  impvar <- rownames(imp)[order(imp[, 1], decreasing=TRUE)]
  
  impvar=impvar[1:Nsel]
  Nlev=levels(yIn)
  
  # Interpret partial plots
  # https://stats.stackexchange.com/questions/147763/meaning-of-y-axis-in-random-forest-partial-dependence-plot
  # p=seq(from=0.005, to=0.995, by=0.005) # to understand the plots
  # plot(p,log(p/(1-p))) # plot the logit function
  for (i in seq_along(impvar)) {
    var.lab=str_replace_all(impvar[i],pattern="[[:punct:]]",replacement = "")
    var.lab=str_replace_all(var.lab,pattern = " ",replacement = "_")
    plotOut=paste("Plot_PartialDependence_",trunc(ntree.in),"_mtry",trunc(mtry.in),
                  "Class-",select.class,"_Var-",var.lab,
                  ".pdf",sep="")
    pdf(file=plotOut,width=12)
    op <- par(mfrow=c(2, 3))
    for(level in levels(yIn)){
      partialPlot(RF.out, xIn, impvar[i], xlab=impvar[i],
                  which.class = level,
                  main=paste("Partial Dependence for class", level)) #ylim=c(30, 70))
    }
    par(op)
    dev.off()
  }
}


# 
# 
# data(airquality)
# airquality <- na.omit(airquality)
# set.seed(131)
# ozone.rf <- randomForest(Ozone ~ ., airquality, importance=TRUE)
# imp <- importance(ozone.rf)
# impvar <- rownames(imp)[order(imp[, 1], decreasing=TRUE)]
# op <- par(mfrow=c(2, 3))
# for (i in seq_along(impvar)) {
#   partialPlot(ozone.rf, airquality, impvar[i], xlab=impvar[i],
#               main=paste("Partial Dependence on", impvar[i]),
#               ylim=c(30, 70))
# }
# par(op)


# BREAKAGE

In [None]:

################################# RANDOM FOREST computation -------------
set.seed(3032024) # today

# --- First estimate the optimal RF parameters:
# the two steps (ntree and mtry) should possibly be iterated
dir.create(dirOut)
setwd(dirOut)
mtry.default=sqrt(dim(xIn)[2]) # default number of variables taken to build the trees
if(optimization == 1){
  range="large"
  Nrand=25
  if(range == "small"){
    ntree.test=seq(from=100,to=500,by=20) # small range
  }else{
    ntree.test=c(100,250,500, seq(from=1000,to=11000,by=2000)) # large
    ntree.test=c(ntree.test,14000,18000,22000) # very large
    #ntree.test=c(ntree.test,12000,14000,16000,18000,20000) # very large
  }
  mtry.test=mtry.default #300 # if different than default this was fixed after one iteration with the next step below
  #OOB=vector(mode="numeric",length=length(ntree.test))
  OOB=matrix(0,nrow=length(ntree.test),ncol=Nrand)
  i=0
  for(ntree.tmp in ntree.test){
    i=i+1
    for(k in 1:Nrand){
      RF.tmp <-randomForest(y=yIn,x=xIn,
                            importance=T, proximity = T, 
                            ntree=ntree.tmp,mtry = mtry.test)
      
      if (type_RF == "classification") {
        OOB[i,k]=mean(RF.tmp$err.rate[,1])
      } else if (type_RF == "regression") {
        OOB[i,k]=mean(RF.tmp$mse)
      } else {
        warning("type not recognised")
      }
      
    }
  }
  OOB.mean=rowMeans(OOB)
  OOB.std=apply(OOB,1, sd, na.rm = TRUE)
  OOB.df=data.frame(cbind(ntree.test,OOB.mean,OOB.std))
  #rownames(OOB.df)=ntree.test
  OOB.min.id=which.min(OOB.mean)
  ntree.min=OOB.df$ntree.test[OOB.min.id]
  ymin=OOB.df$OOB.mean-OOB.df$OOB.std/sqrt(Nrand)
  ymax=OOB.df$OOB.mean+OOB.df$OOB.std/sqrt(Nrand)
  fileOut=paste("optimization_ntree_Class-",select.class,"_mtry",trunc(mtry.test),
                "_",range,".csv",sep="")
  write.table(OOB.df,file = fileOut,sep="\t",quote = FALSE,row.names = FALSE)
  plotOut=paste("Plot_optimization_ntree_Class-",select.class,"_mtry",trunc(mtry.test),
                "_",range,".pdf",sep="")
  pdf(plotOut)
  g=ggplot()+
    geom_vline(xintercept = ntree.min,linetype = 'dotted', col = 'red')+
    geom_point(data=OOB.df,aes(x=ntree.test,y=OOB.mean))+
    ylab("Mean OOB error")+xlab("Number of trees")+
    scale_y_continuous(trans='log10')+
    geom_errorbar(aes(x=ntree.test,ymin=ymin,ymax=ymax))+
    theme_bw()
  print(g)
  dev.off()
  
  # --- Now we fix the optimal ntree and look for the optimization of mtry
  ntree.in=ntree.min
  #ntree.in = 6000 # same order of magnitude than the minimum.
  mtry.test=seq(from=mtry.default/2,to= dim(xIn)[2],by=mtry.default/2)
  #OOB=vector(mode="numeric",length=length(mtry.test))
  OOB=matrix(0,nrow=length(mtry.test),ncol=Nrand)
  i=0
  for(mtry.tmp in mtry.test){
    i=i+1
    for(k in 1:Nrand){
      RF.tmp <-randomForest(y=yIn,x=xIn,
                            importance=T, proximity = T, 
                            ntree=ntree.in,mtry=mtry.tmp)
      
      if (type_RF == "classification") {
        OOB[i,k]=mean(RF.tmp$err.rate[,1])
      } else if (type_RF == "regression") {
        OOB[i,k]=mean(RF.tmp$mse)
      } else {
        warning("type not recognised")
      }
      
    }
  }
  OOB.mean=rowMeans(OOB)
  OOB.std=apply(OOB,1, sd, na.rm = TRUE)
  OOB.df=data.frame(cbind(mtry.test,OOB.mean,OOB.std))
  #rownames(OOB.df)=ntree.test
  OOB.min.id=which.min(OOB.mean)
  mtry.min=OOB.df$mtry.test[OOB.min.id]
  ymin=OOB.df$OOB.mean-OOB.df$OOB.std/sqrt(Nrand)
  ymax=OOB.df$OOB.mean+OOB.df$OOB.std/sqrt(Nrand)
  fileOut=paste("optimization_mtry_Class-",select.class,"_ntree-",trunc(ntree.min),
                "_",range,".csv",sep="")
  write.table(OOB.df,file = fileOut,sep="\t",quote = FALSE,row.names = FALSE)
  plotOut=paste("Plot_optimization_mtry_Class-",select.class,
                "_ntree",ntree.in,".pdf",sep="")
  pdf(plotOut)
  g=ggplot()+
    geom_point(data=OOB.df,aes(x=mtry.test,y=OOB.mean))+
    ylab("Mean OOB error")+xlab("Number of variables")+
    geom_errorbar(aes(x=mtry.test,ymin=ymin,ymax=ymax))
  print(g)
  dev.off()
}else{ # if we do not optimize
  if(mtry.min == "default"){ # we need to give a value if the user choose a default value
    mtry.min = mtry.default
  }
}

# --- With the optimal parameters run the RF
ntree.in=ntree.min
mtry.in=mtry.min
fileOut=paste("RandForestOut_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".RDS",sep="")
if(run_RF == 1){
  RF.out <-randomForest(y=yIn,x=xIn,
                        importance=T, proximity = T, 
                        ntree=ntree.in,mtry=mtry.in)
  # --- Have a look at the output
  RF.out
  
  if (type_RF == "classification") {
    mean(RF.out$err.rate[,1]) # this is the mean OOB
  } else if (type_RF == "regression") {
    mean(RF.out$mse) # this is the mean OOB
  } else {
    warning("type not recognised")
  }
  
  saveRDS(RF.out, file = fileOut)
}else{
  RF.out = readRDS(file = fileOut)
}

RF.out 

################################# ANALYSE -------------

# --- Extract important variables
setwd(dirOut)
importance.df = measure_importance(RF.out)
ASV.top = important_variables(importance.df)

fileOut=paste("varImpExt_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".tsv",sep="")

write.table(importance.df,file=fileOut,sep = "\t",quote=FALSE)

fileOut=paste("varTop_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".tsv",sep="")

write.table(ASV.top,file=fileOut,sep = "\t",quote=FALSE)

# Plots ---------

# first and overview, it will generate an html file. Takes time.
explain_forest(RF.out)

# --- Error
plotOut=paste("Plot_Error_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".pdf",sep="")
pdf(plotOut)
plot(RF.out)
dev.off()


# --- Variable importance
fileOut=paste("varImp_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".tsv",sep="")
varImp.out=varImp(RF.out)
varImp.out=varImp.out[which(rowSums(varImp.out) != 0),]
write.table(varImp.out,file=fileOut,sep = "\t",quote=FALSE)
plotOut=paste("Plot_varImp_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".pdf",sep="")
pdf(plotOut,width=10)
#varImpPlot(RF.out, type=1,main="")#,xlab="Variable importance",title="")
#dev.off()

# Get variable importance from the model fit
ImpData <- as.data.frame(importance(RF.out))
ImpData$Var.Names <- row.names(ImpData)
ImpData.sort.idx = sort(ImpData$MeanDecreaseAccuracy, 
                        decreasing = TRUE, index.return = T)
quantile(ImpData.sort.idx$x)
Imp.Data.sort = ImpData[ImpData.sort.idx$ix, ]
explained = 50
id.exp = which(Imp.Data.sort$MeanDecreaseAccuracy > explained)
Imp.Data.sort = Imp.Data.sort[id.exp,]

ggplot(Imp.Data.sort, aes(x=Var.Names, y=MeanDecreaseAccuracy)) +
  geom_segment( aes(x=Var.Names, xend=Var.Names, y=0, yend=MeanDecreaseAccuracy), color="skyblue") +
  geom_point(aes(size = MeanDecreaseGini), color="blue", alpha=0.6) +
  theme_light() +
  coord_flip() +
  theme(
    legend.position="bottom",
    panel.grid.major.y = element_blank(),
    panel.border = element_blank(),
    axis.ticks.y = element_blank()
  )

dev.off()

# Variable importance, multiway

plotOut=paste("Plot_accuracyVsGini_Class-",select.class,
              "_ntree",trunc(ntree.in),"_mtry",trunc(mtry.in),
              ".pdf",sep="")

pdf(plotOut, height =6)
p = plot_multi_way_importance(
  importance_frame = importance.df,
  x_measure = "accuracy_decrease",
  y_measure = "gini_decrease",
  size_measure = "p_value",
  min_no_of_trees = 0,
  no_of_labels = 10 #,
  #main = "Multi-way importance plot"
)
p = p + xlab("Accuracy decrease") +ylab("Gini index decrease")
p
#print(p)
dev.off()

# --- Partial plots sorted by importance
# ... these plots may take a long time
if(partial.plots == 1){
  Nsel=10 # select only 10 vars
  imp <- importance(RF.out)
  impvar <- rownames(imp)[order(imp[, 1], decreasing=TRUE)]
  
  impvar=impvar[1:Nsel]
  Nlev=levels(yIn)
  
  # Interpret partial plots
  # https://stats.stackexchange.com/questions/147763/meaning-of-y-axis-in-random-forest-partial-dependence-plot
  # p=seq(from=0.005, to=0.995, by=0.005) # to understand the plots
  # plot(p,log(p/(1-p))) # plot the logit function
  for (i in seq_along(impvar)) {
    var.lab=str_replace_all(impvar[i],pattern="[[:punct:]]",replacement = "")
    var.lab=str_replace_all(var.lab,pattern = " ",replacement = "_")
    plotOut=paste("Plot_PartialDependence_",trunc(ntree.in),"_mtry",trunc(mtry.in),
                  "Class-",select.class,"_Var-",var.lab,
                  ".pdf",sep="")
    pdf(file=plotOut,width=12)
    op <- par(mfrow=c(2, 3))
    for(level in levels(yIn)){
      partialPlot(RF.out, xIn, impvar[i], xlab=impvar[i],
                  which.class = level,
                  main=paste("Partial Dependence for class", level)) #ylim=c(30, 70))
    }
    par(op)
    dev.off()
  }
}


## Exercise 2
Same as Exercise 1 but now each predictor is one of your clusters in the network, and you will take 
the sum of the total abundance of its members. You will also include the logarithm of the total 
abundance of each sample.
Compare both exercises, are results consistent? Does the prediction improve by using the clusters?

In [None]:
# look at distribution of clusters
# remove clusters with only 1 asv

# Remove either location or partition
# BUt first compare include v not include these bariables by comparing MSE

# Plot mean relative abundance of interesting clusters across startnign communities and final communities (zx, y) and regression

# Use ALberto's document for RF

# RF using one final replciate as trainign and then predict function of other 3 replicated
# THen second RF where combine all replicates as before
# DO for both ASV and cluster

# Compare using presence/absence for RF to predict function with using rel abundance with using cluster

# Narrative of paper
# Genotype to phenotype
# Structure to function
# COmparing different ways of quantifying structure (composition) to predicting different functions

# REad Statistically learning the functional landscape of microbial communities to come up with narrative