/
0_demonstration.R
77 lines (62 loc) · 2.15 KB
/
0_demonstration.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
## Higher-Order Total Variation (HOTV)
## A. Okuno (ISM, okuno@ism.ac.jp)
## Augst 3rd, 2023
## ----------------------------------------------------
working_dir = getwd()
## ----------------------------------------------------
## dependencies
require("progress")
## directories
dir = list(
main = (tmp <- working_dir),
scripts = paste0(tmp,"/A0_scripts"),
data = paste0(tmp,"/A1_data"),
computed = paste0(tmp,"/A2_computed")
)
lapply(dir, function(z) dir.create(z, showWarnings=FALSE))
## number of datasets generated
n_seeds = 1
## loading scripts
source(paste0(dir$scripts,"/gen_data.R"))
source(paste0(dir$scripts,"/functions.R"))
## ------------------------
## dataset specification
## ------------------------
## loading quadratic datasets
df = read.csv(paste0(dir$data,"/2_seed1.csv"), header=T)
N=100
x=df$x[1:N]; y=(df$y+df$e)[1:N];
## --------------
## NN training
## --------------
## settings
constants = list(L = 50, ## number of hidden units
lambda = 0, ## reg. coef. for beta (ridge/weight decay)
eta = c(0,0,10**(-5)), ## reg. coef. for gamma (k-TV)
n = 5, ## num. of subsampling for alpha
m = 5, ## num. of subsampling for gamma
lr0 = 10**(-3), ## initial learning rate
lr_dr = 0.9, ## decay rate
lr_it = 25, ## decay interval
lr_period = 10**3, ## period of cyclic decay
n_itr = 10**4 ## num. of SGD iteration
)
## initialization of the parameter
theta0 = theta.init(L=constants$L, sd=1, sdx=sd(x))
## Variation-regularized SGD
.sgd = VRSGD(theta0=theta0, x=x, y=y, constants=constants,
monitor_loss=TRUE)
## --------
## plot
## --------
xl = c(-1/2, 1/2); yl = range(y);
xx = seq(xl[1], xl[2], length.out=100)
yp = f(xx, .sgd$theta)
par(mfrow=c(1,2))
## [plot 1] monitored loss functions (of SGD)
plot(.sgd$monitor$itr, .sgd$monitor$loss, type="l", xlab="t", ylab="loss", log="y")
## [plot 2] training data and predictions
plot(x, y, xlim=xl, ylim=yl, xlab="x", ylab="y")
par(new=T)
plot(xx, yp, col="blue", type="l", xlim=xl, ylim=yl,
xlab=" ", ylab=" ", xaxt="n", yaxt="n", lwd=2)