In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

gpt_model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium", device_map="auto")

In [3]:
from latent import *
model = LatentLM("gpt2-medium", model=gpt_model, tokenizer=tokenizer)

In [4]:
r = model(["How", "What are you up to"])
r

<gpt2-medium.LatentTensor[25, 2, 6] <-> 'How[TOK]'
                                        'What are you up to[TOK]'

In [5]:
r.layer(-1)

<gpt2-medium.layer(-1).LatentTensor[2, 6] <-> 'How[TOK]'
                                              'What are you up to[TOK]'

In [6]:
r

<gpt2-medium.LatentTensor[25, 2, 6] <-> 'How[TOK]'
                                        'What are you up to[TOK]'

In [7]:
print(r.complete())

<gpt2-medium.LatentTensor[25, 2, 7] <-> 'How did[TOK]'
                                        'What are you up to?[TOK]'



In [8]:
r.complete()

<gpt2-medium.LatentTensor[25, 2, 7] <-> 'How did[TOK]'
                                        'What are you up to?[TOK]'

In [47]:
r.distribution()

tensor([[ -5.6751,  -8.6013, -13.2165,  ..., -18.4128, -12.4559, -10.6140],
        [ -5.5568, -10.8951, -11.8829,  ..., -19.0959, -19.4955,  -9.3850]],
       grad_fn=<LogSoftmaxBackward0>)

In [48]:
r.layer(-1)

<gpt2-medium.layer(-1).LatentTensor[2, 6] <-> 'How[TOK]'
                                              'What are you up to[TOK]'

In [130]:
color = "red"
y = model(f"The boat is {color}. The color of the boat is", name='hard_color')
y


<gpt2-medium.LatentTensor[25, 1, 12] <-> 'The boat is red. The color of the boat is[TOK]'

In [131]:
colors = ["red", "blue"]
soft_color = model([["The boat is ", colors, "."]], name='soft_color')[-2]
soft_color = adapter(soft_color)

In [132]:
soft_color.shape

[1, 2, 1, 1024]

In [133]:
x = model([[soft_color, "The color of the boat is"]], name='soft_color')
x

<gpt2-medium.LatentTensor[25, 2, 7] <-> '{Adapter(soft_color)}The color of the boat is[TOK]'
                                        '{Adapter(soft_color)}The color of the boat is[TOK]'

In [9]:
from latent import *

class Adapter(LatentModule):
    def __init__(self):
        super().__init__()
        
        self.lin1 = torch.nn.Linear(1024, 1024)
        # self.lin2 = torch.nn.Linear(32, 4096)

    def forward(self, x):
        x = self.lin1(x)
        return x
    
adapter = Adapter().to(model.device)

def color_match(colors: List[str]):
    soft_color = model([["The boat is ", colors, "."]], name='soft_color')[-2]
    soft_color = adapter(soft_color)
    x = model([[soft_color, "The color of the boat is"]], name='soft_color')
    
    y = model([["The boat is ", colors, ". The color of the boat is"]], name='hard_color')
    
    return x.layer(-1)[-1] == y.layer(-1)[-1]
color_match(["red"])

LatentEqualityObjective:
 - <gpt2-medium.layer(-1).LatentTensor[1, 1] <-> ' is[TOK]'

 - <gpt2-medium.layer(-1).LatentTensor[1, 1] <-> ' is[TOK]'

In [135]:
color = "pink"
model([f"The boat is {color}. The color of the boat is"], name='hard_color').complete()

Epoch:   0%|                                                                                                                            | 0/10000 [18:03<?, ?it/s]
Epoch:   0%|                                                                                                                            | 0/10000 [17:29<?, ?it/s]
Epoch 1/100, loss=0.3156:   0%|                                                                                            | 3/10000 [16:40<925:41:42, 333.35s/it]


<gpt2-medium.LatentTensor[25, 1, 13] <-> 'The boat is pink. The color of the boat is pink[TOK]'

In [136]:
color = "pink"
soft_color = model(f"The boat is {color}.", name='soft_color')[-1]
soft_color = adapter(soft_color)
x = model([[soft_color, "The color of the boat is"]], name='soft_color')
x.complete()

<gpt2-medium.LatentTensor[25, 1, 15] <-> '{Adapter(soft_color)}The color of the boat is the[TOK]'

In [11]:
colors = []
with open("working_colors.txt") as f:
    for color in f:
        colors.append(color.strip())
print(colors)

['aqua', 'black', 'blue', 'brown', 'cardinal', 'champagne', 'charcoal', 'chocolate', 'cinnamon', 'coral', 'corn', 'cream', 'cyan', 'denim', 'ecru', 'emerald', 'eggplant', 'gold', 'goldenrod', 'green', 'grey', 'indigo', 'ivory', 'khaki', 'lime', 'mustard', 'olive', 'orange', 'peach', 'pear', 'pink', 'puce', 'pumpkin', 'purple', 'red', 'rose', 'salmon', 'silver', 'smalt', 'tomato', 'violet', 'white', 'yellow']


In [17]:
train, test = colors[:30], colors[30:]
adapter.fit(color_match, train, epochs=100, lr=1e-4, loss_fct="crossentropy", test=test)



