In [None]:
import numpy as np
import matplotlib.pyplot as plt


def sqrt_tp(probs, etas):
    tp = probs * etas
    tp = np.mean(tp, axis=-1)
    srqt_tp = np.sqrt(tp + 1e-3)
    return np.mean(srqt_tp, axis=-1)


def recall(probs, etas):
    tp = probs * etas
    tp = np.sqrt(np.sum(tp, axis=-1))
    return np.sum(tp, axis=-1)


def score_plot(etas, ax, measure):
    a = np.linspace(0, 1)
    xx, yy = np.meshgrid(a, a)
    # xx, zz -- instance 1
    # yy, aa -- instance 2 
    zz = 1 - xx
    aa = 1 - yy

    #xx[:, :] = 30
    #yy[:, :] = 40
    #zz[:, :] = 70
    #aa[:, :] = 60

    result = np.stack([xx, yy, zz, aa], axis=-1).reshape((50, 50, 2, 2))
    #print(np.sum(result, axis=-2))
    #print(result[0, 0])
    #print(result[0, 0, 0])

    scores = measure(result, etas)
    ax.imshow(scores)
    ticks = 5
    ax.set(xticks=np.linspace(0, a.shape[0] - 1, ticks), xticklabels=np.linspace(0, 1, ticks))
    ax.set(yticks=np.linspace(0, a.shape[0] - 1, ticks), yticklabels=np.linspace(0, 1, ticks))

    print(f"Best score: {np.max(scores.reshape(-1))}")
    best = np.argmax(scores.reshape(-1))
    print(f"Best solution: {result.reshape((-1, 2, 2))[best]}")


fig, axes = plt.subplots(1, 2)
etas = np.array([[0.7, 0.4], [0, 0.3]])
score_plot(etas, axes[0], sqrt_tp)

etas = np.array([[0.7, 0.4], [0.3, 0]])
score_plot(etas, axes[1], sqrt_tp)

plt.show()

In [None]:
from src.frank_wolfe import frank_wolfe
import torch

def sqrt_tp_C(C):
    return torch.sqrt(C[:,0] + 1e-3)

classifiers, classifiers_weights, meta = frank_wolfe(etas, etas, sqrt_tp_C, max_iters=100, k=1)