/
GetICNet.m
41 lines (37 loc) · 1.29 KB
/
GetICNet.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
function net = GetICNet(v_fStrain,m_fYtrain)
% Generate trained interference cancellation network
%
% Syntax
% -------------------------------------------------------
% net = GetICNet(v_fStrain,m_fYtrain)
%
% INPUT:
% -------------------------------------------------------
% v_fStrain - training labels
% m_fYtrain - training inputs (channel outputs + interference
%
% OUTPUT:
% -------------------------------------------------------
% net - trained neural network
% Generate neural network
inputSize = size(m_fYtrain,1);
numHiddenUnits = 60;
numClasses = 2; % Binary constellations
% Nir - work around converting LSTMs into a perceptron with sigmoid activation
LSTMLayer = lstmLayer(numHiddenUnits,'OutputMode','last'...
...,'RecurrentWeights', zeros(4*numHiddenUnits,numHiddenUnits)...
, 'RecurrentWeightsLearnRateFactor', 0 ...
, 'RecurrentWeightsL2Factor', 0 ...
);
LSTMLayer.RecurrentWeights = zeros(4*numHiddenUnits,numHiddenUnits);
% Layers = 3 fullly connected + softmax
layers = [ ...
sequenceInputLayer(inputSize)
LSTMLayer
fullyConnectedLayer(floor(numHiddenUnits/2))
reluLayer
fullyConnectedLayer(numClasses)
softmaxLayer
classificationLayer];
% Train network
net = TrainICNet(v_fStrain,m_fYtrain, layers, 0);