In [None]:
load_or_install = function(list.of.packages){
    new.packages <- list.of.packages[!(list.of.packages %in% installed.packages()[,"Package"])]
    if(length(new.packages)) install.packages(new.packages)
    for(pkg in list.of.packages)
        library(pkg,character.only=TRUE)
}

In [None]:
library('ggplot2')
library('reshape2')

In [None]:
sig2 = 1
mus = c(1,5,20)
N = 100

In [None]:
Ns = rep(N,3)
Ns

In [None]:
g1 <- data.frame(group=1,x = rnorm(Ns[1],mus[1],1))
g2 <- data.frame(group=2,x = rnorm(Ns[2],mus[2],1))
g3 <- data.frame(group=3,x = rnorm(Ns[3],mus[3],1))
d = rbind(g1,g2,g3)
d$group <- as.factor(d$group)
d = d[sample(nrow(d)),]

In [None]:
head(d)

In [None]:
ggplot(data=d,mapping=aes(x=x,color=group,group=group))+
    geom_density()+
    geom_point(mapping=aes(x=x,y=0))

In [None]:
class=1
  
subd = d[d$group==class,]
head(subd)

In [None]:
# mu1_hat
muhat = mean(subd$x)
muhat

In [None]:
# pi_hat
pi_hat = mean(d$group==class)
pi_hat

In [None]:
# pooled sd
vars = sapply(1:3,function(i)var(d[d$group==i,'x']))
vars

In [None]:
pooled_var = sum((Ns-1)*vars)/(sum(Ns)-3)
pooled_var

In [None]:
var(d$x)

In [None]:
x0 = 1

In [None]:
dnorm(x0,mean=muhat,sd=sqrt(pooled_var))*pi_hat

In [None]:
delta_lda_c = function(x0,class){
    subd = d[d$group==class,]
    muhat = mean(subd$x)
    pi_hat = mean(d$group==class)
    vars = sapply(1:3,function(i)var(d[d$group==i,'x']))
    pooled_var = sum((Ns-1)*vars)/(sum(Ns)-3)
    dlta = dnorm(x0,mean=muhat,sd=sqrt(pooled_var))*pi_hat
    return(dlta)
}

In [None]:
delta_lda_c(x0,class=1)

In [None]:
delta_lda_c(x0,class=2)

In [None]:
delta_lda_c(x0,class=3)

In [None]:
lda_pred = function(x0){
    deltas = sapply(1:3,function(c)delta_lda_c(x0,c))
    return(which.max(deltas))
}

In [None]:
lda_pred(x0=1)

In [None]:
lda_pred(x0=7)

In [None]:
lda_pred(x0=20)

In [None]:
x_seq = seq(-2,25,length.out=500)

In [None]:
df = data.frame(x=x_seq,y_pred = sapply(x_seq,lda_pred))
df$y_pred = factor(df$y_pred)
df$c1 = sapply(x_seq,delta_lda_c,class=1)
df$c2 = sapply(x_seq,delta_lda_c,class=2)
df$c3 = sapply(x_seq,delta_lda_c,class=3)
head(df)

In [None]:
mdf = melt(df,id.vars=c('x','y_pred'))

In [None]:
levels(mdf$variable) = c(1,2,3)

In [None]:
ggplot(data=d,mapping=aes(x=x,color=group,group=group))+
    geom_density()+
    geom_point(mapping=aes(x=x,y=0))

In [None]:
options(repr.plot.width = 10, repr.plot.height = 4, repr.plot.res = 100)
ggplot(data=mdf,mapping=aes(x=x,y=0,color=y_pred,group=y_pred))+
    geom_point()+
    geom_line(mapping=aes(x=x,y=value,group=variable,color=variable),inherit.aes=FALSE)

In [None]:
library('MASS')
?lda

In [None]:
my_preds = sapply(d$x,lda_pred)
my_preds

In [None]:
mod = lda(group~.,data=d)

In [None]:
mod

In [None]:
muhat

In [None]:
pi_hat

In [None]:
mod_preds = predict(mod)$class
mod_preds

In [None]:
all(mod_preds == my_preds)

In [None]:
head(predict(mod)$posterior)

In [None]:
mod_df = cbind(x_seq,predict(mod,newdata=data.frame(x=x_seq))$posterior)
colnames(mod_df)[1] = c('x')
mod_df = data.frame(mod_df)
mmod_df = melt(mod_df,id.vars='x')
levels(mmod_df$variable) = 1:3

