Tout d'abord on crée la fonction

In [13]:
import numpy as np
import random
import torch
torch.set_default_dtype(torch.float64)
from scipy.sparse.linalg import LinearOperator
from scipy.special import softmax

class Problem:
    def __init__(self, U, V, mu, nu, q, epsilon):
        self.U = U  # (nb_receivers, nb_states, nb_actions)
        self.V = V  # (nb_states, nb_actions)
        self.mu = mu  # (nb_states,)
        self.nu = nu  # (nb_receivers, nb_states)
        self.q = q  # (nb_receivers, nb_messages, nb_actions)
        self.epsilon = epsilon  # (nb_receivers,)
        self.nb_receivers, self.nb_states, self.nb_actions = U.shape
        self.nb_messages = q.shape[1]
        self.size = self.nb_receivers * self.nb_states * self.nb_messages
        self.shape = (self.nb_receivers, self.nb_states, self.nb_messages)
        self.check()
    
    def check(self):
        for receiver_idx in range(self.nb_receivers):
            self.debug_shape(self.U[receiver_idx], [self.nb_states, self.nb_actions])
            self.debug_shape(self.V, [self.nb_states, self.nb_actions])
            self.debug_shape(self.mu, [self.nb_states])
            self.debug_shape(self.nu[receiver_idx], [self.nb_states])
            self.debug_shape(self.q[receiver_idx], [self.nb_messages, self.nb_actions])

    def debug_shape(self, vect, target_shape):
        if list(vect.shape) != target_shape:
            print(f"Found vector of size {vect.shape}, expected {target_shape}")
            assert False
    
    def verbose(self, pi):
        print(f"We have (state, message, action) = ({self.nb_states}, {self.nb_messages}, {self.nb_actions})")
        self.debug_shape(pi, [self.nb_states, self.nb_messages])
        for receiver_idx in range(self.nb_receivers):
            theta = self.compute_theta(self.compute_g(pi, receiver_idx), receiver_idx)
            self.debug_shape(theta, [self.nb_messages, self.nb_actions])
            print(f"We have 1 = {pi.sum(axis=1)}")
            print(f"We have 1 = {theta.sum(axis=1)}")
            print("Theta", theta)
        print("Objective", self.objective(pi, range(self.nb_receivers)))
    
    def compute_g(self, pi, receiver_idx):
        denominator = (pi * self.nu[receiver_idx][:, None]).sum(axis=0)
        self.debug_shape(denominator, [self.nb_messages])
        g = (pi[:, :, None] * self.nu[receiver_idx][:, None, None] * self.U[receiver_idx][:, None, :]).sum(axis=0)
        self.debug_shape(g, [self.nb_messages, self.nb_actions])
        return g / denominator[:, None]

    def compute_theta(self, g, receiver_idx):
        max_g, _ = g.max(axis=1)
        exp = torch.exp((g - max_g[:, None]) / self.epsilon[receiver_idx])
        self.debug_shape(exp, [self.nb_messages, self.nb_actions])
        theta = self.q[receiver_idx] * exp
        denom = theta.sum(axis=1)
        return theta / denom[:, None]

    def objective(self, pi, receivers_batch):
        total_objective = 0
        for receiver_idx in receivers_batch:
            g = self.compute_g(pi, receiver_idx)
            theta = self.compute_theta(g, receiver_idx)
            total_objective += (theta[None, :, :] * pi[:, :, None] * self.mu[:, None, None] * self.V[:, None, :]).sum()
        return total_objective / len(receivers_batch)

    def value(self, x, receivers_batch=None):
        if receivers_batch is None:
            receivers_batch = range(self.nb_receivers)
        x = x.reshape(self.nb_states, self.nb_messages)
        x = torch.from_numpy(x).requires_grad_(True)
        f = self.objective(x, receivers_batch)
        f.backward()
        df = x.grad
        return -f.item(), -df.numpy()

    def project(self, x):
        x = torch.from_numpy(x).reshape(self.nb_states, self.nb_messages)
        x_projected = torch.zeros_like(x)
        for i in range(x.shape[0]):
            row = x[i, :]
            sorted_row, _ = torch.sort(row, descending=True)
            cumulative_sum = torch.cumsum(sorted_row, dim=0)
            rho = torch.nonzero(sorted_row * torch.arange(1, len(row) + 1) > (cumulative_sum - 1), as_tuple=False).max()
            theta = (cumulative_sum[rho] - 1) / (rho + 1)
            x_projected[i, :] = torch.clamp(row - theta, min=0)
        return x_projected.numpy()

    def project_tangent(self, x, d):
        d2 = d - d.mean(axis=1)[:, None]
        d2[(x == 0) * (d2 < 0)] = 0.
        d2[(x == 1) * (d2 > 0)] = 0.
        return d2


In [14]:


def dot(a, b):
    return (a * b).sum()

