In [1]:
import matplotlib.pyplot as plt
import pyGMs as gm
import numpy as np
import torch
import pyro
import pyro.infer
import pyro.optim
import pyro.distributions as dist
import pyro.poutine as poutine

import pandas as pd

seed = 123
pyro.set_rng_seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

<torch._C.Generator at 0x133e97970>

In [8]:
golden_sample = torch.load('mcmc_samples_20000.pt')
dataset = torch.load("dataset.pt")
teams = torch.load("teams.pt")
teams_number = list(range(len(teams)))

In [21]:
def construct_ranking(samples: dict, teams) -> dict:
    return {
        team: rank for rank, team in enumerate(
            sorted(teams,
                   key=lambda i: samples[f"X{i}"].mean(), reverse=True)
        )
    }

In [48]:
def model(matches, n_teams):
    X = [pyro.sample(f"X{i}", dist.Normal(0, 2)) for i in range(n_teams)]

    for i, m in enumerate(matches):
        pyro.sample(f"w{i}",
                    dist.Bernoulli(logits=X[m[0]]-X[m[1]]),
                    obs=torch.tensor(1.))

In [61]:
mcmc_runs = []

for n in [1, 5, 10, 50, 100, 200]:
    nuts_kernel = pyro.infer.NUTS(model, jit_compile=False)
    mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=n,
                           warmup_steps=n//5, num_chains=1)
    mcmc.run(dataset, len(teams))

    mcmc_runs.append((n, mcmc.get_samples()))

Sample: 100%|██████████| 1/1 [00:01,  1.41s/it, step size=5.00e-01, acc. prob=1.000]
Sample: 100%|██████████| 6/6 [00:04,  1.29it/s, step size=5.00e-01, acc. prob=0.616]
Sample: 100%|██████████| 12/12 [00:02,  5.68it/s, step size=7.19e+00, acc. prob=0.000]
Sample: 100%|██████████| 60/60 [01:34,  1.57s/it, step size=1.91e-01, acc. prob=0.888]
Sample:  90%|█████████ | 108/120 [01:29,  1.35it/s, step size=6.68e-01, acc. prob=0.470]

In [None]:
golden_ranking = construct_ranking(golden_sample, teams_number)

X = [golden_ranking[i] for i in teams_number]

# plot results in a subplot grid
fig, axs = plt.subplots(2, 3, figsize=(15, 10))

for i, (n, samples) in enumerate(mcmc_runs):
    ranking = construct_ranking(samples, teams_number)
    Y = [ranking[i] for i in teams_number]

    ax = axs[i//3, i % 3]
    ax.plot(X, Y, 'o', alpha=0.5)
    ax.set_title(f"n={n}")