In [1]:
import os
import torch
import mlflow
import mlflow_env
from gpt import GPTLanguageModel
from autoencoder import Autoencoder
from gpt_params import tokenizer
from torch.nn import functional as F
from gpt_params import transformer_experiment
from autoencoder_params import autoencoder_experiment, device


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transformer_run_id = "1631cdf63904427fb5833afa9372b625"
autoencoder_run_id = "1ab8c069b70c474da0efa51bb993dae0"

gpt = GPTLanguageModel.load_from_mlflow(transformer_experiment, transformer_run_id, device)
autoencoder = Autoencoder.load_from_mlflow(autoencoder_experiment, autoencoder_run_id, device)

gpt.eval()
autoencoder.eval()

Downloading artifacts: 100%|██████████| 1/1 [00:01<00:00,  1.70s/it]
  model = torch.load(local_model_path, map_location=device)



GPT loaded from MLflow:

Metrics:
  cross_entropy_loss_train: 4.2946
  cross_entropy_loss_eval: 4.7007
  interval_time: 5.6415

Parameters:
  Dataset: wikitext-wikitext-103-v1
  subsets_max_size: 20
  num_training_subsets: 50
  epochs: 1
  learning_rate: 0.001
  batch_size: 768
  optimizer: AdamW
  context_length: 40
  embedding_dim: 128
  num_of_attention_heads: 8
  num_of_blocks: 1
  vocab_size: 50257
  dropout: 0


Downloading artifacts: 100%|██████████| 1/1 [00:00<00:00,  2.84it/s]


Autoencoder loaded from MLflow:

Metrics:
  loss_train: 0.0104
  loss_eval: 0.0105
  recon_loss_eval: 0.0025
  norm_loss_eval: 0.008
  acts_eval: 278.0312

Parameters:
  lasso_lambda: 0.0001
  learning_rate: 0.0001
  num_epochs: 3
  batch_size: 64
  transformer_model_run_id: 1631cdf63904427fb5833afa9372b625
  sparse_dimension_factor: 16
  num_training_subsets: 30
  subsets_max_size: 20



  model = torch.load(local_model_path, map_location=device)


Autoencoder(
  (encoder): Linear(in_features=128, out_features=2048, bias=True)
  (decoder): Linear(in_features=2048, out_features=128, bias=True)
  (relu): ReLU()
)

Forcing a neuron

In [5]:
idx = (
    torch.tensor(tokenizer.encode("The united states of america"), dtype=torch.long)
    .unsqueeze(0)
    .to(device)
)
out = gpt.generate(idx, 100)
    
print(tokenizer.decode(out.squeeze().tolist()))

The united states of america in its Norfolk Iowa in May 1921 . He also ended it in 527 , leaving Benton expedised the territory . Gilligan considered making a major corridor to protect the purpose of which he was called for each defence , which he was allowed to coach Jeff Johnson in a restructuring centre coaching youth to address that folded and a fortune despite feelings cheered high for this in charge for winning time . A yardnings [ Gatmore was alumni and coach and future captain TV Guide . Eisenhower bought a youth to meet L


In [3]:
def biased_generate(idx, max_new_tokens, neuron_idx, activation):
    if neuron_idx < 0 or neuron_idx >= autoencoder.dim_rala:
        raise ValueError(f"Invalid neuron index: {neuron_idx}")

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -gpt.context_length :]
        x = gpt.embed(idx_cond)
        encoded, _ = autoencoder(x)
        encoded[:, :, neuron_idx] = activation
        x = autoencoder.decode(encoded)
        logits = gpt.unembed(x)
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        idx_next = torch.multinomial(probs, num_samples=1)
        idx = torch.cat((idx, idx_next), dim=1)
    return idx

In [29]:
idx = (
    torch.tensor(tokenizer.encode("There people born in Israel are"), dtype=torch.long)
    .unsqueeze(0)
    .to(device)
)
out = biased_generate(idx, 100, 1511, 1.5)

print(tokenizer.decode(out.squeeze().tolist()))

There people born in Israel are silent and in fact , but cited Britto Target Entertainment . Holland is one of these religions of their country called " vicious " or either ’ s values " . 


 = = Identity manifestation of the Gawixandaccioh , according to writer Paulise Leigh , blopper won as " You Really Got Your Man " ( Dr that the goal never sort of a sonic imaginary , and excitement ) . Things and Matt Patfellow : " I 'm gonna worry about ever " .
