Skip to content

Commit

Permalink
tgv/ictv + multiple penalties + 3D
Browse files Browse the repository at this point in the history
  • Loading branch information
uecker committed Jun 1, 2022
1 parent 66d1f4c commit 5393516
Show file tree
Hide file tree
Showing 14 changed files with 373 additions and 98 deletions.
3 changes: 2 additions & 1 deletion Makefile
Expand Up @@ -205,7 +205,7 @@ ISMRM_BASE ?= /usr/local/ismrmrd/
TBASE=show slice crop resize join transpose squeeze flatten zeros ones flip circshift extract repmat bitmask reshape version delta copy casorati vec poly index multicfl
TFLP=scale invert conj fmac saxpy sdot spow cpyphs creal carg normalize cdf97 pattern nrmse mip avg cabs zexp
TNUM=fft fftmod fftshift noise bench threshold conv rss filter mandelbrot wavelet window var std fftrot roistat pol2mask conway morphop
TRECO=pics pocsense sqpics itsense nlinv moba nufft rof tgv sake wave lrmatrix estdims estshift estdelay wavepsf wshfl rtnlinv mobafit
TRECO=pics pocsense sqpics itsense nlinv moba nufft rof tgv ictv sake wave lrmatrix estdims estshift estdelay wavepsf wshfl rtnlinv mobafit
TCALIB=ecalib ecaltwo caldir walsh cc ccapply calmat svd estvar whiten rmfreq ssa bin
TMRI=homodyne poisson twixread fakeksp looklocker upat fovshift
TSIM=phantom traj signal epg sim
Expand Down Expand Up @@ -240,6 +240,7 @@ MODULES_estvar = -lcalib
MODULES_nufft = -lnoncart -liter -llinops
MODULES_rof = -liter -llinops
MODULES_tgv = -liter -llinops
MODULES_ictv = -liter -llinops
MODULES_bench = -lwavelet -llinops
MODULES_phantom = -lsimu -lgeom
MODULES_bart = -lbox -lgrecon -lsense -lnoir -liter -llinops -lwavelet -llowrank -lnoncart -lcalib -lsimu -lsake -ldfwavelet -lnlops -lnetworks -lnn -liter -lmoba -lgeom -lnn -lnlops
Expand Down
1 change: 1 addition & 0 deletions src/grecon/italgo.c
Expand Up @@ -46,6 +46,7 @@ enum algo_t italgo_choose(int nr_penalties, const struct reg_s regs[nr_penalties

case TV:
case TGV:
case ICTV:
case IMAGL1:
case IMAGL2:

Expand Down
142 changes: 89 additions & 53 deletions src/grecon/optreg.c
Expand Up @@ -65,6 +65,7 @@ void help_reg(void)
"-R T:A:B:C\ttotal variation\n"
"-R T:7:0:.01\t3D isotropic total variation with 0.01 regularization.\n"
"-R G:A:B:C\ttotal generalized variation\n"
"-R C:A:B:C\tinfimal convolution TV\n"
"-R L:7:7:.02\tLocally low rank with spatial decimation and 0.02 regularization.\n"
"-R M:7:7:.03\tMulti-scale low rank with spatial decimation and 0.03 regularization.\n"
"-R TF:{graph_path}:lambda\tTensorFlow loss\n"
Expand Down Expand Up @@ -143,6 +144,12 @@ bool opt_reg(void* ptr, char c, const char* optarg)
int ret = sscanf(optarg, "%*[^:]:%d:%d:%f", &regs[r].xflags, &regs[r].jflags, &regs[r].lambda);
assert(3 == ret);
}
else if (strcmp(rt, "C") == 0) {

regs[r].xform = ICTV;
int ret = sscanf(optarg, "%*[^:]:%d:%d:%f", &regs[r].xflags, &regs[r].jflags, &regs[r].lambda);
assert(3 == ret);
}
else if (strcmp(rt, "P") == 0) {

regs[r].xform = LAPLACE;
Expand Down Expand Up @@ -247,21 +254,22 @@ bool opt_reg_init(struct opt_reg_s* ropts)
ropts->r = 0;
ropts->lambda = -1;
ropts->svars = 0;
ropts->sr = 0;

return false;
}


void opt_bpursuit_configure(struct opt_reg_s* ropts, const struct operator_p_s* prox_ops[NUM_REGS], const struct linop_s* trafos[NUM_REGS], const struct linop_s* model_op, const complex float* data, const float eps)
{
int nr_penalties = ropts->r;
int nr_penalties = ropts->r + ropts->sr;
assert(NUM_REGS > nr_penalties);

const struct iovec_s* iov = linop_codomain(model_op);
prox_ops[nr_penalties] = prox_l2ball_create(iov->N, iov->dims, eps, data);
trafos[nr_penalties] = linop_clone(model_op);

ropts->r++;
ropts->sr++;
}

void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, const struct operator_p_s* prox_ops[NUM_REGS], const struct linop_s* trafos[NUM_REGS], unsigned int llr_blk, unsigned int shift_mode, bool use_gpu)
Expand All @@ -287,19 +295,22 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c
ropts->r = 1;
}

