In [88]:
from insi import Probe, Probes, Cortex 
import torch 
import torch.nn as nn 
import torch.nn.functional as F 
from torchvision import datasets, transforms
from dataclasses import dataclass
from nanogpt import GPT
import pickle

In [89]:
@dataclass
class GPTConfig:
    block_size: int = 64
    vocab_size: int = 65
    n_layer: int = 1
    n_head: int = 4
    n_embd: int = 128
    bias: bool = False
    
gptconf = GPTConfig()

# Initialize your neural network model and objective function
model = GPT(gptconf)
model.load_state_dict(torch.load('saved/3000e.pth'))

with open("../data/shakespeare_char/meta.pkl", 'rb') as f:
    meta = pickle.load(f)
stoi, itos = meta['stoi'], meta['itos']
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

def gen(model):
    model.eval()
    start = '\n'
    x = (torch.tensor(encode(start), dtype=torch.long)[None, ...])
    y, logits = model.generate(x, max_new_tokens=11, temperature=0.8, top_k=200)
    return y, logits

def objective(pred):
    logits = pred[0]
    print("logits:", logits.shape)
    return F.cross_entropy(logits.view(65), torch.tensor(encode("Q")[0]))

# set iputs to random strings 
input = [torch.randint(gptconf.vocab_size, (gptconf.block_size,)) for _ in range(100)]

number of parameters: 0.21M


In [90]:
# Create probes
num_probes = 64
probes =  {i: Probe(discrete=False) for i in range(num_probes)}

# Create a Probes collection
probes_collection = Probes(probes)

# Initialize Cortex instance
cortex = Cortex(probes_collection, model, objective)

In [91]:
# Tune the neural network using probes
cortex.tune(epochs=1, lr=0.1, input=input, first_layer="emb")

logits: torch.Size([1, 1, 65])
0 tensor(4) torch.Size([64]) tensor(5) 65
1 tensor(56) torch.Size([64]) tensor(57) 65
2 tensor(54) torch.Size([64]) tensor(55) 65
3 tensor(26) torch.Size([64]) tensor(27) 65
4 tensor(2) torch.Size([64]) tensor(3) 65
5 tensor(31) torch.Size([64]) tensor(32) 65
6 tensor(21) torch.Size([64]) tensor(22) 65
7 tensor(16) torch.Size([64]) tensor(17) 65
8 tensor(11) torch.Size([64]) tensor(12) 65
9 tensor(25) torch.Size([64]) tensor(26) 65
10 tensor(32) torch.Size([64]) tensor(33) 65
11 tensor(6) torch.Size([64]) tensor(7) 65
12 tensor(29) torch.Size([64]) tensor(30) 65
13 tensor(55) torch.Size([64]) tensor(56) 65
14 tensor(3) torch.Size([64]) tensor(4) 65
15 tensor(32) torch.Size([64]) tensor(33) 65
16 tensor(58) torch.Size([64]) tensor(59) 65
17 tensor(63) torch.Size([64]) tensor(64) 65
18 tensor(49) torch.Size([64]) tensor(50) 65
19 tensor(41) torch.Size([64]) tensor(42) 65
20 tensor(33) torch.Size([64]) tensor(34) 65
21 tensor(38) torch.Size([64]) tensor(39) 