def ls_wolfe(x, function, step, descent, f, df, batch):
    """
    Wolfe line search for stochastic gradient descent.
    """
    step_min, step_max = 0., np.inf
    scal = dot(df, descent)
    if scal > 0:
        print('WARNING with scal', scal)
    step2 = step
    eps1, eps2 = 1.e-4, 0.9  # Wolfe condition parameters
    i = 0
    while i < 100:
        i += 1
        x2 = function.project(x + step2 * descent)
        f2, df2 = function.value(x2, batch)
        if dot(x2 - x, df) >= 0:
            print('We have a problem', dot(x2 - x, df), dot(descent, df))
        if f2 > f + eps1 * dot(x2 - x, df):  # step is too big, decrease it
            step_max = step2
            step2 = 0.5 * (step_min + step_max)
        else:
            if dot(df2, x2 - x) < eps2 * dot(df, x2 - x):  # step is too small, increase it
                step_min = step2
                step2 = min(0.5 * (step_min + step_max), 2 * step_min)
            else:
                return x2, f2, df2, step2
    print('We do not exit Wolfe')
    return x2, f2, df2, step2


def dot(a,b) :
    return (a*b).sum()

def ls_wolfe(x,function,step,descent,f,df,batch) :
    step_min,step_max=0.,np.inf
    scal=dot(df,descent)
    if scal > 0 :
        print('WARNING with scal',scal)
    step2=step
    eps1,eps2=1.e-4,0.9
    i=0
    while i<100 :
        i=i+1
        x2=function.project(x+step2*descent)
        f2,df2=function.value(x2,batch)
        if dot(x2-x,df) >=0 :
            print('We have a problem',dot(x2-x,df),dot(descent,df))
        if f2>f+eps1*dot(x2-x,df) : # step is too big, decrease it
            step_max=step2
            step2=0.5*(step_min+step_max)
        else :
            if dot(df2,x2-x) < eps2*dot(df,x2-x) : # step is too small, increase it
                step_min=step2
                step2=min(0.5*(step_min+step_max),2*step_min)
            else :
                return x2,f2,df2,step2
    print('We do not exit Wolfe')
    print(f2>f+eps1*step2*scal,dot(df2,descent) < eps2*scal)
    return x2,f2,df2,step2




def optimize(function,itermax = 5000,tol=1.e-6,batch_size = 100,verbose=True):
    np.random.seed(42)
    receivers = list(range(function.nb_receivers))
    x = np.random.randn(function.nb_states, function.nb_messages)
    x=function.project(x)
    np.random.seed(None)
    list_costs=[]
    list_grads=[]
    nbiter = 0
    batch = np.random.choice(receivers, size=batch_size, replace=False)
    f,df=function.value(x,batch)
    df_tangent=function.project_tangent(x,-df)
    norm_grad=np.linalg.norm(df_tangent)
    err=2*tol
    if verbose :
        print('iter={:4d} f={:1.3e} df={:1.3e}'.format(nbiter,f,err))
    list_costs.append(f)
    list_grads.append(norm_grad)
    while (err > tol) and (nbiter < itermax):
        descent=-df
        x_old=np.copy(x)
        x,f,df,step = ls_wolfe(x, function,1., descent,f,df,batch)
        batch = np.random.choice(receivers, size=batch_size, replace=False)
        norm_grad = np.linalg.norm(function.project_tangent(x,-df))
        list_costs.append(f)
        list_grads.append(norm_grad)
        err=norm_grad
        nbiter+=1
        if verbose :
            print('iter={:4d} f={:1.3e} err={:1.3e} s={:1.3e}'.format(nbiter,f,err,step))
        if (err <= tol):
            if verbose : print("Success !!! Algorithm converged !!!")
            return x,list_costs,list_grads
    if verbose : print("FAILED to converge")



On crée 1000 variations d'un receiver de base

In [15]:
alpha = 0.7
beta = 0.3

# Initialisation des matrices de base
nb_receivers = 1000
nb_states = 2
nb_actions = 2
nb_messages = 2

U_base = torch.tensor([[1.0, 0.0], [0.0, 1.0]], dtype=torch.float64)  # Matrice U de base
nu_base = torch.tensor([alpha, 1 - alpha], dtype=torch.float64)  # Distribution de probabilité de base pour nu
q_base = torch.tensor([beta, 1 - beta], dtype=torch.float64)  # Une seule ligne pour q de base
epsilon_base = 0.01  # Valeur de base pour epsilon

mu = torch.tensor([alpha, 1 - alpha])

# Fonction pour ajouter du bruit à une matrice tout en conservant les propriétés des distributions
def add_noise_to_distribution(base_vector, noise_level):
    noisy_vector = base_vector + noise_level * torch.randn_like(base_vector)
    noisy_vector = torch.clamp(noisy_vector, min=1e-8)  # Évite les valeurs négatives
    return noisy_vector / noisy_vector.sum()  # Normalisation pour conserver les distributions

# Génération des variations
noise_level_U = 0.01
noise_level_nu = 0.01
noise_level_q = 0.01
noise_level_epsilon = 0.01