int nr_penalties = ropts->r;
// compute needed supporting variables

for (int nr = 0; nr < nr_penalties; nr++) {
for (int nr = 0; nr < ropts->r; nr++) {

switch (regs[nr].xform) {

case TGV:

if (0 != ropts->svars)
error("only one TGV term allowed.");
ropts->svars += bitcount(regs[nr].xflags);
ropts->sr++;
break;

case ICTV:

ropts->svars += 2;
ropts->r++;
ropts->svars += 1;
ropts->sr++;
break;

default: ;
Expand All @@ -311,17 +322,20 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c

long ext_dims[DIMS];
md_copy_dims(DIMS, ext_dims, img_dims);
ext_dims[BATCH_DIM] += ropts->svars;

ext_dims[BATCH_DIM] += ropts->svars;

int ext_shift = 1;
int nr_penalties = ropts->r;

long blkdims[MAX_LEV][DIMS];
int levels;


for (int nr = 0, nr2 = 0; nr < nr_penalties; nr++, nr2++) {
for (int nr = 0; nr < ropts->r; nr++) {

// fix up regularization parameter

if (-1. == regs[nr].lambda)
regs[nr].lambda = lambda;

Expand All @@ -331,7 +345,7 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c
long thresh_dims[N];
long img_strs[N];

assert(nr2 < NUM_REGS);
assert(nr_penalties < NUM_REGS);

switch (regs[nr].xform) {

Expand All @@ -348,8 +362,8 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c
}
}

trafos[nr2] = linop_identity_create(DIMS, img_dims);
prox_ops[nr2] = prox_wavelet_thresh_create(DIMS, img_dims, wflags, regs[nr].jflags, minsize, regs[nr].lambda, randshift);
trafos[nr] = linop_identity_create(DIMS, img_dims);
prox_ops[nr] = prox_wavelet_thresh_create(DIMS, img_dims, wflags, regs[nr].jflags, minsize, regs[nr].lambda, randshift);
break;

case NIHTWAV:
Expand All @@ -373,10 +387,10 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c
}
}

trafos[nr2] = linop_wavelet_create(N, wflags, img_dims, img_strs, minsize, randshift);
trafos[nr] = linop_wavelet_create(N, wflags, img_dims, img_strs, minsize, randshift);

long wav_dims[DIMS];
md_copy_dims(DIMS, wav_dims, linop_codomain(trafos[nr2])->dims);
md_copy_dims(DIMS, wav_dims, linop_codomain(trafos[nr])->dims);

unsigned int K = (md_calc_size(wxdim, wav_dims) / 100) * regs[nr].k;

Expand All @@ -388,7 +402,7 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c

debug_printf(DP_DEBUG3, "]\n");

prox_ops[nr2] = prox_niht_thresh_create(N, wav_dims, K, regs[nr].jflags);
prox_ops[nr] = prox_niht_thresh_create(N, wav_dims, K, regs[nr].jflags);
break;

case NIHTIM:
Expand All @@ -404,34 +418,50 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c

debug_printf(DP_INFO, "k = %d%%, actual K = %d\n", regs[nr].k, K);

trafos[nr2] = linop_identity_create(DIMS, img_dims);
prox_ops[nr2] = prox_niht_thresh_create(N, img_dims, K, regs[nr].jflags);
trafos[nr] = linop_identity_create(DIMS, img_dims);
prox_ops[nr] = prox_niht_thresh_create(N, img_dims, K, regs[nr].jflags);

break;

case TV:

debug_printf(DP_INFO, "TV regularization: %f\n", regs[nr].lambda);

trafos[nr2] = linop_grad_create(DIMS, img_dims, DIMS, regs[nr].xflags);
prox_ops[nr2] = prox_thresh_create(DIMS + 1,
linop_codomain(trafos[nr2])->dims,
trafos[nr] = linop_grad_create(DIMS, img_dims, DIMS, regs[nr].xflags);
prox_ops[nr] = prox_thresh_create(DIMS + 1,
linop_codomain(trafos[nr])->dims,
regs[nr].lambda, regs[nr].jflags | MD_BIT(DIMS));
break;

case TGV:

debug_printf(DP_INFO, "TGV regularization: %f\n", regs[nr].lambda);

struct reg2 reg2 = tgvreg(regs[nr].xflags, regs[nr].jflags /*| MD_BIT(DIMS - 1)*/ | MD_BIT(DIMS), regs[nr].lambda, DIMS, img_dims);
struct reg2 reg2 = tgv_reg(regs[nr].xflags, regs[nr].jflags /*| MD_BIT(DIMS - 1)*/ | MD_BIT(DIMS), regs[nr].lambda, DIMS, ext_dims, &ext_shift);

trafos[nr2 + 0] = reg2.linop[0];
trafos[nr2 + 1] = reg2.linop[1];
trafos[nr] = reg2.linop[0];
prox_ops[nr] = reg2.prox[0];

prox_ops[nr2 + 0] = reg2.prox[0];
prox_ops[nr2 + 1] = reg2.prox[1];
trafos[nr_penalties] = reg2.linop[1];
prox_ops[nr_penalties] = reg2.prox[1];

nr2++;
nr_penalties++;

break;

case ICTV:

debug_printf(DP_INFO, "ICTV regularization: %f\n", regs[nr].lambda);

reg2 = ictv_reg(regs[nr].xflags & FFT_FLAGS, regs[nr].xflags & ~FFT_FLAGS, regs[nr].jflags | MD_BIT(DIMS), regs[nr].lambda, DIMS, ext_dims, &ext_shift);

trafos[nr] = reg2.linop[0];
prox_ops[nr] = reg2.prox[0];

trafos[nr_penalties] = reg2.linop[1];
prox_ops[nr_penalties] = reg2.prox[1];

nr_penalties++;

break;

Expand All @@ -453,9 +483,10 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c

assert(9 == md_calc_size(DIMS, krn_dims));

trafos[nr2] = linop_conv_create(DIMS, regs[nr].xflags, CONV_TRUNCATED, CONV_SYMMETRIC, img_dims, img_dims, krn_dims, krn);
prox_ops[nr2] = prox_thresh_create(DIMS,
linop_codomain(trafos[nr2])->dims,
trafos[nr] = linop_conv_create(DIMS, regs[nr].xflags, CONV_TRUNCATED, CONV_SYMMETRIC, img_dims, img_dims, krn_dims, krn);

prox_ops[nr] = prox_thresh_create(DIMS,
linop_codomain(trafos[nr])->dims,
regs[nr].lambda, regs[nr].jflags);
break;

Expand Down Expand Up @@ -483,8 +514,8 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c

int remove_mean = 0;

trafos[nr2] = linop_identity_create(DIMS, img_dims);
prox_ops[nr2] = lrthresh_create(img_dims, randshift, regs[nr].xflags, (const long (*)[DIMS])blkdims, regs[nr].lambda, false, remove_mean, overlapping_blocks);
trafos[nr] = linop_identity_create(DIMS, img_dims);
prox_ops[nr] = lrthresh_create(img_dims, randshift, regs[nr].xflags, (const long (*)[DIMS])blkdims, regs[nr].lambda, false, remove_mean, overlapping_blocks);
break;

case MLR:
Expand Down Expand Up @@ -520,86 +551,91 @@ void opt_reg_configure(int N, const long img_dims[N], struct opt_reg_s* ropts, c

debug_printf(DP_INFO, "l1 regularization of imaginary part: %f\n", regs[nr].lambda);

trafos[nr2] = linop_rdiag_create(DIMS, img_dims, 0, &(complex float){ 1.i });
prox_ops[nr2] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags);
trafos[nr] = linop_rdiag_create(DIMS, img_dims, 0, &(complex float){ 1.i });
prox_ops[nr] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags);
break;

case IMAGL2:

debug_printf(DP_INFO, "l2 regularization of imaginary part: %f\n", regs[nr].lambda);

trafos[nr2] = linop_rdiag_create(DIMS, img_dims, 0, &(complex float){ 1.i });
prox_ops[nr2] = prox_leastsquares_create(DIMS, img_dims, regs[nr].lambda, NULL);
trafos[nr] = linop_rdiag_create(DIMS, img_dims, 0, &(complex float){ 1.i });
prox_ops[nr] = prox_leastsquares_create(DIMS, img_dims, regs[nr].lambda, NULL);
break;

case L1IMG:

debug_printf(DP_INFO, "l1 regularization: %f\n", regs[nr].lambda);

trafos[nr2] = linop_identity_create(DIMS, img_dims);
prox_ops[nr2] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags);
trafos[nr] = linop_identity_create(DIMS, img_dims);
prox_ops[nr] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags);
break;

case POS:

debug_printf(DP_INFO, "non-negative constraint\n");

trafos[nr2] = linop_identity_create(DIMS, img_dims);
prox_ops[nr2] = prox_nonneg_create(DIMS, img_dims);
trafos[nr] = linop_identity_create(DIMS, img_dims);
prox_ops[nr] = prox_nonneg_create(DIMS, img_dims);
break;

case L2IMG:

debug_printf(DP_INFO, "l2 regularization: %f\n", regs[nr].lambda);

trafos[nr2] = linop_identity_create(DIMS, img_dims);
prox_ops[nr2] = prox_leastsquares_create(DIMS, img_dims, regs[nr].lambda, NULL);
trafos[nr] = linop_identity_create(DIMS, img_dims);
prox_ops[nr] = prox_leastsquares_create(DIMS, img_dims, regs[nr].lambda, NULL);
break;

case FTL1:

debug_printf(DP_INFO, "l1 regularization of Fourier transform: %f\n", regs[nr].lambda);

trafos[nr2] = linop_fft_create(DIMS, img_dims, regs[nr].xflags);
prox_ops[nr2] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags);
trafos[nr] = linop_fft_create(DIMS, img_dims, regs[nr].xflags);
prox_ops[nr] = prox_thresh_create(DIMS, img_dims, regs[nr].lambda, regs[nr].jflags);
break;

