In [1]:
import torch
import sys; sys.path.insert(0, "../program_synthesis")
import arguments
import models
import datasets

In [2]:
vanilla_cmd = "--dataset karel --model_type karel-lgrl-ref --karel-mutate-ref --karel-mutate-n-dist 1,2,3 --karel-trace-enc none --karel-refine-dec edit --num_placeholders 0 --debug_every_n=1000 --eval_every_n=10000000 --keep_every_n=10000 --log_interval=100 --batch_size 64 --num_epochs 50 --max_beam_trees 1 --optimizer sgd --gradient-clip 1 --lr 1 --lr_decay_steps 100000 --lr_decay_rate 0.5 --model_dir logdirs/vanilla,trace_enc==none,batch_size==64,lr==1,lr_decay_steps=100000"
trace_cmd = "--dataset karel --model_type karel-lgrl-ref --karel-mutate-ref --karel-mutate-n-dist 1,2,3 --karel-trace-enc aggregate:conv_all_grids=True --karel-refine-dec edit --num_placeholders 0 --debug_every_n=1000 --eval_every_n=10000000 --keep_every_n=10000 --log_interval=100 --batch_size 64 --num_epochs 50 --max_beam_trees 1 --optimizer sgd --gradient-clip 1 --lr 1 --lr_decay_steps 100000 --lr_decay_rate 0.5 --model_dir logdirs/aggregate-with-io,trace_enc==aggregate:conv_all_grids=True,batch_size==64,lr==1,lr_decay_steps=100000"

In [3]:
def parse(cmd):
    a = arguments.get_arg_parser("", "train").parse_args(cmd.split(" "))
    a.cuda = False # not actually using this

    arguments.backport_default_args(a)
    datasets.set_vocab(a)
    return a

In [4]:
vanilla_a = parse(vanilla_cmd)
trace_a = parse(trace_cmd)

In [5]:
def model_size(a):
    m = models.get_model(a)
    return sum(x.nelement() for x in m.model.parameters())

In [6]:
trace_size = model_size(trace_a)

Loaded vocab ../program_synthesis/datasets/../data/karel/word.vocab: 43
LGRLRefineKarel(
  (code_encoder): CodeEncoder(
    (embed): Embedding(43, 256)
    (augment_with_trace): AugmentWithTrace(
      (grid_enc): GridEncoder(
        (initial_conv): Conv2d(45, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (blocks): ModuleList(
          (0): Sequential(
            (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): ReLU()
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (4): ReLU()
            (5): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
          (1): Sequential(
            (0): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (1): ReLU()
            (2): Conv2d(64, 64, kernel_size=(3, 3), stride

In [7]:
def vhs(size):
    va = parse(vanilla_cmd + " --karel-hidden-size " + str(size))
    return model_size(va)

In [8]:
s = 256
search_window = 128
while search_window:
    val = vhs(s)
    if val < trace_size:
        s += search_window
    else:
        s -= search_window
    search_window //= 2

Loaded vocab ../program_synthesis/datasets/../data/karel/word.vocab: 43
LGRLRefineKarel(
  (code_encoder): CodeEncoder(
    (embed): Embedding(43, 256)
    (augment_with_trace): DoNotAugmentWithTrace()
    (encoder): LSTM(256, 256, num_layers=2, batch_first=True, bidirectional=True)
  )
  (decoder): LGRLSeqRefineEditDecoder(
    (op_embed): Embedding(90, 256)
    (last_token_embed): Embedding(43, 256)
    (decoder): LSTM(1536, 256, num_layers=2)
    (out): Linear(in_features=256, out_features=90, bias=False)
  )
  (encoder): LGRLTaskEncoder(
    (input_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (output_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (block_1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), 

LGRLRefineKarel(
  (code_encoder): CodeEncoder(
    (embed): Embedding(43, 324)
    (augment_with_trace): DoNotAugmentWithTrace()
    (encoder): LSTM(324, 324, num_layers=2, batch_first=True, bidirectional=True)
  )
  (decoder): LGRLSeqRefineEditDecoder(
    (op_embed): Embedding(90, 324)
    (last_token_embed): Embedding(43, 324)
    (decoder): LSTM(1944, 324, num_layers=2)
    (out): Linear(in_features=324, out_features=90, bias=False)
  )
  (encoder): LGRLTaskEncoder(
    (input_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (output_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (block_1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU()
      (4): Conv2d(64, 64, kernel_size=

In [9]:
vhs(s) / trace_size

Loaded vocab ../program_synthesis/datasets/../data/karel/word.vocab: 43
LGRLRefineKarel(
  (code_encoder): CodeEncoder(
    (embed): Embedding(43, 327)
    (augment_with_trace): DoNotAugmentWithTrace()
    (encoder): LSTM(327, 327, num_layers=2, batch_first=True, bidirectional=True)
  )
  (decoder): LGRLSeqRefineEditDecoder(
    (op_embed): Embedding(90, 327)
    (last_token_embed): Embedding(43, 327)
    (decoder): LSTM(1962, 327, num_layers=2)
    (out): Linear(in_features=327, out_features=90, bias=False)
  )
  (encoder): LGRLTaskEncoder(
    (input_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (output_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (block_1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), 

1.0014341053685454

In [11]:
vhs(s), trace_size

Loaded vocab ../program_synthesis/datasets/../data/karel/word.vocab: 43
LGRLRefineKarel(
  (code_encoder): CodeEncoder(
    (embed): Embedding(43, 327)
    (augment_with_trace): DoNotAugmentWithTrace()
    (encoder): LSTM(327, 327, num_layers=2, batch_first=True, bidirectional=True)
  )
  (decoder): LGRLSeqRefineEditDecoder(
    (op_embed): Embedding(90, 327)
    (last_token_embed): Embedding(43, 327)
    (decoder): LSTM(1962, 327, num_layers=2)
    (out): Linear(in_features=327, out_features=90, bias=False)
  )
  (encoder): LGRLTaskEncoder(
    (input_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (output_encoder): Sequential(
      (0): Conv2d(15, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
    )
    (block_1): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU()
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), 

(22021552, 21990016)