In [11]:
import torch

from src.init_states import BaseState, ZerosState, TrainableState
from src.init_vars import BaseVar, BasicVar
from src.init_contexts import BaseContext, EmptyContext
from src.embeddings import BaseEmbedding, BasicEmbedding, IdentityEmbedding

from src.encoders import RNNEncoder
from src.decoders import RNNDecoder
from src.encoder_decoder import EncoderDecoder
from src.baselines import BaselineRollout

from src.train import train
from src.generator import UniformGenerator

import torch.optim as optim

## Init states

BaseState

In [None]:
init_state = BaseState()
state = init_state()
# Expected return: NotImplementedError

ZerosState

In [5]:
cell = 'GRU'
hidden_size = 4
num_layers = 1

enc_output = None
batch_size = None

init_state = ZerosState()
state = init_state(enc_output, batch_size)
# ::state:: [num_layers, batch_size, hidden_size]
print(state)

None


TrainableState

In [10]:
cell = 'GRU'
hidden_size = 4
num_layers = 1

enc_output = None
batch_size = 2

init_state = TrainableState(cell, hidden_size, num_layers, a=-0.8, b=0.8)
state = init_state(enc_output, batch_size)
# ::state:: [num_layers, batch_size, hidden_size]
state.shape

torch.Size([1, 2, 4])

## Init var

BaseVar

In [None]:
enc_output = None
formula = None
num_variables = None
variables = None

init_dec_var = BaseVar()
var = init_dec_var(enc_output, formula, num_variables, variables)
# Expected return: NotImplementedError


BasicVar

In [17]:
enc_output = None
formula = None
num_variables = 5
variables =None

init_dec_var = BasicVar()
var = init_dec_var(enc_output, formula, num_variables, variables)
# ::var:: [batch_size=1, seq_len=num_variables, feature_size=!]
var.shape

torch.Size([1, 5, 1])

## Init contexts

BaseContext

In [None]:
enc_output = None
formula = None
num_variables = None
variables = None
batch_size = None

init_dec_context = BaseContext()
context = init_dec_context(enc_output, formula, num_variables, variables, batch_size)
# Expected return: NotImplementedError

EmptyContext

In [10]:
enc_output = None
formula = None
num_variables = 5
variables = None
batch_size = 2

init_dec_context = EmptyContext()
context = init_dec_context(enc_output, formula, num_variables, variables, batch_size)
# ::context:: [batch_size, feature_size=0]
context.shape

torch.Size([2, 0])

## Embeddings

BaseEmbedding

In [None]:
batch_size = 2
seq_len = 4
feature_size = 5
X = torch.rand((batch_size, seq_len, feature_size))
# ::X:: [batch_size, seq_len, features_size]

embedding = BaseEmbedding()
X = embedding(X)
# Expected return: NotImplementedError

BasicEmbedding

In [14]:
num_labels = 5
embedding_size = 10

batch_size = 2
seq_len = 4
X = torch.randint(low=0, high=num_labels-1, size=(batch_size, seq_len, 1), dtype=torch.long)
# ::X:: [batch_size, seq_len, features_size=1]

embedding = BasicEmbedding(num_labels, embedding_size)
X = embedding(X)
# ::X:: [batch_size, seq_len, num_features=embedding_size]
X.shape

torch.Size([2, 4, 10])

IdentityEmbedding

In [17]:
batch_size = 2
seq_len = 4
feature_size = 5
X = torch.rand((batch_size, seq_len, feature_size))
# ::X:: [batch_size, seq_len, features_size]

embedding = IdentityEmbedding()
X = embedding(X)
# ::X:: [batch_size, seq_len, features_size]
X.shape

torch.Size([2, 4, 5])

## Encoders

In [18]:
cell = 'GRU'
embedding_size = 7
hidden_size = 16
num_layers = 1

enc_embedding = IdentityEmbedding()
X = torch.rand((2, 5, embedding_size))
# ::X:: [batch_size, seq_len, features_size=embedding_size]

encoder = RNNEncoder(cell = cell,
                     embedding = enc_embedding,
                     embedding_size = embedding_size,
                     hidden_size = hidden_size,
                     num_layers = num_layers,
                     dropout = 0)
encoder.eval()
output, state = encoder(X)
# ::output:: [seq_len, batch_size, hidden_size]
# ::state:: [num_layers, batch_size, hidden_size]
print(output.shape) 

torch.Size([5, 2, 16])


## Decoders

In [19]:
cell = 'GRU'
hidden_size = 16
num_layers = 1
dropout = 0
clip_logits_c = 0