In [None]:
options(repr.plot.width = 10, repr.plot.height = 3, repr.plot.res = 100)
ggplot(data=mdf,mapping=aes(x=x,y=0,color=y_pred,group=y_pred))+
    geom_point()+
    #geom_line(mapping=aes(x=x,y=value,group=variable,color=variable),inherit.aes=FALSE)+
    geom_line(mapping=aes(x=x,y=value,group=variable,color=variable),inherit.aes=FALSE,data=mmod_df)

In [None]:
head(df)

In [None]:
df[,3:5] = df[,3:5]/rowSums(df[,3:5])

In [None]:
mdf = melt(df,id.vars=c('x','y_pred'))

In [None]:
levels(mdf$variable) = c(1,2,3)

In [None]:
options(repr.plot.width = 10, repr.plot.height = 4, repr.plot.res = 100)
ggplot(data=mdf,mapping=aes(x=x,y=0,color=y_pred,group=y_pred))+
    geom_point()+
    geom_line(mapping=aes(x=x,y=value,group=variable,color=variable),inherit.aes=FALSE)
    geom_line(mapping=aes(x=x,y=value,group=variable,color=variable),inherit.aes=FALSE,data=mmod_df)

In [None]:
delta_lda_c2 = function(x0,class){
    subd = d[d$group==class,]
    muhat = mean(subd$x)
    pi_hat = mean(d$group==class)
    vars = sapply(1:3,function(i)var(d[d$group==i,'x']))
    pooled_var = sum((Ns-1)*vars)/(sum(Ns)-3)
    dlta = muhat*x0/(pooled_var) - muhat^2/(2*pooled_var)+log(pi_hat)
    return(dlta)
}
lda_pred2 = function(x0){
    deltas = sapply(1:3,function(c)delta_lda_c(x0,c))
    return(which.max(deltas))
}

In [None]:
df = data.frame(x=x_seq,y_pred = sapply(x_seq,lda_pred2))
df$y_pred = factor(df$y_pred)
df$c1 = sapply(x_seq,delta_lda_c2,class=1)
df$c2 = sapply(x_seq,delta_lda_c2,class=2)
df$c3 = sapply(x_seq,delta_lda_c2,class=3)

mdf = melt(df,id.vars=c('x','y_pred'))
levels(mdf$variable) = c(1,2,3)

In [None]:
ggplot(data=mdf,mapping=aes(x=x,y=0,color=y_pred,group=y_pred))+
    geom_point(shape=1,size=1/10)+
    geom_line(mapping=aes(x=x,y=value,group=variable,color=variable),lwd=2)+
    coord_cartesian(ylim=c(-50,50))

# for $p>1$

In [None]:
load_or_install('palmerpenguins')
penguins = penguins[complete.cases(penguins),]
head(penguins)

In [None]:
d = penguins[,c('bill_length_mm','bill_depth_mm','species')]
head(d)

In [None]:
mod = lda(species~.,data=d)
mod

In [None]:
plot_fit = function(v1,v2,df=penguins,N=floor(sqrt(10000)),scaleit=FALSE,fmla='species~.'){
    train_df = df[,c('species',v1,v2)]
    if(scaleit)
        train_df[,c(v1,v2)] = scale(train_df[,c(v1,v2)])
    
    #mod = knn3(species~.,data=train_df,k=k)
    mod = lda(formula=as.formula(fmla),data=train_df)
    
    r1 = range(train_df[[v1]])
    r2 = range(train_df[[v2]])
    
    s1 = seq(r1[1],r1[2],length.out=N)
    s2 = seq(r2[1],r2[2],length.out=N)
    
    p_df = expand.grid(v1=s1,v2=s2)
    colnames(p_df) = c(v1,v2)
    preds = predict(mod,newdata=p_df)$class
    p_df$species = preds
    
    ggplot(data=p_df,mapping=aes_string(x=v1,y=v2,fill='species',shape='species'))+geom_tile()+
        geom_point(data=train_df,size=5)
}

In [None]:
options(repr.plot.width = 10, repr.plot.height = 3, repr.plot.res = 100)
plot_fit(v1='bill_length_mm',v2='bill_depth_mm',df=penguins)

In [None]:
plot_fit(v1='flipper_length_mm',v2='bill_depth_mm',df=penguins)

In [None]:
options(repr.plot.width = 10, repr.plot.height = 5, repr.plot.res = 100)
plot_fit(v1='bill_length_mm',v2='bill_depth_mm',df=penguins,
             fmla='species~I(bill_length_mm^5)+I(bill_depth_mm^3)+I(bill_depth_mm^2)')