Epoch 0/100:   0%|                                                                                                                        | 0/375 [00:00<?, ?it/s][A[A

Epoch 1/100, loss=0.1251:   0%|                                                                                                           | 0/375 [00:03<?, ?it/s][A[A

Epoch 1/100, loss=0.1251:   0%|▎                                                                                                  | 1/375 [00:03<22:33,  3.62s/it][A[A

Epoch 1/100, loss=0.4098:   0%|▎                                                                                                  | 1/375 [00:06<22:33,  3.62s/it][A[A

Epoch 1/100, loss=0.4098:   1%|▌                                                                                                  | 2/375 [00:06<19:44,  3.18s/it][A[A

Epoch 1/100, loss=0.8521:   1%|▌                                                                                                  | 2/375 [00:09<19:

Loss: 1.7388 Train accuracy: 56.67% Test accuracy: 84.62%




Epoch 2/100, loss=1.6750:   1%|█                                                                                                  | 4/375 [00:20<17:34,  2.84s/it][A[A

Epoch 2/100, loss=1.6750:   1%|█▎                                                                                                 | 5/375 [00:20<29:06,  4.72s/it][A[A

Epoch 2/100, loss=1.7953:   1%|█▎                                                                                                 | 5/375 [00:22<29:06,  4.72s/it][A[A

Epoch 2/100, loss=1.7953:   2%|█▌                                                                                                 | 6/375 [00:22<25:05,  4.08s/it][A[A

Epoch 2/100, loss=2.0655:   2%|█▌                                                                                                 | 6/375 [00:25<25:05,  4.08s/it][A[A

Epoch 2/100, loss=2.0655:   2%|█▊                                                                                                 | 7/375 [00:25<22:

Loss: 2.7532 Train accuracy: 56.67% Test accuracy: 84.62%




Epoch 3/100, loss=2.5834:   2%|██                                                                                                 | 8/375 [00:37<21:18,  3.48s/it][A[A

Epoch 3/100, loss=2.5834:   2%|██▍                                                                                                | 9/375 [00:37<30:39,  5.03s/it][A[A

Epoch 3/100, loss=2.6034:   2%|██▍                                                                                                | 9/375 [00:40<30:39,  5.03s/it][A[A

Epoch 3/100, loss=2.6034:   3%|██▌                                                                                               | 10/375 [00:40<27:02,  4.45s/it][A[A

Epoch 3/100, loss=2.7650:   3%|██▌                                                                                               | 10/375 [00:43<27:02,  4.45s/it][A[A

Epoch 3/100, loss=2.7650:   3%|██▊                                                                                               | 11/375 [00:43<24:

Loss: 3.3088 Train accuracy: 56.67% Test accuracy: 84.62%




Epoch 4/100, loss=3.0728:   3%|███▏                                                                                              | 12/375 [00:54<21:28,  3.55s/it][A[A

Epoch 4/100, loss=3.0728:   3%|███▍                                                                                              | 13/375 [00:54<30:36,  5.07s/it][A[A

Epoch 4/100, loss=3.0329:   3%|███▍                                                                                              | 13/375 [00:57<30:36,  5.07s/it][A[A

Epoch 4/100, loss=3.0329:   4%|███▋                                                                                              | 14/375 [00:57<26:31,  4.41s/it][A[A

Epoch 4/100, loss=3.1321:   4%|███▋                                                                                              | 14/375 [01:00<26:31,  4.41s/it][A[A

Epoch 4/100, loss=3.1321:   4%|███▉                                                                                              | 15/375 [01:00<23:

Loss: 3.5896 Train accuracy: 56.67% Test accuracy: 84.62%




Epoch 5/100, loss=3.3207:   4%|████▏                                                                                             | 16/375 [01:12<21:51,  3.65s/it][A[A

Epoch 5/100, loss=3.3207:   5%|████▍                                                                                             | 17/375 [01:12<31:35,  5.29s/it][A[A

Epoch 5/100, loss=3.2508:   5%|████▍                                                                                             | 17/375 [01:15<31:35,  5.29s/it][A[A

Epoch 5/100, loss=3.2508:   5%|████▋                                                                                             | 18/375 [01:15<27:22,  4.60s/it][A[A

Epoch 5/100, loss=3.3178:   5%|████▋                                                                                             | 18/375 [01:18<27:22,  4.60s/it][A[A

Epoch 5/100, loss=3.3178:   5%|████▉                                                                                             | 19/375 [01:18<24:

Loss: 3.7284 Train accuracy: 56.67% Test accuracy: 84.62%




Epoch 6/100, loss=3.4413:   5%|█████▏                                                                                            | 20/375 [01:30<21:46,  3.68s/it][A[A

Epoch 6/100, loss=3.4413:   6%|█████▍                                                                                            | 21/375 [01:30<31:23,  5.32s/it][A[A

Epoch 6/100, loss=3.3534:   6%|█████▍                                                                                            | 21/375 [01:33<31:23,  5.32s/it][A[A

Epoch 6/100, loss=3.3534:   6%|█████▋                                                                                            | 22/375 [01:33<27:34,  4.69s/it][A[A

Epoch 6/100, loss=3.4031:   6%|█████▋                                                                                            | 22/375 [01:36<27:34,  4.69s/it][A[A

Epoch 6/100, loss=3.4031:   6%|██████                                                                                            | 23/375 [01:36<25:

Loss: 3.7791 Train accuracy: 56.67% Test accuracy: 84.62%




Epoch 7/100, loss=3.4856:   6%|██████▎                                                                                           | 24/375 [01:48<22:16,  3.81s/it][A[A

Epoch 7/100, loss=3.4856:   7%|██████▌                                                                                           | 25/375 [01:48<31:14,  5.35s/it][A[A

KeyboardInterrupt: 