**ACHTUNG** This tutorial is graded. Please upload your work on Moodle before the 15th of February.


In [1]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

pio.templates.default = "plotly_white"


# Exercise 1: soft-thresholding


Define a function to compute the soft-thresholding operator, that is the proximal operator of the $\ell_1$-norm.


In [2]:
def soft_thresholding(x, tau: float):
  """
  Input: x is either a scalar, a vector or an image
         tau is a strictly positive scalar
  The function computes the soft thresholding of x with parameter tau.
  Output: a scalar, a vector or an image (with the same dimensions as the input x)
  """
  if x < -tau:
    return x + tau
  elif x > tau:
    return x - tau
  return 0


Plot the soft-thresholding function in $[-2,2]$ for $\tau=0.5$ using the function defined above.


In [3]:
tau = 0.5
x_range = np.linspace(-2, 2, 500)
fig = go.Figure()
fig.add_trace(
  go.Scatter(
    x=x_range,
    y=[soft_thresholding(x, tau) for x in x_range],
    mode="lines",
    name="soft thresholding",
  )
)
fig.update_layout(
  title="Soft thresholding",
  xaxis_title="x",
  yaxis_title="soft thresholding(x)",
  legend_title="Legend Title",
)


# Exercise 2: hard-thresholding


Define a function to compute the hard-thresholding operator, that is the proximal operator of the $\ell_0$-norm.


In [4]:
def hard_thresholding(x, tau):
  """Input: x is either a scalar, a vector or an image
         tau is a strictly positive scalar
  The function computes the soft thresholding of x with parameter tau.
  Output: a scalar, a vector or an image (with the same dimensions as the input x)"""
  if x**2 < 2 * tau:
    return 0
  return x


Plot the hard-thresholding function in $[-2,2]$ for $\tau=0.5$ using the function defined above.


In [5]:
fig = go.Figure()
fig.add_trace(
  go.Scatter(
    x=x_range,
    y=[hard_thresholding(x, tau) for x in x_range],
    mode="lines",
    name="hard thresholding",
  )
)
fig.update_layout(
  title="Hard thresholding",
  xaxis_title="x",
  yaxis_title="hard thresholding(x)",
  legend_title="Legend Title",
)


# Exercise 3: non-negativity constraints


Plot the proximal operator of $\tau \|\cdot\|_1+{\chi}_{\ge 0}(\cdot)$ in $[-2,2]$ for $\tau=0.5$ using the function defined above.


In [6]:
def non_negative_constraint(x):
  """Input: x is either a scalar, a vector or an image
  The function computes the non-negative constraint of x.
  Output: a scalar, a vector or an image (with the same dimensions as the input x)"""
  return max(x, 0)


fig = go.Figure()
fig.add_trace(
  go.Scatter(
    x=x_range,
    y=[non_negative_constraint(x) for x in x_range],
    mode="lines",
    name="non-negative constraint",
  )
)
fig.update_layout(
  title="Non-negative constraint",
  xaxis_title="x",
  yaxis_title="non-negative constraint(x)",
  legend_title="Legend Title",
)


# Exercise 4: elastic net


Compute the proximal operator of the elastic net functional
$f(x) =\tau \|x\|_1 + \lambda/2 \|x\|_2^2$.


In [7]:
def prox_elastic_net(x, tau, lambd):
  """
  Input: x is either a scalar, a vector or an image
         tau is a strictly positive scalar
         lambd is a strictly positive scalar
  The function computes the proximal operator of the elastic net penalty with parameters tau and lambda.
  Output: a scalar, a vector or an image (with the same dimensions as the input x)
  """
  return soft_thresholding(x, tau * lambd) / (1 + tau * (1 - lambd))


Plot the proximal operator of the elastic net in $[-2,2]$ for $\tau=0.5$ using the function defined above.


In [8]:
fig = go.Figure()
fig.add_trace(
  go.Scatter(
    x=x_range,
    y=[prox_elastic_net(x, tau, 0.5) for x in x_range],
    mode="lines",
    name="proximal operator of the elastic net penalty with parameters tau and lambda",
  )
)
fig.update_layout(
  title="Proximal operator of the elastic net penalty with parameters tau and lambda",
  xaxis_title="x",
  yaxis_title="prox_elastic_net(x, tau, lambda)",
  legend_title="Legend Title",
)
