In [1]:
from nnsight import LanguageModel

model = LanguageModel("openai-community/gpt2-xl")

REMOTE = True

In [2]:
from utils import get_request

decoder = lambda x : model.lm_head(model.transformer.ln_f(x))


with model.trace("The Eiffel Tower is in the city of", remote=REMOTE) as tracer:
    results = {}
    
    for i, layer in enumerate(model.transformer.h):
        output = layer.output[0]

        decoded = decoder(output)

        probs, tokens = decoded.softmax(-1).max(-1).save()

        results[i] = {
            "probs" : probs.tolist().save(),
            "tokens" : tokens.tolist().save()
        }

        break

    request = get_request(tracer)

664e8cc0002c2d9aa91d0eea - RECEIVED: Your job has been received and is waiting approval.
664e8cc0002c2d9aa91d0eea - APPROVED: Your job was approved and is waiting to be run.
664e8cc0002c2d9aa91d0eea - COMPLETED: Your job has been completed.


Downloading result: 100%|██████████| 1.75k/1.75k [00:00<00:00, 32.7MB/s]


In [5]:
import json

with open("eiffel_tower.json", "w") as f:
    json.dump(request.model_dump_json(), f, indent=4)

In [21]:
import requests
from nnsight import CONFIG

import socketio

with socketio.SimpleClient(reconnection_attempts=10) as sio:

    sio.connect(
        f"wss://{CONFIG.API.HOST}",
        socketio_path="/ws/socket.io",
        transports=["websocket"],
        wait_timeout=10,
    )

    # Give request session ID so server knows to respond via websockets to us.
    request.session_id = sio.sid

    url = f"https://{CONFIG.API.HOST}/request"

    response = requests.post(
        f"https://{CONFIG.API.HOST}/request",
        json=request.model_dump(exclude=["id", "received"]),
        headers={"ndif-api-key": CONFIG.API.APIKEY},
    )

    _exit = False
    value = None
    while True:
        response = sio.receive()

        for i in response:
            if type(i) == dict:
                print(i)
                if i['status'] == "COMPLETED":
                    _exit = True
                    value = i['id']
                    break
        if _exit:
            break
from tqdm import tqdm
import torch
import io
result_bytes = io.BytesIO()
result_bytes.seek(0)

with requests.get(url=f"https://{CONFIG.API.HOST}/result/{value}", stream=True) as stream:
    # Total size of incoming data.
    total_size = float(stream.headers["Content-length"])

    with tqdm(
        total=total_size,
        unit="B",
        unit_scale=True,
        desc="Downloading result",
    ) as progress_bar:
        # chunk_size=None so server determines chunk size.
        for data in stream.iter_content(chunk_size=None):
            progress_bar.update(len(data))
            result_bytes.write(data)

    # Move cursor to beginning of bytes.
    result_bytes.seek(0)

    out = torch.load(result_bytes, map_location="cpu")

{'id': '664e8ecbb6a68fc7ea14c0c9', 'status': 'APPROVED', 'description': 'Your job was approved and is waiting to be run.', 'received': '2024-05-23 00:33:15.477050', 'session_id': 'ebBzUjECyKK05lshABpX'}
{'id': '664e8ecbb6a68fc7ea14c0c9', 'status': 'COMPLETED', 'description': 'Your job has been completed.', 'received': '2024-05-23 00:33:15.477050', 'session_id': 'ebBzUjECyKK05lshABpX'}


Downloading result: 100%|██████████| 1.75k/1.75k [00:00<00:00, 16.4MB/s]


In [22]:
out

{'id': '664e8ecbb6a68fc7ea14c0c9',
 'saves': {'proxy_call_3': (tensor([[0.0017, 0.0533, 0.0912, 0.0529, 0.4096, 0.0132, 0.0217, 0.0253, 0.6327,
            0.0304]], requires_grad=True),
   tensor([[10562, 26483, 22402,  1098,  8765,   783,  1502,   976,  6745,   262]])),
  'proxy_call_4': [[0.0016897746827453375,
    0.05329510569572449,
    0.09117957204580307,
    0.052862223237752914,
    0.40957528352737427,
    0.013241843320429325,
    0.021667778491973877,
    0.025325456634163857,
    0.6326640844345093,
    0.030422642827033997]],
  'proxy_call_5': [[10562,
    26483,
    22402,
    1098,
    8765,
    783,
    1502,
    976,
    6745,
    262]]}}

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import numpy as np

# Decode tokens
decoded_tokens = {}
for key, value in results.items():
    decoded_tokens[key] = [model.tokenizer.decode([token]) for token in value['tokens'][0]]

# Plotting the table
fig, ax = plt.subplots()

# Set up the colormap
norm = mcolors.Normalize(vmin=0, vmax=1)
cmap = plt.cm.get_cmap('coolwarm')

# Create the table data
table_data = []
for key, value in results.items():
    row = []
    for token, prob in zip(decoded_tokens[key], value['probs'][0]):
        color = cmap(norm(prob))
        row.append((token, color))
    table_data.append(row)

# Create the table
table = ax.table(
    cellText=[[cell[0] for cell in row] for row in table_data],
    cellColours=[[cell[1] for cell in row] for row in table_data],
    loc='center',
    cellLoc='center',
)

# Remove axes
ax.axis('off')

# Adjust layout to make the table fit better
plt.subplots_adjust(left=0.2, right=0.8, top=0.8, bottom=0.2)

# plt.show()