In [2]:
import torch

from src.init_states import BaseState, ZerosState, TrainableState
from src.init_vars import BaseVar, BasicVar
from src.init_contexts import BaseContext, EmptyContext
from src.embeddings import BaseEmbedding, BasicEmbedding, IdentityEmbedding

from src.encoders import RNNEncoder
from src.decoders import RNNDecoder
from src.encoder_decoder import EncoderDecoder
from src.baselines import BaselineRollout

from src.train import train
from src.generator import UniformGenerator
import src.utils as utils

import torch.optim as optim

## Init states

BaseState

In [None]:
init_state = BaseState()
state = init_state()
# Expected return: NotImplementedError

ZerosState

In [5]:
cell = 'GRU'
hidden_size = 4
num_layers = 1

enc_output = None
batch_size = None

init_state = ZerosState()
state = init_state(enc_output, batch_size)
# ::state:: [num_layers, batch_size, hidden_size]
print(state)

None


TrainableState

In [13]:
cell = 'GRU'
hidden_size = 4
num_layers = 1

enc_output = None
batch_size = 2

init_state = TrainableState(cell, hidden_size, num_layers, a=-0.8, b=0.8)
state = init_state(enc_output, batch_size)
# ::state:: [num_layers, batch_size, hidden_size]
state

tensor([[[ 0.3814,  0.2384, -0.7665,  0.3343],
         [ 0.3814,  0.2384, -0.7665,  0.3343]]], grad_fn=<ExpandBackward0>)

## Init var

BaseVar

In [None]:
enc_output = None
formula = None
num_variables = None
variables = None

init_dec_var = BaseVar()
var = init_dec_var(enc_output, formula, num_variables, variables)
# Expected return: NotImplementedError


BasicVar

In [18]:
enc_output = None
formula = None
num_variables = 5
variables =None

init_dec_var = BasicVar()
var = init_dec_var(enc_output, formula, num_variables, variables)
# ::var:: [batch_size=1, seq_len=num_variables, feature_size=1]
var

tensor([[[0],
         [1],
         [2],
         [3],
         [4]]])

## Init contexts

BaseContext

In [None]:
enc_output = None
formula = None
num_variables = None
variables = None
batch_size = None

init_dec_context = BaseContext()
context = init_dec_context(enc_output, formula, num_variables, variables, batch_size)
# Expected return: NotImplementedError

EmptyContext

In [10]:
enc_output = None
formula = None
num_variables = 5
variables = None
batch_size = 2

init_dec_context = EmptyContext()
context = init_dec_context(enc_output, formula, num_variables, variables, batch_size)
# ::context:: [batch_size, feature_size=0]
context.shape

torch.Size([2, 0])

## Embeddings

BaseEmbedding

In [None]:
batch_size = 2
seq_len = 4
feature_size = 5
X = torch.rand((batch_size, seq_len, feature_size))
# ::X:: [batch_size, seq_len, features_size]

embedding = BaseEmbedding()
X = embedding(X)
# Expected return: NotImplementedError

BasicEmbedding

In [14]:
num_labels = 5
embedding_size = 10

batch_size = 2
seq_len = 4
X = torch.randint(low=0, high=num_labels-1, size=(batch_size, seq_len, 1), dtype=torch.long)
# ::X:: [batch_size, seq_len, features_size=1]

embedding = BasicEmbedding(num_labels, embedding_size)
X = embedding(X)
# ::X:: [batch_size, seq_len, num_features=embedding_size]
X.shape

torch.Size([2, 4, 10])

IdentityEmbedding

In [17]:
batch_size = 2
seq_len = 4
feature_size = 5
X = torch.rand((batch_size, seq_len, feature_size))
# ::X:: [batch_size, seq_len, features_size]

embedding = IdentityEmbedding()
X = embedding(X)
# ::X:: [batch_size, seq_len, features_size]
X.shape

torch.Size([2, 4, 5])

## Encoders

In [18]:
cell = 'GRU'
embedding_size = 7
hidden_size = 16
num_layers = 1

enc_embedding = IdentityEmbedding()
X = torch.rand((2, 5, embedding_size))
# ::X:: [batch_size, seq_len, features_size=embedding_size]

encoder = RNNEncoder(cell = cell,
                     embedding = enc_embedding,
                     embedding_size = embedding_size,
                     hidden_size = hidden_size,
                     num_layers = num_layers,
                     dropout = 0)
encoder.eval()
output, state = encoder(X)
# ::output:: [seq_len, batch_size, hidden_size]
# ::state:: [num_layers, batch_size, hidden_size]
print(output.shape) 

torch.Size([5, 2, 16])


## Decoders

In [19]:
cell = 'GRU'
hidden_size = 16
num_layers = 1
dropout = 0
clip_logits_c = 0

batch_size = 2
seq_len = 3
embedding_size = 8
assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=5, embedding_size=embedding_size)
input_size = embedding_size * 3

var = torch.randint(0, 4, [batch_size, seq_len, 1])
# ::var:: [batch_size, seq_len, feature_size]
a_prev = torch.randint(0, 2, [batch_size, seq_len, 1])
# ::a_prev:: [batch_size, seq_len, feaure_size=1]
context = torch.rand([batch_size, embedding_size])
# ::context:: [batch_size, feature_size]
state = torch.rand([num_layers, batch_size, hidden_size])
# ::state:: [num_layers, batch_size, hidden_size]
X = (var, a_prev, context)

