In [1]:
import struct

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.notebook import tqdm
from transformers import GPT2TokenizerFast 
import zmq

In [2]:
device = torch.device('cuda')

In [3]:
g = torch.Generator(device=device).manual_seed(1337)

In [4]:
SEQUENCE_LENGTH = 1024
BATCH_SIZE=3

In [5]:
# Prefixing with _ to signify global.
_text_tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
_text_tokenizer.pad_token = _text_tokenizer.eos_token
TOKENIZATION_DEFAULTS = [
    ("max_length", SEQUENCE_LENGTH),
    ("truncation", True),
    ("padding", "max_length"),
    ("return_tensors", "pt"),
]
def tokenize_text(text, device=device, **kwargs):
    kwargs = dict(TOKENIZATION_DEFAULTS + list(kwargs.items()))
    return _text_tokenizer(text, **kwargs).to(device)

In [6]:
sample_text_1 = """
<INST>
(appendo a b '(1 2 3 4))
results in order of how similar in length are a and b
</INST>
<STREAM>
  mplus
    pause
      #s(state ((#s(var a2 388) #s(var a1 390) . #s(var a2 391)) (#s(var res 389) 4) (#s(var a1 387) . 3) (#s(var a2 385) #s(var a1 387) . #s(var a2 388)) (#s(var res 386) 3 4) (#s(var a1 384) . 2) (#s(var a2 382) #s(var a1 384) . #s(var a2 385)) (#s(var res 383) 2 3 4) (#s(var a1 381) . 1) (#s(var a 379) #s(var a1 381) . #s(var a2 382)) (#s(var #f 0) #s(var a 379) #s(var b 380))) () () ())
      == #s(var res 389) (#s(var a1 390) . #s(var res 392))
    bind
      #f
      == #s(var res 389) (#s(var a1 390) . #s(var res 392))
</STREAM>
"""

In [7]:
sample_text_2 = """
  mplus
    mplus
      pause
        #s(state ((#s(var a2 388)) (#s(var res 389) 4) (#s(var a1 387) . 3) (#s(var a2 385) #s(var a1 387) . #s(var a2 388)) (#s(var res 386) 3 4) (#s(var a1 384) . 2) (#s(var a2 382) #s(var a1 384) . #s(var a2 385)) (#s(var res 383) 2 3 4) (#s(var a1 381) . 1) (#s(var a 379) #s(var a1 381) . #s(var a2 382)) (#s(var #f 0) #s(var a 379) #s(var b 380))) () () ())
        == #s(var b 380) #s(var res 389)
      bind
        #f
        == #s(var b 380) #s(var res 389)
    bind
      mplus
        pause
          #s(state ((#s(var a2 388) #s(var a1 390) . #s(var a2 391)) (#s(var res 389) 4) (#s(var a1 387) . 3) (#s(var a2 385) #s(var a1 387) . #s(var a2 388)) (#s(var res 386) 3 4) (#s(var a1 384) . 2) (#s(var a2 382) #s(var a1 384) . #s(var a2 385)) (#s(var res 383) 2 3 4) (#s(var a1 381) . 1) (#s(var a 379) #s(var a1 381) . #s(var a2 382)) (#s(var #f 0) #s(var a 379) #s(var b 380))) () () ())
          == #s(var res 389) (#s(var a1 390) . #s(var res 392))
        bind
          #f
          == #s(var res 389) (#s(var a1 390) . #s(var res 392))
      (#<procedure:appendo> appendo #s(var a2 391) #s(var b 380) #s(var res 392))
"""

In [8]:
tokenize_text([sample_text_1, sample_text_2])

{'input_ids': tensor([[  198,    27, 38604,  ..., 50256, 50256, 50256],
        [  198,   220,   285,  ..., 50256, 50256, 50256]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]], device='cuda:0')}

In [9]:
EMBEDDING_DIM = 768
_text_embedding = nn.Embedding(_text_tokenizer.vocab_size, EMBEDDING_DIM)
embed_text = _text_embedding

In [10]:
device = torch.device('cuda')

In [11]:
def rand(device=device, generator=g):
    return torch.rand((1,), device=device, generator=generator).item()

In [12]:
def randint(low, high, device=device):
    return torch.randint(low, high, (1,), device=device, generator=g).item()
def rand01():
    return randint(0, 1)

In [13]:
from transformers import GPT2TokenizerFast, GPT2Config, GPT2Model

In [14]:
def init_model():
    configuration = GPT2Config(
        n_layer=6,
        n_head=6,
        n_embd=EMBEDDING_DIM
    )
    model = GPT2Model(configuration)
    return model


def init_optimizer(params):
    optimizer = torch.optim.AdamW(params)
    return optimizer

In [15]:
OUT_DIM = 2
linear = nn.Linear(SEQUENCE_LENGTH * EMBEDDING_DIM, OUT_DIM, device=device)

In [16]:
action_model = init_model().to(device)
opt = init_optimizer(list(action_model.parameters()) + list(linear.parameters()))

In [17]:
out = action_model(tokenize_text([sample_text_1, sample_text_2])["input_ids"])

