-
Notifications
You must be signed in to change notification settings - Fork 0
/
convolutiveica.m
335 lines (284 loc) · 11.8 KB
/
convolutiveica.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
function [S_cica, A_tau, S_noise, A_noise] = convolutiveica(X,L,A,...
sr,d_row,d_col,N_row,N_col,...
d_max,frames_ROI,varargin)
%[S_cica, A_tau, S_noise, A_noise] = convolutiveica(X,L,A,...
% sr,d_row,d_col,N_row,N_col,...
% d_max,frames_ROI,do_cICA,varargin)
%
% X: Array of dimension (D,T) containing D channels.
% L: order of convolutive mixing model.
% A: (N,D,L+1) mixing matrix identified
% by preceding dimensionality reduction from N sensors
% to D channels (e.g. by instantaneous ICA);
% sr: sampling rate in kHz
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Default arguments
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
plotting = 0;
min_skewness = 0.2;
min_corr = 0.15;
M = 0;
max_cluster_size = 2;
max_iter = 5;
maxlags = L+M;
min_no_peaks = 2;
t_s = 0.5;
t_jitter = 1;
coin_thr = 0.5;
thrFactor = 5;
do_cICA = true;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Optional arguments
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if (rem(length(varargin),2)==1)
error('Optional parameters should always go by pairs');
else
for i=1:2:(length(varargin)-1)
if ~ischar (varargin{i}),
error (['Unknown type of optional parameter name (parameter' ...
' names must be strings).']);
end
% change the value of parameter
switch varargin{i}
case 'maxlags'
maxlags = (varargin{i+1});
case 'plotting'
plotting = (varargin{i+1});
case 'min_skewness'
min_skewness = (varargin{i+1});
case 'min_corr'
min_corr = (varargin{i+1});
case 'M'
M = (varargin{i+1});
case 'max_cluster_size'
max_cluster_size = (varargin{i+1});
case 'max_iter'
max_iter = (varargin{i+1});
case 'sr'
sr = (varargin{i+1});
case 'thrFactor'
thrFactor = (varargin{i+1});
case 'min_no_peaks'
min_no_peaks = (varargin{i+1});
case 't_s'
t_s = (varargin{i+1});
case 't_jitter'
t_jitter = (varargin{i+1});
case 'coin_thr'
coin_thr = (varargin{i+1});
otherwise
% Hmmm, something wrong with the parameter string
error(['Unrecognized parameter: ''' varargin{i} '''']);
end;
end;
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Output
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% S_cica sources;
% A_tau (N_sensors,D_sources,L+M+1) updated mixing matrices
if (size(A,3) - 1) ~= L+M
error('Mismatch of order L+M and number of mixing matrices A.')
end
iteration_no = 0;
touched = struct('IDs',{});
S_noise = [];A_noise = [];
while iteration_no <= max_iter
iteration_no = iteration_no + 1;
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Remove components
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% 1. get indices of components to keep
[keep] = checkfornoisycomponents(X,min_skewness,thrFactor, ...
min_no_peaks,sr,plotting);
% 2. remove components and store them away:
S_noise = [S_noise;X(~keep,:)];
A_noise = cat(2,A_noise,A(:,~keep,:));
X = X(keep,:);
A = A(:,keep,:);
% 3. Adapt touched combination indices
%touched is a struct array with touched(i).IDs containing those ID
%combinations that were already processed with convolutive ICA and need
%not be touched again. When e.g. a cICA component is skipped (due to
%noise) it makes sense to touch the remaining combination j again because
%it might have absorbed signal energy from the skipped component(s).
%hence the respective entry touched(j).IDs is removed.
%skipped IDs:
delIDs = find(~keep);
%keep only combinations for which no member was deleted:
touched = touched(cellfun(@(x) ~any(ismember(x,delIDs)),{touched.IDs}));
%adapt remaining indices:
for i=1:length(touched)
for j=1:length(touched(i).IDs)
oldID = touched(i).IDs(j);
touched(i).IDs(j) = oldID - nnz(delIDs < oldID);
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Estimate crosstalk
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
SD = spatialdistance(A, d_row, d_col, N_row, N_col);
SM = channelcrosstalkwiener(X(:,frames_ROI), maxlags, SD, d_max);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Group channels based on crosstalk
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
[T] = hierarchicalclustering(SM,min_corr,'plotting',plotting);
%counts(i): total # of members belonging to cluster i:
counts = arrayfun(@(x)(sum(T == x)),1:max(T));
fprintf('Identified %g clusters showing crosstalk.\n',...
length(nonzeros(counts - 1)));
fprintf('The largest cluster contains %g channels.\n',max(counts));
%care only about clusters that have more than one member:
cluster_ids = find(counts>=2);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Reduce cluster size
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%Idea: if maximum cluster size (due to computational load) is exceeded
%unmix subcluster only, this is independent from all other clusters,
%within each bigger cluster, we can only unmix the remaining components
%in a sequential manner (not in parallel); idea: increase
%minimum similarity until maximum sub cluster size <= max_cluster_size
if max(counts) > max_cluster_size
fprintf('Only the %g channels showing the strongest\n',max_cluster_size);
fprintf('crosstalk within each cluster will be unmixed.\n');
for i = 1:length(cluster_ids)
cl_i = find(T == cluster_ids(i))';
if (length(cl_i) > max_cluster_size)
d_sm = 0;
sm_cl_i = SM(cl_i,cl_i);
cl_i_tmp = cl_i;
while length(cl_i_tmp) > max_cluster_size
d_sm = d_sm + 0.01;
t = hierarchicalclustering(sm_cl_i,...
min_corr + d_sm,'plotting',0);
t_counts = arrayfun(@(x)(sum(t == x)),1:max(t));
%take the biggest remaining subcluster:
[counts_max,ind_max] = max(t_counts);
cl_i_tmp = cl_i(t == ind_max);
if length(cl_i_tmp) <= max_cluster_size
cl_i = cl_i_tmp;
end
end
clear cl_i_tmp;
%cl_i is the subcluster to unmix, put each of the remaining
%components in a separate cluster (because due to single
%linkage clustering they do not necessarily have something in
%common):
cl_rem = setdiff(find(T == cluster_ids(i)),cl_i);
for k = 1:length(cl_rem)
T(cl_rem(k)) = max(T) + 1;
end
clear cl_rem
end
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Skip clusters that were already touched
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%all clusters have to be compared against all touched combinations
%for equality:
equality = zeros(length(cluster_ids),length(touched));
for i = 1:length(cluster_ids)
for j = 1:length(touched)
c = find(T == cluster_ids(i));
equality(i,j) = isequal(sort(c(:)),sort(touched(j).IDs(:)));
end
end
N_total = length(cluster_ids);
cluster_ids_touched = cluster_ids(any(equality'));
N_touched = length(cluster_ids_touched);
%Put each component of the already touched ones in a separate cluster
%to guarantee, that T only contains clusters with more than one member
%that we actually want to provide as input to cICA:
for i = 1:length(cluster_ids_touched)
cl_i_touched = find(T == cluster_ids_touched(i));
for k = 1:length(cl_i_touched)
%if k == 1: the first component is left as only member with
%ID T(cl_i_touched(1))
if k > 1
T(cl_i_touched(k)) = max(T) + 1;
end
end
clear cl_i_touched
end
%cluster_ids = cluster_ids(~any(equality'));
fprintf('%g of the %g total clusters were already touched\n',...
N_touched,N_total);
% Keep only clusters with more than one member:
counts = arrayfun(@(x)(sum(T == x)),1:max(T));
cluster_ids = find(counts>=2);
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Crosstalk before cICA step
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if plotting
%Crosstalk
figure;title('Clusterwise crosstalk');
pltsize = ceil(sqrt(max(T)));
for i = 1:max(T)
subplot(pltsize,pltsize,i);plot(X(find(T == i),:)');
end
end
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Convolutive ICA
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
if ~isempty(cluster_ids) && iteration_no <= max_iter && do_cICA
fprintf('Iteration %g...\n',iteration_no);
for i=1:length(cluster_ids)
fprintf('CICAAR for unmixing cluster %g,',i);
fprintf(' containing channels:\n');
cl_i = find(T == cluster_ids(i))'
%this is the most time consuming step:
tic;
[e,ll,bic,invA0,Atau,Hlambda] = cicaarpro(X(cl_i,frames_ROI),L,M);
toc;
fprintf('BIC=%f, e=%f\n',bic,e);
%Convolve source autocorrelation filters Hlambda with
%mixing matrices:
A0Atau = cat(3,pinv(invA0),Atau);
Aeff = zeros(size(Atau,1),size(Atau,2),L+M+1);
H0Hlambda = [ones(1,size(Hlambda,2)); Hlambda];
for n = 1:size(Atau,1)
for m = 1:size(Atau,2)
Aeff(n,m,:) = conv(squeeze(A0Atau(n,m,:)),...
H0Hlambda(:,m));
end
end
%Mixing matrix update:
Aold = A(:,cl_i,:);
A(:,cl_i,:) = 0;
%the following could be vectorized:
for tau = 1:L+M+1
for k = 1:tau
j = tau + 1 - k;
A(:,cl_i,tau) = A(:,cl_i,tau) + ...
Aold(:,:,k)*Aeff(:,:,j);
end
end
X(cl_i,:) = cicaarprosep(invA0,Atau,X(cl_i,:));
%compute whitened sources Z:
%X(cl_i,:) = cicaarwhitensources(Hlambda,X(cl_i,:));
%X = cicaarprosep(pinv(A(:,:,1)),A(:,:,2:end),X_org);
%append current cluster to list of touched component
%combinations:
touched(length(touched)+1).IDs = cl_i;
end
if plotting
%Crosstalk
figure;title('Remaining Clusterwise crosstalk');
pltsize = ceil(sqrt(max(T)));
for i = 1:max(T)
subplot(pltsize,pltsize,i);plot(X(find(T == i),:)');
end
end
else
S_cica = X;
A_tau = A;
if iteration_no < max_iter;
fprintf('Converged in %g iterations.\n',iteration_no -1);
else
fprintf('Maximum number of iterations reached (%g)\n',max_iter);
end
return;
end
end
end