decoder  = RNNDecoder(input_size = input_size,
                      cell = cell,
                      assignment_emb = assignment_emb,
                      variable_emb = variable_emb,
                      hidden_size = hidden_size,
                      num_layers = num_layers,
                      dropout = 0,
                      clip_logits_c = 0)
decoder.eval()
output, state = decoder(X, state)
# ::output:: [batch_size, seq_len, 2]
# ::state:: [num_layers, batch_size, hidden_size]
print(output.shape)
print(state.shape)

torch.Size([2, 3, 2])
torch.Size([1, 2, 16])


## Encoder-Decoder

In [6]:
num_variables = 5
variables = None

cell = 'GRU'
hidden_size = 16
num_layers = 1
dropout = 0
clip_logits_c = 0

embedding_size = 8
assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=num_variables, embedding_size=embedding_size)
input_size = embedding_size * 2

init_dec_state = TrainableState(cell, hidden_size, num_layers, a=-0.8, b=0.8)

encoder = None
decoder = RNNDecoder(input_size = input_size,
                     cell = cell,
                     assignment_emb = assignment_emb,
                     variable_emb = variable_emb,
                     hidden_size = hidden_size,
                     num_layers = num_layers,
                     dropout = dropout,
                     clip_logits_c = clip_logits_c)

policy_network = EncoderDecoder(encoder=encoder,
                                decoder=decoder,
                                init_dec_var=None,
                                init_dec_context=None,
                                init_dec_state=init_dec_state)

utils.params_summary(policy_network)

decoder.assignment_embedding.embedding.weight torch.Size([8, 3])
decoder.assignment_embedding.embedding.bias torch.Size([8])
decoder.variable_embedding.embedding.weight torch.Size([8, 5])
decoder.variable_embedding.embedding.bias torch.Size([8])
decoder.rnn.weight_ih_l0 torch.Size([48, 16])
decoder.rnn.weight_hh_l0 torch.Size([48, 16])
decoder.rnn.bias_ih_l0 torch.Size([48])
decoder.rnn.bias_hh_l0 torch.Size([48])
decoder.dense_out.weight torch.Size([2, 16])
decoder.dense_out.bias torch.Size([2])
init_dec_state.h torch.Size([1, 1, 16])


## Train

In [15]:
#Create a sat generator
sat_gen = UniformGenerator(min_n = 5,
                           max_n = 5,
                           min_k = 3,
                           max_k = 3,
                           min_r = 4.2,
                           max_r = 4.2)

#Create a random sat formula
n, r, m, formula = sat_gen.generate_formula()

print(f'n: {n}')
print(f'r: {r}')
print(f'm: {m}')
print(formula)

n: 5
r: 4.2
m: 21
[[-5, -2, -3], [-3, -4, -1], [-3, -4, 1], [-5, -2, 4], [4, -3, 5], [-5, -2, -1], [-4, -1, -2], [-4, -5, -1], [-3, -5, -1], [-3, 1, -5], [1, 3, 4], [4, 3, 5], [2, 1, 3], [-2, -3, -5], [5, -1, 4], [1, -4, -5], [-4, -3, 2], [-2, -1, 5], [3, -2, 4], [-4, -1, -5], [-4, 2, 5]]


In [17]:
formula = formula
num_variables = 5
variables = None
num_episodes = 10
accumulation_steps = 1

cell = 'GRU'
hidden_size = 16
num_layers = 1
dropout = 0
clip_logits_c = 0

lr = 1e-3

embedding_size = 8
assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=num_variables, embedding_size=embedding_size)
input_size = embedding_size * 2

encoder = None
decoder = RNNDecoder(input_size = input_size,
                     cell = cell,
                     assignment_emb = assignment_emb,
                     variable_emb = variable_emb,
                     hidden_size = hidden_size,
                     num_layers = num_layers,
                     dropout = dropout,
                     clip_logits_c = clip_logits_c)
init_dec_state = TrainableState(cell, hidden_size, num_layers)

policy_network = EncoderDecoder(encoder=encoder,
                                decoder=decoder,
                                init_dec_var=None,
                                init_dec_context=None,
                                init_dec_state=init_dec_state)

optimizer = optim.Adam(policy_network.parameters(), lr=lr)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

baseline = None #BaselineRollout(-1)  # None
entropy_weight = 0
clip_val = 1
verbose = 2

history_loss, history_num_sat = train(formula,
                                    num_variables,
                                    variables,
                                    num_episodes,
                                    accumulation_steps,
                                    policy_network,
                                    optimizer,
                                    device,
                                    baseline,
                                    entropy_weight,
                                    clip_val,
                                    verbose)

Episode [1/10], Mean Loss 59.7945,  Mean num sat 19.0000
Episode [2/10], Mean Loss 80.7731,  Mean num sat 19.0000
Episode [3/10], Mean Loss 44.6474,  Mean num sat 18.0000
Episode [4/10], Mean Loss 75.6491,  Mean num sat 18.0000
Episode [5/10], Mean Loss 62.8202,  Mean num sat 18.0000
Episode [6/10], Mean Loss 70.7435,  Mean num sat 20.0000
Episode [7/10], Mean Loss 50.4938,  Mean num sat 18.0000
Episode [8/10], Mean Loss 63.2537,  Mean num sat 17.0000
Episode [9/10], Mean Loss 61.8325,  Mean num sat 18.0000
Episode [10/10], Mean Loss 50.0508,  Mean num sat 18.0000