case TENFL:

debug_printf(DP_INFO, "TensorFlow Loss: %f %s\n", regs[nr].lambda, regs[nr].graph_file);

trafos[nr2] = linop_identity_create(DIMS, img_dims);
trafos[nr] = linop_identity_create(DIMS, img_dims);

const struct nlop_s* tf_ops = nlop_tf_create(1, 1, regs[nr].graph_file, true);

// with one step, this only does one gradient descent step

auto prox_op = prox_nlgrad_create(tf_ops, 1, 1., regs[nr].lambda);

prox_ops[nr2] = op_p_auto_normalize(prox_op, ~0LU, NORM_MAX);
prox_ops[nr] = op_p_auto_normalize(prox_op, ~0LU, NORM_MAX);

operator_p_free(prox_op);

break;
}

if ((0 < ropts->svars) && (TGV != regs[nr].xform)) {
// if there are supporting variables, extract the main variables by default

if ( (0 < ropts->svars)
&& !( (TGV == regs[nr].xform)
|| (ICTV == regs[nr].xform))) {

long pos[DIMS] = { 0 };

trafos[nr2] = linop_chain_FF(
linop_extract_create(DIMS, pos, linop_domain(trafos[nr2])->dims, ext_dims),
trafos[nr2]);
trafos[nr] = linop_chain_FF(
linop_extract_create(DIMS, pos, linop_domain(trafos[nr])->dims, ext_dims),
trafos[nr]);
}

nr2++;
}

