# Estimating Proportions by Gradient Descent

## Scenario

Suppose I have two boxes (A and B), each of which have a bunch of small beads in them. Peeking inside, it looks like there are 3 different colors of beads (red, orange, and yellow), but the two boxes have very different colors.

Each box has a lever on it. When I push the lever, a bead comes out of the box. (We can assume it's a random one, and we'll put the bead back in the box it came from so we don't lose beads.)

My friend suggests we play a game: they'll pick a box and press the lever a few times; I have to guess what color beads are going to come out. But I complain that I'm never going to be able to guess 100% correctly, since the boxes have mixtures of beads in them. So here's what they propose: I can spread out my one guess among the different colors, e.g., 0.5 for red and 0.25 for orange or yellow--as long as they add up to 1. Okay...sounds good?

Even though there's no way I could count the number of each color bead in each box (way too many!), I think I can do well at this game after a few rounds. What do you think?

## Setup

In [None]:
import torch
from torch import tensor
import matplotlib.pyplot as plt
%matplotlib inline
torch.manual_seed(0);

Define the true proportions of the 3 colors in each box.

In [None]:
boxes = tensor([
    [600, 550, 350],
    [100, 1300, 100]
]).float()

Here's how the friend is going to pick which box. We'll get to see which box they pick.

In [None]:
def pick_box():
    return int(torch.rand(1) < .5)
pick_box()

1

In [None]:
def draw_beads(box, num_beads):
    return torch.multinomial(boxes[box], num_beads, replacement=True)
example_beads = draw_beads(box=0, num_beads=10); example_beads

tensor([0, 0, 0, 0, 0, 0, 0, 0, 1, 0])

# My guesses

In [None]:
my_guesses = torch.ones((2, 3)) / 3
def get_guess(box):
    guesses_for_box = my_guesses[box]
    return guesses_for_box

example_guess = get_guess(0); example_guess

tensor([0.3333, 0.3333, 0.3333])

## My score

In [None]:
example_guess[example_beads]

tensor([0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333, 0.3333,
        0.3333])

In [None]:
def score_guesses(guess, beads):
    return guess[beads].mean()
score_guesses(example_guess, example_beads)

tensor(0.3333)

# Learning to play the game

In [None]:
my_guesses = torch.ones((2, 3)) / 3.0
my_guesses.requires_grad_()

scores = []
for i in range(50):
    box = pick_box()                       # friend picks a box
    my_guess = get_guess(box)              # I make a guess
    assert (my_guess > 0).all()
    assert (my_guess.sum() - 1.0).abs() < .01

    beads = draw_beads(box, 10)            # friend draws a bunch of beads
    score = score_guesses(my_guess, beads) # friend computes my score
    scores.append(score.item())

    # I figure out how I should have guessed differently
    score.backward()
    my_guesses.data -= my_guesses.grad

plt.plot(scores)


AssertionError: ignored

In [None]:
torch.stack([get_guess(box=0), get_guess(box=1)])