U = torch.stack([U_base + noise_level_U * torch.randn_like(U_base) for _ in range(nb_receivers)])
nu = torch.stack([add_noise_to_distribution(nu_base, noise_level_nu) for _ in range(nb_receivers)])
q = torch.stack([add_noise_to_distribution(q_base, noise_level_q).expand(nb_messages, -1) for _ in range(nb_receivers)])
epsilon = torch.tensor([epsilon_base + noise_level_epsilon * np.random.randn() for _ in range(nb_receivers)])
epsilon = torch.clamp(epsilon, min=1e-8)  # Assurez-vous que epsilon reste positif

# Vérifications
print("Exemple de U :", U[0])
print("Exemple de nu (somme = 1) :", nu[0], "Somme =", nu[0].sum())
print("Exemple de q (lignes identiques, somme = 1) :", q[0], "Somme des lignes =", q[0].sum(dim=1))
print("Exemple de epsilon :", epsilon[:10])

V = torch.tensor([[0.0, 1.0],
                  [0.0, 1.0]])



Exemple de U : tensor([[ 1.0084, -0.0143],
        [ 0.0013,  1.0145]])
Exemple de nu (somme = 1) : tensor([0.6897, 0.3103]) Somme = tensor(1.)
Exemple de q (lignes identiques, somme = 1) : tensor([[0.3015, 0.6985],
        [0.3015, 0.6985]]) Somme des lignes = tensor([1., 1.])
Exemple de epsilon : tensor([1.0000e-08, 6.1013e-03, 1.5562e-02, 3.6889e-03, 1.3474e-02, 2.4862e-02,
        2.4202e-03, 2.3388e-02, 3.0227e-03, 4.6717e-03])


On optimise le "vrai" max (moyenne sur tous les Receivers) pour comparer ensuite avec l'optimisation sur des batchs

In [22]:
# Création de l'objet Problem avec les données générées
P = Problem(U, V, mu, nu, q, epsilon)

x,costs,grad=optimize(P,tol=1.e-4,verbose=True,batch_size = 1000)

iter=   0 f=-4.091e-01 df=2.000e-04
iter=   1 f=-5.482e-01 err=4.075e-01 s=5.000e-01
iter=   2 f=-5.512e-01 err=7.591e-01 s=1.250e-01
iter=   3 f=-5.568e-01 err=1.687e-01 s=3.125e-02
iter=   4 f=-5.575e-01 err=7.286e-02 s=6.250e-02
iter=   5 f=-5.575e-01 err=4.994e-02 s=6.250e-02
iter=   6 f=-5.575e-01 err=3.042e-02 s=6.250e-02
iter=   7 f=-5.576e-01 err=2.033e-02 s=6.250e-02
iter=   8 f=-5.576e-01 err=1.292e-02 s=6.250e-02
iter=   9 f=-5.576e-01 err=8.510e-03 s=6.250e-02
iter=  10 f=-5.576e-01 err=5.485e-03 s=6.250e-02
iter=  11 f=-5.576e-01 err=3.588e-03 s=6.250e-02
iter=  12 f=-5.576e-01 err=2.325e-03 s=6.250e-02
iter=  13 f=-5.576e-01 err=1.516e-03 s=6.250e-02
iter=  14 f=-5.576e-01 err=9.845e-04 s=6.250e-02
iter=  15 f=-5.576e-01 err=6.410e-04 s=6.250e-02
iter=  16 f=-5.576e-01 err=4.167e-04 s=6.250e-02
iter=  17 f=-5.576e-01 err=2.712e-04 s=6.250e-02
iter=  18 f=-5.576e-01 err=1.763e-04 s=6.250e-02
iter=  19 f=-5.576e-01 err=1.147e-04 s=6.250e-02
iter=  20 f=-5.576e-01 err=7.462e

In [23]:
x

array([[0.61758566, 0.38241434],
       [0.        , 1.        ]])

In [24]:
P.value(x)

(-0.557565229324088,
 array([[-3.37761417e-13,  1.05524736e-04],
        [-7.15813323e-12, -5.57605583e-01]]))

In [33]:
P = Problem(U, V, mu, nu, q, epsilon)

x,costs,grad=optimize(P,tol=1.e-1,verbose=True,batch_size = 100)

iter=   0 f=-4.091e-01 df=2.000e-01
iter=   1 f=-5.374e-01 err=1.685e+00 s=6.250e-01
iter=   2 f=-5.517e-01 err=3.517e-01 s=3.125e-02
iter=   3 f=-5.546e-01 err=6.108e-01 s=1.250e-01
iter=   4 f=-5.547e-01 err=1.567e-01 s=3.125e-02
iter=   5 f=-5.564e-01 err=2.060e-01 s=6.250e-02
iter=   6 f=-5.580e-01 err=4.068e-02 s=6.250e-02
Success !!! Algorithm converged !!!


In [34]:
x

array([[0.61619069, 0.38380931],
       [0.        , 1.        ]])

In [35]:
P.value(x)

(-0.5575115139988321,
 array([[-3.37761417e-13,  7.85272586e-02],
        [-7.15813323e-12, -5.87651007e-01]]))