batch_size = 2
seq_len = 3
embedding_size = 8
assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=5, embedding_size=embedding_size)
input_size = embedding_size * 3

var = torch.randint(0, 4, [batch_size, seq_len, 1])
# ::var:: [batch_size, seq_len, feature_size]
a_prev = torch.randint(0, 2, [batch_size, seq_len, 1])
# ::a_prev:: [batch_size, seq_len, feaure_size=1]
context = torch.rand([batch_size, embedding_size])
# ::context:: [batch_size, feature_size]
state = torch.rand([num_layers, batch_size, hidden_size])
# ::state:: [num_layers, batch_size, hidden_size]
X = (var, a_prev, context)

decoder  = RNNDecoder(input_size = input_size,
                      cell = cell,
                      assignment_emb = assignment_emb,
                      variable_emb = variable_emb,
                      hidden_size = hidden_size,
                      num_layers = num_layers,
                      dropout = 0,
                      clip_logits_c = 0)
decoder.eval()
output, state = decoder(X, state)
# ::output:: [batch_size, seq_len, 2]
# ::state:: [num_layers, batch_size, hidden_size]
print(output.shape)
print(state.shape)

torch.Size([2, 3, 2])
torch.Size([1, 2, 16])


## Train

In [3]:
#Create a sat generator
sat_gen = UniformGenerator(min_n = 5,
                           max_n = 5,
                           min_k = 3,
                           max_k = 3,
                           min_r = 4.2,
                           max_r = 4.2)

#Create a random sat formula
n, r, m, formula = sat_gen.generate_formula()

print(f'n: {n}')
print(f'r: {r}')
print(f'm: {m}')
print(formula)

n: 5
r: 4.2
m: 21
[[4, 1, 3], [-3, -4, -2], [4, 3, 5], [1, 5, -2], [-1, -5, 4], [-3, -1, 2], [1, -3, -2], [4, -1, 5], [-2, -5, 4], [2, -1, -4], [-2, -3, 5], [-4, -2, 5], [-2, -5, 4], [-5, 1, 4], [1, -2, -4], [1, -5, -3], [1, -5, -2], [4, 5, -3], [-3, 5, 4], [4, -5, -2], [-3, -2, 1]]


In [4]:
formula = formula
num_variables = 5
variables = None
num_episodes = 10
accumulation_steps = 1

cell = 'GRU'
hidden_size = 16
num_layers = 1
dropout = 0
clip_logits_c = 0

lr = 1e-3

embedding_size = 8
assignment_emb = BasicEmbedding(num_labels=3, embedding_size=embedding_size)
variable_emb = BasicEmbedding(num_labels=num_variables, embedding_size=embedding_size)
input_size = embedding_size * 2

encoder = None
decoder = RNNDecoder(input_size = input_size,
                     cell = cell,
                     assignment_emb = assignment_emb,
                     variable_emb = variable_emb,
                     hidden_size = hidden_size,
                     num_layers = num_layers,
                     dropout = dropout,
                     clip_logits_c = clip_logits_c)

policy_network = EncoderDecoder(encoder=encoder,
                                decoder=decoder,
                                init_dec_var=None,
                                init_dec_context=None,
                                init_dec_state=None)

optimizer = optim.Adam(policy_network.parameters(), lr=lr)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

baseline = BaselineRollout(-1)  # None
entropy_weight = 0
clip_val = 1
verbose = 2

history_loss, history_num_sat = train(formula,
                                    num_variables,
                                    variables,
                                    num_episodes,
                                    accumulation_steps,
                                    policy_network,
                                    optimizer,
                                    device,
                                    baseline,
                                    entropy_weight,
                                    clip_val,
                                    verbose)

Episode [1/10], Mean Loss -3.1729,  Mean num sat 19.0000
Episode [2/10], Mean Loss -23.9046,  Mean num sat 13.0000
Episode [3/10], Mean Loss 0.0000,  Mean num sat 20.0000
Episode [4/10], Mean Loss -3.4134,  Mean num sat 19.0000
Episode [5/10], Mean Loss -3.5093,  Mean num sat 19.0000
Episode [6/10], Mean Loss -3.3867,  Mean num sat 19.0000
Episode [7/10], Mean Loss 3.5746,  Mean num sat 21.0000
Episode [8/10], Mean Loss -3.9134,  Mean num sat 19.0000
Episode [9/10], Mean Loss -2.9852,  Mean num sat 19.0000
Episode [10/10], Mean Loss -11.3217,  Mean num sat 17.0000
