In [1]:
import os
import random
import json
import pickle	

import torch
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Tokenizer, GPT2LMHeadModel, utils, AutoTokenizer, GPTJForCausalLM

from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

# move workind directory to the root of the project
os.chdir("..")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def load_or_download_model(model_name="EleutherAI/gpt-j-6B", device = "cpu"):
    if not os.path.exists(f'./models/{model_name}'):
        os.makedirs(f'./models/{model_name}', exist_ok=True)

    TOKENIZER_PATH = f"./models/{model_name}/tokenizer.pt"
    MODEL_PATH = f"./models/{model_name}/model.pt"
    EMBEDDINGS_PATH = f"./models/{model_name}/embeddings.pt"

    # Load or Download Tokenizer
    if os.path.exists(TOKENIZER_PATH):
        print(f'Loading {model_name} tokenizer from local storage...')
        tokenizer = torch.load(TOKENIZER_PATH)
    else:
        print(f'Downloading {model_name} tokenizer...')
        tokenizer = AutoTokenizer.from_pretrained(f"{model_name}")
        torch.save(tokenizer, TOKENIZER_PATH)

    # Load or Download Model
    if os.path.exists(MODEL_PATH):
        print(f'Loading {model_name} model from local storage...')
        GPTmodel = torch.load(MODEL_PATH).to(device)
    else:
        print(f'Downloading {model_name} model...')
        GPTmodel = GPTJForCausalLM.from_pretrained(f"{model_name}").to(device)
        torch.save(GPTmodel, MODEL_PATH)
        
    GPTmodel.eval()

    # Save or Load Embeddings
    if os.path.exists(EMBEDDINGS_PATH):
        print(f'Loading {model_name} embeddings from local storage...')
        embeddings = torch.load(EMBEDDINGS_PATH).to(device)
    else:
        embeddings = GPTmodel.transformer.wte.weight.to(device)
        torch.save(embeddings, EMBEDDINGS_PATH)
        print(f"The {model_name} 'embeddings' tensor has been saved.")

    return tokenizer, GPTmodel, embeddings

# Call the function with desired model name
tokenizer, GPTmodel, embeddings = load_or_download_model(
    model_name="gpt2", device = "cpu")


Loading gpt2 tokenizer from local storage...
Loading gpt2 model from local storage...
Loading gpt2 embeddings from local storage...


In [5]:
from src.letter_token_utils import (
    get_token_strings,
    get_all_rom_tokens,
    get_distinct_letters_dict
)

token_strings = get_token_strings(tokenizer)
_, all_rom_token_indices = get_all_rom_tokens(token_strings)
letter_presence_dict = get_distinct_letters_dict(
    all_rom_token_indices, token_strings
)

all_rom_token_gt2_indices = [idx for idx in all_rom_token_indices if len(token_strings[idx].lstrip()) > 2]


There are 50257 tokens.
There are 46893 all-Roman tokens.
There are 1 all-Roman tokens with 14 distinct letters


In [6]:
from src.train_letter_presence_probes import train_letter_presence_probes

probe_weights_tensor = train_letter_presence_probes(
    embeddings, 
    all_rom_token_gt2_indices,
    token_strings,
    )



_________________________________________________

A: epoch 1/100, Loss: 0.6692088716030121
A: epoch 2/100, Loss: 0.6250839204788208
A: epoch 3/100, Loss: 0.5937851204872131
A: epoch 4/100, Loss: 0.5708102705478668
A: epoch 5/100, Loss: 0.5530942339897156
A: epoch 6/100, Loss: 0.539606367468834
A: epoch 7/100, Loss: 0.5281799081563949
A: epoch 8/100, Loss: 0.5189728199243545
A: epoch 9/100, Loss: 0.5111996406316757
A: epoch 10/100, Loss: 0.5047116018533707
A: epoch 11/100, Loss: 0.49903610146045685
A: epoch 12/100, Loss: 0.49412901484966276
A: epoch 13/100, Loss: 0.4903116943836212
A: epoch 14/100, Loss: 0.48618496334552763
A: epoch 15/100, Loss: 0.48261413991451263
A: epoch 16/100, Loss: 0.47991536390781403
A: epoch 17/100, Loss: 0.47741268944740295
A: epoch 18/100, Loss: 0.4750845518112183
A: epoch 19/100, Loss: 0.47277259922027587
A: epoch 20/100, Loss: 0.47100847434997556
A: epoch 21/100, Loss: 0.46891233706474306
A: epoch 22/100, Loss: 0.46760645985603333
A: epoch 23/100, Loss: 0