<a href="https://colab.research.google.com/github/keenanpepper/epsilon-rnns/blob/main/Train_RNN_on_RRXOR.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!git clone https://github.com/keenanpepper/epsilon-transformers.git

Cloning into 'epsilon-transformers'...
remote: Enumerating objects: 4947, done.[K
remote: Counting objects: 100% (757/757), done.[K
remote: Compressing objects: 100% (360/360), done.[K
remote: Total 4947 (delta 421), reused 575 (delta 388), pack-reused 4190[K
Receiving objects: 100% (4947/4947), 206.68 MiB | 22.86 MiB/s, done.
Resolving deltas: 100% (2779/2779), done.


In [2]:
cd epsilon-transformers

/content/epsilon-transformers


In [3]:
!git fetch

In [4]:
!git checkout hackathon-prep

Branch 'hackathon-prep' set up to track remote branch 'hackathon-prep' from 'origin'.
Switched to a new branch 'hackathon-prep'


In [5]:
!pip install -e .

Obtaining file:///content/epsilon-transformers
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting wandb (from epsilon_transformers==0.1)
  Downloading wandb-0.17.0-py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.7/6.7 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
Collecting transformer-lens (from epsilon_transformers==0.1)
  Downloading transformer_lens-2.0.0-py3-none-any.whl (144 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.0/145.0 kB[0m [31m19.9 MB/s[0m eta [36m0:00:00[0m
Collecting black (from epsilon_transformers==0.1)
  Downloading black-24.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)
[2K 

In [1]:
import pathlib
import torch
import torch as t

from epsilon_transformers.training.configs.training_configs import TrainConfig, OptimizerConfig, ProcessDatasetConfig, PersistanceConfig, LoggingConfig
from epsilon_transformers.training.configs.model_configs import RawModelConfig
from epsilon_transformers.training.train import train_model
from epsilon_transformers.analysis.activation_analysis import get_beliefs_for_transformer_inputs

import numpy as np

from sklearn.linear_model import LinearRegression
from sklearn.decomposition import PCA

import plotly.express as px
from plotly.subplots import make_subplots
from plotly import graph_objects as go

In [2]:
import os
os.environ["WANDB_API_KEY"] = 'REDACTED'

In [3]:
from epsilon_transformers.training.configs.base_config import Config

class RNNConfig(Config):
    input_size: int
    hidden_size: int
    output_size: int
    num_layers: int
    nonlinearity: str

In [4]:
class RNNPredictor(t.nn.Module):
    config: RNNConfig
    rnn: t.nn.RNN
    linear: t.nn.Linear

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.rnn = t.nn.RNN(config.input_size, config.hidden_size, config.num_layers, nonlinearity=config.nonlinearity, batch_first=True)
#        self.rnn = t.nn.LSTM(config.input_size, config.hidden_size, config.num_layers, batch_first=True)
        self.linear = t.nn.Linear(config.hidden_size, config.output_size)

    def forward(self, x, hidden, target_data, collect_hidden_states=False):
        x = t.nn.functional.one_hot(x, 2) * 1.0

        if collect_hidden_states:
            each_seq_pos = t.split(x, 1, dim=1)
            assert len(each_seq_pos) == x.shape[1]
            hidden = None
            all_hiddens = []
            for length_one_tensor in each_seq_pos:
                _, hidden = self.rnn(length_one_tensor, hidden)
                all_hiddens.append(hidden)
            output = t.stack(all_hiddens, dim=0) # output.shape should be (seq, layer, batch, dim)
            return output, hidden
        else:
            output, hidden = self.rnn(x, hidden)
            logits = self.linear(output)
            loss = t.nn.functional.cross_entropy(logits.reshape((-1, 2)), target_data.reshape((-1,)))
            return output, hidden, loss

In [5]:
from epsilon_transformers.training.configs.training_configs import Log

import dotenv
import wandb
import os

class RNNTrainConfig(Config):
    rnnConfig: RNNConfig
    optimizer: OptimizerConfig
    dataset: ProcessDatasetConfig
    persistance: PersistanceConfig
    logging: LoggingConfig
    sequence_length: int
    seed: int
    verbose: bool

    def init_logger(self) -> Log:
        if self.logging.wandb:
            dotenv.load_dotenv()
            wandb_api_key = os.environ.get("WANDB_API_KEY", None)
            if wandb_api_key is None:
                raise ValueError(
                    "To use wandb, set your API key as the environment variable `WANDB_API_KEY`"
                )

            wandb.login(key=wandb_api_key)
            wandb.init(project=self.logging.project_name, config=self.model_dump())
        if self.logging.local is not None:
            raise NotImplementedError()
        return self.logging.init()

In [6]:
from epsilon_transformers.persistence import Persister

import random
from tqdm import tqdm
from torch.utils.data import DataLoader

def _set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def _calculate_tokens_trained(
    batch_size: int,
    sequence_len: int,
    batch_idx: int,
) -> int:
    tokens_per_batch = batch_size * sequence_len
    total_tokens_trained = (batch_idx + 1) * tokens_per_batch
    return total_tokens_trained

def _check_if_action_batch(
    perform_action_every_n_tokens: int,
    batch_size: int,
    sequence_len: int,
    batch_idx: int,
) -> bool:
    tokens_per_batch = batch_size * sequence_len
    assert (
        perform_action_every_n_tokens >= tokens_per_batch
    ), "perform_action_every_n_tokens must be greater than or equal to tokens_per_batch"
    perform_action_every_n_batches = perform_action_every_n_tokens // tokens_per_batch
    return (batch_idx + 1) % perform_action_every_n_batches == 0

def _evaluate_model(
    model: RNNPredictor,
    eval_dataloader: DataLoader,
    device: torch.device,
    log: Log
) -> Log:
    with torch.no_grad():
        for input_data, target_data in tqdm(eval_dataloader, desc="Eval Loop"):
            hidden = t.zeros((model.config.num_layers, input_data.shape[0], model.config.hidden_size), device=device)
            input_data, target_data = input_data.to(device), target_data.to(device)
            _, _, loss = model(input_data, hidden, target_data)
            log.update_metrics(train_or_test="test", loss=loss.item())
    return log

def _evaluate_log_and_persist(
    dataset_config: ProcessDatasetConfig,
    persister: Persister,
    model: RNNPredictor,
    verbose: bool,
    log: Log,
    device: torch.device,
    tokens_trained: int,
    sequence_length: int
):
    eval_dataloader = dataset_config.to_dataloader(
        sequence_length=sequence_length, train=False
    )
    _evaluate_model(
        model=model,
        eval_dataloader=eval_dataloader,
        device=device,
        log=log
    )

    if verbose:
        print(f"This is the log\n{log}")

    log.persist()
    log.reset()
    persister.save_model(model, tokens_trained)
    return log

def train_model(config: RNNTrainConfig) -> RNNPredictor:
    device = torch.device(
        "mps"
        if torch.backends.mps.is_available()
        else ("cuda" if torch.cuda.is_available() else "cpu")
    )

    _set_random_seed(config.seed)

    model = RNNPredictor(config.rnnConfig).to(device)
    optimizer = config.optimizer.from_model(model=model, device=device)
    train_dataloader = config.dataset.to_dataloader(
        sequence_length=config.sequence_length, train=True
    )

    persister = config.persistance.init()
    log = config.init_logger()
    model.train()
    for batch_idx, (input_data, target_data) in enumerate(tqdm(train_dataloader, desc="Train Loop")):
        input_data, target_data = input_data.to(device), target_data.to(device)
        hidden = t.zeros((config.rnnConfig.num_layers, config.dataset.batch_size, config.rnnConfig.hidden_size), device=device)
        _, _, loss = model(input_data, hidden, target_data)
        log.update_metrics(train_or_test="train", loss=loss.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tokens_trained_so_far = _calculate_tokens_trained(
            batch_size=config.dataset.batch_size,
            sequence_len=config.sequence_length,
            batch_idx=batch_idx,
        )

        if _check_if_action_batch(
            perform_action_every_n_tokens=config.persistance.checkpoint_every_n_tokens,
            batch_size=config.dataset.batch_size,
            batch_idx=batch_idx,
            sequence_len=config.sequence_length,
        ):
            model.eval()
            _evaluate_log_and_persist(
                dataset_config=config.dataset,
                persister=persister,
                model=model,
                log=log,
                verbose=config.verbose,
                device=device,
                tokens_trained=tokens_trained_so_far,
                sequence_length=config.sequence_length
            )
            log.reset()
            model.train()

    model.eval()
    _evaluate_log_and_persist(
        dataset_config=config.dataset,
        persister=persister,
        model=model,
        log=log,
        verbose=config.verbose,
        device=device,
        tokens_trained=tokens_trained_so_far,
        sequence_length=config.sequence_length
    )

    config.logging.close()
    return model, log

In [None]:
model_config = RNNConfig(
            input_size=2,
            hidden_size=32,
            output_size=2,
            num_layers=3,
            nonlinearity="relu"
        )

optimizer_config = OptimizerConfig(
    optimizer_type='sgd',
    learning_rate=0.02,
    weight_decay=0
)

dataset_config = ProcessDatasetConfig(
    process='rrxor',
    batch_size=512,
    num_tokens=400000000,
    test_split=0.0001
)

persistance_config = PersistanceConfig(
    location='local',
    collection_location=pathlib.Path('/content/epsilon-transformers/rnn-rrxor-test'),
    checkpoint_every_n_tokens=500000
)

train_config = RNNTrainConfig(rnnConfig=model_config,
                              optimizer=optimizer_config,
                              dataset=dataset_config,
                              persistance=persistance_config,
                              logging=LoggingConfig(project_name="rnn-rrxor-test", wandb=True),
                              sequence_length=6, seed=42, verbose=True)

train_model(train_config)

In [8]:
device = t.device('cuda')

In [9]:
model = RNNPredictor(model_config)

In [10]:
model.load_state_dict(torch.load(pathlib.Path('/content/epsilon-transformers/rnn-rrxor-test/986867712.pt'), map_location=device))

<All keys matched successfully>

In [11]:
model = model.to(device)

In [12]:
from epsilon_transformers.process.processes import RRXOR

In [13]:
process = RRXOR()
print(process)

rrxor Process
Number of states: 5
Vocabulary length: 2
Transition matrix shape: (2, 5, 5)


In [14]:
mixed_state_tree = process.derive_mixed_state_presentation(depth=11)
MSP_transition_matrix = mixed_state_tree.build_msp_transition_matrix()

In [15]:
# in order to plot the belief states in the simplex, we need to get the paths and beliefs from the MSP
tree_paths, tree_beliefs = mixed_state_tree.paths_and_belief_states

In [16]:
# the MSP states are the unique beliefs in the tree
msp_beliefs = [tuple(round(b, 5) for b in belief) for belief in tree_beliefs]
print(f"Number of Unique beliefs: {len(set(msp_beliefs))} out of {len(msp_beliefs)}")

Number of Unique beliefs: 36 out of 1723


In [17]:
# now lets index each belief
msp_belief_index = {b: i for i, b in enumerate(set(msp_beliefs))}

for i in range(5):
    ith_belief = list(msp_belief_index.keys())[i]
    print(f"{ith_belief} is indexed as {msp_belief_index[ith_belief]}")

(0.5, 0.25, 0.0, 0.0, 0.25) is indexed as 0
(0.33333, 0.0, 0.33333, 0.16667, 0.16667) is indexed as 1
(0.0, 0.0, 1.0, 0.0, 0.0) is indexed as 2
(0.0, 0.0, 0.5, 0.5, 0.0) is indexed as 3
(0.0, 0.66667, 0.0, 0.0, 0.33333) is indexed as 4


In [18]:
def run_visualization_pca(beliefs):
    pca = PCA(n_components=3)
    pca.fit(beliefs)

    return pca

def visualize_ground_truth_simplex_3d(beliefs, belief_labels, pca):

    beliefs_pca = pca.transform(beliefs)

    colors = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24 + px.colors.qualitative.Plotly
    fig = px.scatter_3d(beliefs_pca, x=0, y=1, z=2,
                        color=[str(i) for i in belief_labels],
                        color_discrete_sequence=colors)
    fig.update_layout(width=400, height=400)
    fig.update_traces(marker={'size': 1})
    fig.show()

vis_pca = run_visualization_pca(list(msp_belief_index.keys()))
index = list(msp_belief_index.values())
visualize_ground_truth_simplex_3d(list(msp_belief_index.keys()),
                                  list(msp_belief_index.values()),
                                  vis_pca)

In [132]:
SEQUENCE_LENGTH = 6

In [133]:
model_inputs = [x for x in tree_paths if len(x) == SEQUENCE_LENGTH]
model_inputs = torch.tensor(model_inputs, dtype=torch.int).to(device).long()

# print first few batches
print(model_inputs[:5])

tensor([[1, 0, 1, 1, 0, 0],
        [1, 0, 1, 1, 0, 1],
        [1, 1, 1, 1, 0, 1],
        [1, 1, 1, 1, 0, 0],
        [1, 1, 1, 0, 1, 0]], device='cuda:0')


In [134]:
model_input_beliefs, model_input_belief_indices = get_beliefs_for_transformer_inputs(model_inputs, msp_belief_index, tree_paths, tree_beliefs)
print(f"Model Input Beliefs: {model_input_beliefs.shape}, Model Input Belief Indices: {model_input_belief_indices.shape}")

Model Input Beliefs: torch.Size([52, 6, 5]), Model Input Belief Indices: torch.Size([52, 6])


need to get all hidden states ~~~~with TorchLens~~~~ by using for-loop implementation with collect_hidden_states=True

In [148]:
hidden_states, _ = model(model_inputs, None, t.zeros_like(model_inputs, device=device), collect_hidden_states=True)

In [149]:
hidden_states.shape

torch.Size([6, 3, 52, 32])

In [150]:
hidden_states0 = hidden_states[:,0:1,:,:]
hidden_states1 = hidden_states[:,1:2,:,:]
hidden_states2 = hidden_states[:,2:3,:,:]

In [151]:
from einops import rearrange

In [152]:
hidden_states_reshaped = rearrange(hidden_states, "seq layer batch i -> batch seq (layer i)")
hidden_states_reshaped0 = rearrange(hidden_states0, "seq layer batch i -> batch seq (layer i)")
hidden_states_reshaped1 = rearrange(hidden_states1, "seq layer batch i -> batch seq (layer i)")
hidden_states_reshaped2 = rearrange(hidden_states2, "seq layer batch i -> batch seq (layer i)")

In [153]:
# this works as-is with the understanding that "activations" are actually the hidden states

# in the end we want to do linear regression between the activations and the transformer_input_beliefs
def run_activation_to_beliefs_regression(activations, ground_truth_beliefs):

    # make sure the first two dimensions are the same
    assert activations.shape[0] == ground_truth_beliefs.shape[0]
    assert activations.shape[1] == ground_truth_beliefs.shape[1]

    # flatten the activations
    batch_size, n_ctx, d_model = activations.shape
    belief_dim = ground_truth_beliefs.shape[-1]
    activations_flattened = activations.reshape(-1, d_model) # [batch * n_ctx, d_model]
    ground_truth_beliefs_flattened = ground_truth_beliefs.view(-1, belief_dim) # [batch * n_ctx, belief_dim]

    # run the regression
    regression = LinearRegression()
    regression.fit(activations_flattened, ground_truth_beliefs_flattened)

    # get the belief predictions
    belief_predictions = regression.predict(activations_flattened) # [batch * n_ctx, belief_dim]
    belief_predictions = belief_predictions.reshape(batch_size, n_ctx, belief_dim)

    return regression, belief_predictions

In [154]:
regressions, belief_predictions = run_activation_to_beliefs_regression(hidden_states_reshaped.detach().cpu(), model_input_beliefs.cpu())
regressions0, belief_predictions0 = run_activation_to_beliefs_regression(hidden_states_reshaped0.detach().cpu(), model_input_beliefs.cpu())
regressions1, belief_predictions1 = run_activation_to_beliefs_regression(hidden_states_reshaped1.detach().cpu(), model_input_beliefs.cpu())
regressions2, belief_predictions2 = run_activation_to_beliefs_regression(hidden_states_reshaped2.detach().cpu(), model_input_beliefs.cpu())

In [182]:
belief_predictions_pca = vis_pca.transform(belief_predictions.reshape(-1, 5))
belief_predictions_pca0 = vis_pca.transform(belief_predictions0.reshape(-1, 5))
belief_predictions_pca1 = vis_pca.transform(belief_predictions1.reshape(-1, 5))
belief_predictions_pca2 = vis_pca.transform(belief_predictions2.reshape(-1, 5))
model_input_belief_indices_flattened = model_input_belief_indices.view(-1).cpu().numpy()

beliefs_2d = vis_pca.transform(list(msp_belief_index.keys()))
colors = px.colors.qualitative.Light24 + px.colors.qualitative.Dark24 + px.colors.qualitative.Plotly

# Create a subplot with two scatter plots
fig = make_subplots(rows=3, cols=3, specs=[[{'type': 'scatter'}, {'type': 'scatter'}, {'type': 'scatter'}],[{'type': 'scatter'}, {'type': 'scatter'}, {'type': 'scatter'}],[{'type': 'scatter'}, {'type': 'scatter'}, {'type': 'scatter'}]])

# Plot the ground truth beliefs on the left
fig.add_trace(go.Scatter(x=beliefs_2d[:, 0], y=beliefs_2d[:, 2],
                         mode='markers',
                         marker=dict(size=10, color=[colors[i] for i in list(msp_belief_index.values())], opacity=1),
                         name=f'Beliefs'),
              row=2, col=1)

# Calculate and plot the centers of mass of the belief predictions on the right
for belief in msp_belief_index.keys():
    b = msp_belief_index[belief]
    relevant_indices = np.where(model_input_belief_indices_flattened == b)[0]

    relevant_data = belief_predictions_pca2[relevant_indices]
    if len(relevant_data) > 0:
        centers_of_mass = np.mean(relevant_data, axis=0)
        fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                 mode='markers',
                                 marker=dict(size=10, color=colors[b], opacity=1),
                                 name=f'Belief {b} Center of Mass'),
                      row=1, col=2)
        fig.add_trace(go.Scatter(x=relevant_data[:, 0], y=relevant_data[:, 2],
                                 mode='markers',
                                 marker=dict(size=4, color=colors[b], opacity=.2),
                                 name=f'Belief {b}'),
                      row=1, col=2)

    relevant_data = belief_predictions_pca1[relevant_indices]
    if len(relevant_data) > 0:
        centers_of_mass = np.mean(relevant_data, axis=0)
        fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                 mode='markers',
                                 marker=dict(size=10, color=colors[b], opacity=1),
                                 name=f'Belief {b} Center of Mass'),
                      row=2, col=2)
        fig.add_trace(go.Scatter(x=relevant_data[:, 0], y=relevant_data[:, 2],
                                 mode='markers',
                                 marker=dict(size=4, color=colors[b], opacity=.2),
                                 name=f'Belief {b}'),
                      row=2, col=2)

    relevant_data = belief_predictions_pca0[relevant_indices]
    if len(relevant_data) > 0:
        centers_of_mass = np.mean(relevant_data, axis=0)
        fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                 mode='markers',
                                 marker=dict(size=10, color=colors[b], opacity=1),
                                 name=f'Belief {b} Center of Mass'),
                      row=3, col=2)
        fig.add_trace(go.Scatter(x=relevant_data[:, 0], y=relevant_data[:, 2],
                                 mode='markers',
                                 marker=dict(size=4, color=colors[b], opacity=.2),
                                 name=f'Belief {b}'),
                      row=3, col=2)

    relevant_data = belief_predictions_pca[relevant_indices]
    if len(relevant_data) > 0:
        centers_of_mass = np.mean(relevant_data, axis=0)
        fig.add_trace(go.Scatter(x=[centers_of_mass[0]], y=[centers_of_mass[2]],
                                 mode='markers',
                                 marker=dict(size=10, color=colors[b], opacity=1),
                                 name=f'Belief {b} Center of Mass'),
                      row=2, col=3)
        fig.add_trace(go.Scatter(x=relevant_data[:, 0], y=relevant_data[:, 2],
                                 mode='markers',
                                 marker=dict(size=4, color=colors[b], opacity=.2),
                                 name=f'Belief {b}'),
                      row=2, col=3)

# set x and y lime to -.75 to .75
fig.update_xaxes(range=[-.85, .85], row=2, col=1)
fig.update_yaxes(range=[-.85, .85], row=2, col=1)
fig.update_xaxes(range=[-.85, .85], row=2, col=2)
fig.update_yaxes(range=[-.85, .85], row=2, col=2)
fig.update_xaxes(range=[-.85, .85], row=2, col=3)
fig.update_yaxes(range=[-.85, .85], row=2, col=3)
# Update layout
fig.update_layout(title='2D PCA Projection of Beliefs', title_x=0.45,
                  xaxis_title='PCA Dimension 1', yaxis_title='PCA Dimension 2',
                  width=1100, height=1000,
                  annotations=[
                      dict(text="Ground Truth", x=0.1, y=0.65, showarrow=False, xref="paper", yref="paper"),
                      dict(text="Layer 2 Only", x=0.5, y=1.03, showarrow=False, xref="paper", yref="paper"),
                      dict(text="Layer 1 Only", x=0.5, y=0.65, showarrow=False, xref="paper", yref="paper"),
                      dict(text="Layer 0 Only", x=0.5, y=0.27, showarrow=False, xref="paper", yref="paper"),
                      dict(text="All Layers", x=0.9, y=0.65, showarrow=False, xref="paper", yref="paper"),
                  ])

In [None]:
fig.show()

In [None]:
fig.show()