-
Notifications
You must be signed in to change notification settings - Fork 9
/
linear-learning-problem.jl
170 lines (150 loc) · 4.97 KB
/
linear-learning-problem.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# Linear learning problem data types and constructors #########################
"""
abstract type LinearProblem{T<:Real} <: AbstractLearningProblem end
An abstract type to specify linear potential inference problems.
"""
abstract type LinearProblem{T<:Real} <: AbstractLearningProblem end
"""
struct UnivariateLinearProblem{T<:Real} <: LinearProblem{T}
iv_data::Vector
dv_data::Vector
β::Vector{T}
β0::Vector{T}
σ::Vector{T}
Σ::Symmetric{T,Matrix{T}}
end
A UnivariateLinearProblem is a linear problem in which there is only 1 type of independent variable / dependent variable. Typically, that means we are either only fitting energies or only fitting forces. When this is the case, the solution is available analytically and the standard deviation, σ, and covariance, Σ, of the coefficients, β, are computable.
"""
struct UnivariateLinearProblem{T<:Real} <: LinearProblem{T}
iv_data::Vector
dv_data::Vector
β::Vector{T}
β0::Vector{T}
σ::Vector{T}
Σ::Symmetric{T,Matrix{T}}
end
Base.show(io::IO, u::UnivariateLinearProblem{T}) where {T} =
print(io, "UnivariateLinearProblem{T, $(u.β), $(u.σ)}")
"""
struct CovariateLinearProblem{T<:Real} <: LinearProblem{T}
e::Vector
f::Vector{Vector{T}}
B::Vector{Vector{T}}
dB::Vector{Matrix{T}}
β::Vector{T}
β0::Vector{T}
σe::Vector{T}
σf::Vector{T}
Σ::Symmetric{T,Matrix{T}}
end
A CovariateLinearProblem is a linear problem in which we are fitting energies and forces using both descriptors and their gradients (B and dB, respectively). When this is the case, the solution is not available analytically and must be solved using some iterative optimization proceedure. In the end, we fit the model coefficients, β, standard deviations corresponding to energies and forces, σe and σf, and the covariance Σ.
"""
struct CovariateLinearProblem{T<:Real} <: LinearProblem{T}
e::Vector
f::Vector{Vector{T}}
B::Vector{Vector{T}}
dB::Vector{Matrix{T}}
β::Vector{T}
β0::Vector{T}
σe::Vector{T}
σf::Vector{T}
Σ::Symmetric{T,Matrix{T}}
end
Base.show(io::IO, u::CovariateLinearProblem{T}) where {T} =
print(io, "CovariateLinearProblem{T, $(u.β), $(u.σe), $(u.σf)}")
"""
function LinearProblem(
ds::DataSet;
T = Float64
)
Construct a LinearProblem by detecting if there are energy descriptors and/or force descriptors and construct the appropriate LinearProblem (either Univariate, if only a single type of descriptor, or Covariate, if there are both types).
"""
function LinearProblem(
ds::DataSet
)
d_flag, descriptors, energies = try
true, sum.(get_values.(get_local_descriptors.(ds))), get_values.(get_energy.(ds))
catch
false, 0.0, 0.0
end
fd_flag, force_descriptors, forces = try
true,
[reduce(vcat, get_values(get_force_descriptors(dsi))) for dsi in ds],
get_values.(get_forces.(ds))
catch
false, 0.0, 0.0
end
if d_flag & ~fd_flag
dim = length(descriptors[1])
β = zeros(dim)
β0 = zeros(1)
p = UnivariateLinearProblem(
descriptors,
energies,
β,
β0,
[1.0],
Symmetric(zeros(dim, dim)),
)
elseif ~d_flag & fd_flag
dim = length(force_descriptors[1][1])
β = zeros(dim)
β0 = zeros(1)
force_descriptors = [reduce(hcat, fi) for fi in force_descriptors]
p = UnivariateLinearProblem(
force_descriptors,
[reduce(vcat, fi) for fi in forces],
β,
β0,
[1.0],
Symmetric(zeros(dim, dim)),
)
elseif d_flag & fd_flag
dim_d = length(descriptors[1])
dim_fd = length(force_descriptors[1][1])
if (dim_d != dim_fd)
error("Descriptors and Force Descriptors have different dimension!")
else
dim = dim_d
end
β = zeros(dim)
β0 = zeros(1)
forces = [reduce(vcat, fi) for fi in forces]
force_descriptors = [reduce(hcat, fi) for fi in force_descriptors]
p = CovariateLinearProblem(
energies,
[reduce(vcat, fi) for fi in forces],
descriptors,
force_descriptors,
β,
β0,
[1.0],
[1.0],
Symmetric(zeros(dim, dim)),
)
else
error("Either no (Energy, Descriptors) or (Forces, Force Descriptors) in DataSet")
end
p
end
# Linear learning functions common to OLS and WLS implementations #############
"""
function learn!(
iap::InteratomicPotentials.LinearBasisPotential,
ds::DataSet,
args...
)
Learning dispatch function, common to ordinary and weghted least squares implementations.
"""
function learn!(
lb::InteratomicPotentials.LinearBasisPotential,
ds::DataSet,
args...
)
lp = LinearProblem(ds)
learn!(lp, args...)
resize!(lb.β, length(lp.β))
lb.β .= lp.β
lb.β0 .= lp.β0
return lp
end