In [18]:
linear(out.last_hidden_state.flatten(1)).softmax(dim=1)

tensor([[0.7033, 0.2967],
        [0.7661, 0.2339]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [19]:
F.softmax(linear(out.last_hidden_state.flatten(1)), dim=1)

tensor([[0.7033, 0.2967],
        [0.7661, 0.2339]], device='cuda:0', grad_fn=<SoftmaxBackward0>)

In [20]:
out = action_model(tokenize_text([sample_text_1])["input_ids"])
out = linear(out.last_hidden_state.flatten(1))
out = F.tanh(out)
out = F.softmax(out, dim=1)
out.argmax(dim=1).item()

0

In [21]:
ɛ = 0.1
terminated = False
if rand() < ɛ:
    action = rand01()
else:
    out = action_model(tokenize_text([sample_text_1])["input_ids"])
    out = linear(out.last_hidden_state.flatten(1))
    out = F.tanh(out)
    out = F.softmax(out, dim=1)
    action = out.argmax(dim=1).item()
action

0

In [22]:
initial_action_prompt = """
<INST>
(appendo a b '(1 2 3 4))
results in order of how similar in length are a and b
</INST>
<STREAM>
{stream}
</STREAM>
"""
def make_action_prompt(s):
    return initial_action_prompt.format(stream=s)
make_action_prompt("""
mplus
  conj
    #t
    (== 1 1)
""")

"\n<INST>\n(appendo a b '(1 2 3 4))\nresults in order of how similar in length are a and b\n</INST>\n<STREAM>\n\nmplus\n  conj\n    #t\n    (== 1 1)\n\n</STREAM>\n"

In [23]:
initial_value_prompt = """
<INST>
(appendo a b '(1 2 3 4))
results in order of how similar in length are a and b
</INST>
<STREAM>
{stream}
</STREAM>
<ACTION>
{action}
</ACTION>
"""
def make_value_prompt(s, a):
    return initial_value_prompt.format(stream=s, action=a)
make_value_prompt("""
mplus
  conj
    #t
    (== 1 1)
""", 1)

"\n<INST>\n(appendo a b '(1 2 3 4))\nresults in order of how similar in length are a and b\n</INST>\n<STREAM>\n\nmplus\n  conj\n    #t\n    (== 1 1)\n\n</STREAM>\n<ACTION>\n1\n</ACTION>\n"

In [24]:
state = """
mplus
  conj
    #t
    (== 1 1)
"""
out = action_model(tokenize_text([make_action_prompt(state)])["input_ids"])
out = linear(out.last_hidden_state.flatten(1))
out = F.tanh(out)
out = F.softmax(out, dim=1)
action = out.argmax(dim=1).item()
out, action, out[0][action]

(tensor([[0.4762, 0.5238]], device='cuda:0', grad_fn=<SoftmaxBackward0>),
 1,
 tensor(0.5238, device='cuda:0', grad_fn=<SelectBackward0>))

In [25]:
def calc_reward(observation):
    if "(1 2)" in observation:
        return 1
    elif "(1 2 3)" in observation:
        return 0.5
    elif "(2 3 4)" in observation:
        return 0.5
    else:
        return 0

In [26]:
def zmq_init():
    context = zmq.Context()
    socket = context.socket(zmq.PAIR)
    return context, socket

def run_server(ctx, sock, port=5555):
    print(f"Binding to port {port}")
    while True:
        message = sock.recv().decode("utf-8")
        if message.startswith("Success: "):
            pass
        print(message.decode("utf-8"))
        choice = input("Choose a path [0, 1]: ")
        try:
            choice = int(choice)
        except ValueError:
            print(f"Received {choice} but expected an integer. Try again.")
            continue
        socket.send(struct.pack("!i", choice))

In [27]:
def make_observe(sock):
    sock.RCVTIMEO = 50
    poller = zmq.Poller()
    poller.register(sock, zmq.POLLIN)
    def observe():
        result = []
        while poller.poll(50):
            events = poller.poll(50)
            messages = []
            for event in events:
                messages.append(event[0].recv().decode("utf-8"))
            result.append("\n".join(messages))
        return "\n".join(result)
    return observe

def make_act(sock):
    def act(action):
        sock.send(struct.pack("!i", action))
    return act

In [29]:
ctx, sock = zmq_init()
sock.connect("tcp://127.0.0.5:5555")

<SocketContext(connect='tcp://127.0.0.5:5555')>

In [30]:
poller = zmq.Poller()
poller.register(sock, zmq.POLLIN)

In [31]:
poller.poll(50)

[]

### Actions, Observations, and Rewards, from reinforcement learning

In [430]:
act = make_act(sock)
observe = make_observe(sock)

In [698]:
print(observe())

Stream:
mplus
  pause
    #s(state ((#s(var #f 0) #s(var a 230) #s(var b 231))) () () ())
    (#<procedure:appendo> appendo #s(var a 230) #s(var b 231) (1 2 3 4))
  bind
    #f
    (#<procedure:appendo> appendo #s(var a 230) #s(var b 231) (1 2 3 4))


Which path do you want to take? [0, 1]



In [701]:
# Take an action (0 or 1... left or right)
act(0)
# Take an observation
print(observe())

Solution: (() (1 2 3 4))


Stream:
mplus
  pause
    #s(state ((#s(var res 234) 2 3 4) (#s(var a1 232) . 1) (#s(var a 230) #s(var a1 232) . #s(var a2 233)) (#s(var #f 0) #s(var a 230) #s(var b 231))) () () ())
    == (1 2 3 4) (1 2 3 4)
  bind
    #f
    == (1 2 3 4) (1 2 3 4)


Which path do you want to take? [0, 1]



In [None]:
act(0)
print(observe())

In [None]:
path = [0, 0, 0]
for p in path:
    act(p)
    print(observe())

In [702]:
path = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]

In [703]:
print(observe())

Stream:
mplus
  pause
    #s(state ((#s(var #f 0) #s(var a 235) #s(var b 236))) () () ())
    (#<procedure:appendo> appendo #s(var a 235) #s(var b 236) (1 2 3 4))
  bind
    #f
    (#<procedure:appendo> appendo #s(var a 235) #s(var b 236) (1 2 3 4))


Which path do you want to take? [0, 1]



In [704]:
for p in path[:-1]:
    act(p)
    observe()
act(path[-1])
print(observe())

Solution: ((1 2) (3 4))


Stream:
mplus
  pause
    #s(state ((#s(var b 236) 1 2 3 4) (#s(var a 235)) (#s(var #f 0) #s(var a 235) #s(var b 236))) () () ())
    conj
      == () ()
      == (1 2 3 4) (1 2 3 4)
  mplus
    pause
      #s(state ((#s(var b 236) 2 3 4) (#s(var a2 238)) (#s(var res 239) 2 3 4) (#s(var a1 237) . 1) (#s(var a 235) #s(var a1 237) . #s(var a2 238)) (#s(var #f 0) #s(var a 235) #s(var b 236))) () () ())
      conj
        == () ()
        == (2 3 4) (2 3 4)
    pause
      #s(state ((#s(var res 245) 4) (#s(var a1 243) . 3) (#s(var a2 241) #s(var a1 243) . #s(var a2 244)) (#s(var res 242) 3 4) (#s(var a1 240) . 2) (#s(var a2 238) #s(var a1 240) . #s(var a2 241)) (#s(var res 239) 2 3 4) (#s(var a1 237) . 1) (#s(var a 235) #s(var a1 237) . #s(var a2 238)) (#s(var #f 0) #s(var a 235) #s(var b 236))) () () ())
      conj
        (#<procedure:appendo> appendo #s(var a2 244) #s(var b 236) (4))
        == (3 . #s(var a2 244)) (3 . #s(var a2 244))
        == (3 4) (3 4)




In [576]:
act(0)
print(observe())
act(0)
print(observe())




In [None]:
path

In [34]:
OUT_DIM = 2
linear = nn.Linear(SEQUENCE_LENGTH * EMBEDDING_DIM, OUT_DIM, device=device)

In [35]:
action_model = init_model().to(device)
opt = init_optimizer(list(action_model.parameters()) + list(linear.parameters()))

In [37]:
α = 0.1
ɛ = 0.1
γ = 0.9

opt.zero_grad()
out = action_model(tokenize_text([make_action_prompt("")])["input_ids"])
out = linear(out.last_hidden_state.flatten(1))
out = F.tanh(out)
out = F.softmax(out, dim=1)
action = out.argmin(dim=1).item()
act(action)
observation = observe()

if observation.startswith("Success: "):
    terminated = True
    reward = calc_reward(observation)
    y = torch.ones(out.shape, device=device)
    y[0][action] = reward
    loss = ((y - out)**2).mean(1) 
    loss.backward()
    opt.step()
else:
    opt.zero_grad()
    terminated = False
    out = action_model(tokenize_text([make_action_prompt(observation)])["input_ids"])
    out = linear(out.last_hidden_state.flatten(1))
    out = F.tanh(out)
    out = F.softmax(out, dim=1)
    action = out.argmin(dim=1).item()    
    reward = out[0][action]

while not terminated:
    prev_reward = reward
    prev_action = action
    
    reward = out[0][action]
    reward = (1 - α) * prev_reward + α * γ * reward
    y = torch.ones(out.shape, device=device)
    y[0][action] = reward
    loss = ((y - out)**2).mean(1)
    loss.backward()
    opt.step()
    
    if rand() < ɛ:
        action = rand01()        

    act(action)
    observation = observe()
    
    if observation.startswith("Success: "):
        terminated = True
        reward = calc_reward(observation)
        reward = (1 - α) * prev_reward + α * γ * reward
        y = torch.ones(out.shape, device=device)
        y[0][action] = reward  
        loss = ((y - out)**2).mean(1)
        loss.backward()
        opt.step()
    else:
        out = action_model(tokenize_text([make_action_prompt(observation)])["input_ids"])
        out = linear(out.last_hidden_state.flatten(1))
        out = F.tanh(out)
        out = F.softmax(out, dim=1)
        action = out.argmin(dim=1).item()


RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
y_0