In [1]:
import random
from itertools import chain

import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import trange

from models import YNet
from factor_clique import FactorClique

In [2]:
INT_DIM = 32
RULE_DIM = 64
HIDDEN_DIM = 256
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
ynet = YNet([INT_DIM, HIDDEN_DIM, HIDDEN_DIM, INT_DIM], 
            [RULE_DIM, HIDDEN_DIM, HIDDEN_DIM, RULE_DIM], 
            [INT_DIM + RULE_DIM, HIDDEN_DIM, HIDDEN_DIM, 1], 
            nn.ELU,
            )
ynet.to(device)

YNet(
  (left_pipe): MLP(
    (network): Sequential(
      (0): Linear(in_features=32, out_features=256, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ELU(alpha=1.0)
      (4): Linear(in_features=256, out_features=32, bias=True)
    )
  )
  (right_pipe): MLP(
    (network): Sequential(
      (0): Linear(in_features=64, out_features=256, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ELU(alpha=1.0)
      (4): Linear(in_features=256, out_features=64, bias=True)
    )
  )
  (out_pipe): MLP(
    (network): Sequential(
      (0): Linear(in_features=96, out_features=256, bias=True)
      (1): ELU(alpha=1.0)
      (2): Linear(in_features=256, out_features=256, bias=True)
      (3): ELU(alpha=1.0)
      (4): Linear(in_features=256, out_features=1, bias=True)
    )
  )
)

In [None]:
# clique parameters
min_factor = 11
max_factor = 50
max_val = 1000

# num sample parameters
num_cliques = 10
num_examples = 5
num_samples = 10

# training parameters
num_inner_iters = 20
num_outer_iters = 100000
inner_lr = 1e-2
outer_lr = 1e-3

# used for inner loop
max_score = 2

losses = []
optimizer = optim.Adam(ynet.parameters(), lr=outer_lr)
for _ in trange(num_outer_iters):
    factors = [random.randint(11, max_factor) for _ in range(num_cliques)]
    cliques = [FactorClique(factor, max_val) for factor in factors]
    examples = [clique.generate_examples(num_examples) for clique in cliques]

    # you can't start these values at 0 because then you'll just copy the first sample over and over
    # again: this setup will start with prob of 0.5 but will approach the ideal prob over time
    numerator_fill = (num_examples + num_samples) * max_score / 2
    denominator_fill = (num_examples + num_samples) * max_score
    rule_numerator = torch.full((num_cliques, RULE_DIM), numerator_fill, dtype=torch.float)
    rule_denominator = torch.full((num_cliques, RULE_DIM), denominator_fill, dtype=torch.float)
    target = torch.tile(torch.cat((torch.ones(num_examples), torch.zeros(num_samples))), (num_cliques,))
    loss_fn = nn.BCEWithLogitsLoss()
    unreduced_loss_fn = nn.BCEWithLogitsLoss(reduction='none')
    # no backprop during inner loop
    with torch.no_grad():
        for _ in range(num_inner_iters):
            samples = [clique.generate_samples(num_samples) for clique in cliques]
            x = cliques[0].encode_samples(list(chain(*[a + b for a, b in zip(examples, samples)])))
            rule_sample = torch.bernoulli(rule_numerator / rule_denominator)
            y = torch.repeat_interleave(rule_sample, num_examples + num_samples, dim=0)
            x.to(device), y.to(device)
            preds = ynet(x, y).squeeze()
            loss = unreduced_loss_fn(preds, target)
            # flip and shift so that better results mean higher scores
            score = torch.clamp(max_score - loss, min=0)
            score = score.reshape((-1, num_cliques)).sum(dim=0).unsqueeze(-1)
            rule_numerator += score * rule_sample
            rule_denominator += score

    # in the outer loop, we're at test time, so we use ground truth counterexamples instead of random samples
    counter = [clique.generate_counterexamples(num_samples) for clique in cliques]
    x = cliques[0].encode_samples(list(chain(*[a + b for a, b in zip(examples, counter)])))
    rule_sample = torch.bernoulli(rule_numerator / rule_denominator)
    y = torch.repeat_interleave(rule_sample, num_examples + num_samples, dim=0)
    x.to(device), y.to(device)
    preds = ynet(x, y).squeeze()
    loss = loss_fn(preds, target)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    losses.append(loss.item())

100%|██████████| 1000/1000 [00:42<00:00, 23.47it/s]