assert(ext_shift == 1 + ropts->svars);
assert(nr_penalties == ropts->r + ropts->sr);
}


void opt_reg_free(struct opt_reg_s* ropts, const struct operator_p_s* prox_ops[NUM_REGS], const struct linop_s* trafos[NUM_REGS])
{
int nr_penalties = ropts->r;
int nr_penalties = ropts->r + ropts->sr;

for (int nr = 0; nr < nr_penalties; nr++) {

Expand Down
3 changes: 2 additions & 1 deletion src/grecon/optreg.h
Expand Up @@ -18,7 +18,7 @@ struct linop_s;

struct reg_s {

enum { L1WAV, NIHTWAV, NIHTIM, TV, LLR, MLR, IMAGL1, IMAGL2, L1IMG, L2IMG, FTL1, LAPLACE, POS, TENFL, TGV } xform;
enum { L1WAV, NIHTWAV, NIHTIM, TV, LLR, MLR, IMAGL1, IMAGL2, L1IMG, L2IMG, FTL1, LAPLACE, POS, TENFL, TGV, ICTV } xform;

unsigned int xflags;
unsigned int jflags;
Expand All @@ -35,6 +35,7 @@ struct opt_reg_s {
struct reg_s regs[NUM_REGS];
int r;
int svars;
int sr;
};


Expand Down

0 comments on commit 5393516

Please sign in to comment.