In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM

gpt_model = AutoModelForCausalLM.from_pretrained("gpt2-medium")
tokenizer = AutoTokenizer.from_pretrained("gpt2-medium", device_map="auto")

In [602]:
from latent import *
model = LatentLM("gpt2-medium", model=gpt_model, tokenizer=tokenizer)

In [555]:
r = model(["How", "What are you up to"]).layer(-1)
r

<gpt2-medium.layer(-1).LatentTensor[2, 6] <-> 'How[TOK]'
                                              'What are you up to[TOK]'

In [564]:
print(r.complete())

<gpt2-medium.LatentTensor[25, 2, 7] <-> 'How did[TOK]'
                                        'What are you up to?[TOK]'



In [567]:
r.complete()

<gpt2-medium.LatentTensor[25, 2, 7] <-> 'How did[TOK]'
                                        'What are you up to?[TOK]'

In [590]:
r.distribution()

tensor([[ -5.6751,  -8.6013, -13.2165,  ..., -18.4128, -12.4559, -10.6140],
        [ -5.5568, -10.8951, -11.8829,  ..., -19.0959, -19.4955,  -9.3850]],
       grad_fn=<LogSoftmaxBackward0>)

In [592]:
r.layer(-1)

<gpt2-medium.layer(-1).LatentTensor[2, 6] <-> 'How[TOK]'
                                              'What are you up to[TOK]'

In [593]:
from latent import *

class Adapter(LatentModule):
    def __init__(self):
        super().__init__()
        
        self.lin1 = torch.nn.Linear(1024, 1024)
        # self.lin2 = torch.nn.Linear(32, 4096)

    def forward(self, x):
        x = self.lin1(x)
        return x
    
adapter = Adapter().to(model.device)

def color_match(color: str):
    soft_color = model(f"The boat is {color}.", name='soft_color')[-2]
    soft_color = adapter(soft_color)
    x = model(soft_color, "The color of the boat is", name='soft_color')
    
    y = model(f"The boat is {color}. The color of the boat is", name='hard_color')
    
    return x[-1] == y[-1]
color_match("red")

In [608]:
color = "pink"
model([f"The boat is {color}. The color of the boat is"], name='hard_color').complete()

<gpt2-medium.LatentTensor[25, 1, 13] <-> 'The boat is pink. The color of the boat is pink[TOK]'

In [610]:
color = "pink"
soft_color = model(f"The boat is {color}.", name='soft_color')[-2]
soft_color = adapter(soft_color)
x = model(soft_color, "The color of the boat is", name='soft_color').complete()
x

In [8]:
colors = []
with open("colors.txt") as f:
    for color in f:
        colors.append(color.strip())
print(colors)

['alizarin', 'amaranth', 'amber', 'amethyst', 'apricot', 'aqua', 'aquamarine', 'asparagus', 'auburn', 'azure', 'beige', 'bistre', 'black', 'blue', 'blue-green', 'blue-violet', 'bondi-blue', 'brass', 'bronze', 'brown', 'buff', 'burgundy', 'camouflage-green', 'caput-mortuum', 'cardinal', 'carmine', 'carrot-orange', 'celadon', 'cerise', 'cerulean', 'champagne', 'charcoal', 'chartreuse', 'cherry-blossom-pink', 'chestnut', 'chocolate', 'cinnabar', 'cinnamon', 'cobalt', 'copper', 'coral', 'corn', 'cornflower', 'cream', 'crimson', 'cyan', 'dandelion', 'denim', 'ecru', 'emerald', 'eggplant', 'falu-red', 'fern-green', 'firebrick', 'flax', 'forest-green', 'french-rose', 'fuchsia', 'gamboge', 'gold', 'goldenrod', 'green', 'grey', 'han-purple', 'harlequin', 'heliotrope', 'hollywood-cerise', 'indigo', 'ivory', 'jade', 'kelly-green', 'khaki', 'lavender', 'lawn-green', 'lemon', 'lemon-chiffon', 'lilac', 'lime', 'lime-green', 'linen', 'magenta', 'magnolia', 'malachite', 'maroon', 'mauve', 'midnight-bl

In [None]:
train, test = colors[:100], colors[100:]
adapter.fit(color_match, train, epochs=100, lr=1e-4, loss_fct="mse", test=test)

Epoch 1/100, loss=0.6019:   1%|█                                                                                                        | 100/10000 [01:19<2:11:38,  1.25it/s]

Loss: 0.6019 Train accuracy: 15.00% Test accuracy: 24.24%


Epoch 2/100, loss=0.4452:   2%|██                                                                                                       | 200/10000 [03:09<2:09:57,  1.26it/s]

Loss: 0.4452 Train accuracy: 28.00% Test accuracy: 37.88%


Epoch 3/100, loss=0.4221:   3%|███▏                                                                                                     | 300/10000 [04:59<2:08:37,  1.26it/s]

Loss: 0.4221 Train accuracy: 30.00% Test accuracy: 37.88%


Epoch 4/100, loss=0.3378:   4%|████▏                                                                                                    | 400/10000 [06:50<2:07:45,  1.25it/s]

Loss: 0.3378 Train accuracy: 30.00% Test accuracy: 36.36%


Epoch 5/100, loss=0.2669:   5%|█████▎                                                                                                   | 500/10000 [08:40<2:05:47,  1.26it/s]

Loss: 0.2669 Train accuracy: 31.00% Test accuracy: 37.88%


Epoch 6/100, loss=0.2517:   6%|██████▎                                                                                                  | 600/10000 [10:31<2:10:20,  1.20it/s]

Loss: 0.2517 Train accuracy: 30.00% Test accuracy: 39.39%


Epoch 7/100, loss=0.2809:   7%|███████▎                                                                                                 | 700/10000 [12:21<2:04:21,  1.25it/s]

Loss: 0.2809 Train accuracy: 30.00% Test accuracy: 37.88%


Epoch 8/100, loss=0.2162:   8%|████████▍                                                                                                | 800/10000 [14:12<2:02:41,  1.25it/s]

Loss: 0.2162 Train accuracy: 30.00% Test accuracy: 37.88%


Epoch 9/100, loss=0.1922:   9%|█████████▍                                                                                               | 900/10000 [16:02<2:00:30,  1.26it/s]

Loss: 0.1922 Train accuracy: 30.00% Test accuracy: 39.39%


Epoch 10/100, loss=0.1810:  10%|██████████▎                                                                                            | 1000/10000 [17:52<1:59:26,  1.26it/s]

Loss: 0.1810 Train accuracy: 29.00% Test accuracy: 37.88%


Epoch 11/100, loss=0.1881:  11%|███████████▎                                                                                           | 1100/10000 [19:42<1:58:09,  1.26it/s]

Loss: 0.1881 Train accuracy: 30.00% Test accuracy: 37.88%


Epoch 12/100, loss=0.1428:  12%|████████████▎                                                                                          | 1200/10000 [21:32<1:57:37,  1.25it/s]