In [1]:
import numpy as np
import plotly.graph_objs as go
import torch
import pickle
from transformer_modules import MTTokenizer, MTStatePredictor, generate_causal_mask

In [2]:
if torch.cuda.is_available():
    device = "cuda"
    print("Cuda is available. Using GPU.")
else:
    device = "cpu"
    print("Cuda not available. Using CPU.")

Cuda is available. Using GPU.


In [3]:
def plot_mt(state):
    layers = state.size(0)
    token_ids = state.cpu().numpy() # should be shape [20, 13], values between 0 and 3 inclusive
    radius = 1.0
    layer_spacing = 0.5
    
    color_map = {0:'red', 1:'red', 2:'black', 3:'red'}

    xs, ys, zs, cs = [], [], [], []
    for layer in range(layers):
        for i, tid in enumerate(token_ids[layer]):
            theta = 2 * np.pi * i / 13
            xs.append(radius * np.cos(theta))
            ys.append(radius * np.sin(theta))
            zs.append(layer * layer_spacing)
            cs.append(color_map[int(tid)])

    fig = go.Figure(data=[
        go.Scatter3d(
            x=xs, y=ys, z=zs,
            mode='markers',
            marker=dict(size=6, color=cs, colorscale='Viridis', colorbar=dict(title='Token ID'))
        )
    ])
    fig.update_layout(scene=dict(aspectmode='data'))
    fig.show()

In [4]:
mamba = True
if not mamba:
    path = "saved_models/FINAL_24layer_transformer_model.pt"
    print("Using transformer based model")
else:
    path = "saved_models/FINAL_1layer_mamba_model.pt"
    print("Using mamba based model")
    
model = torch.load(path)['MODEL'].to(device)

data_path = "processed_split_sequences/test_sequences_1.pkl"
with open(data_path, 'rb') as f:
    data = pickle.load(f)
    
tokenizer = MTTokenizer()
vocab_size = len(tokenizer.vocab)

Using mamba based model


In [38]:
sample = data[0]

model.eval()

with torch.no_grad():
    meta_data, seqs = sample[0], sample[1:]
    
    meta_data = torch.tensor(meta_data, dtype=torch.float32).to(device)
    seqs = tokenizer.encode(seqs).to(device).unsqueeze(0)

#     targets = seqs.reshape(-1)
    inputs = seqs[:, :-1, :, :]
    
    outputs = model(inputs, meta_data)
    print(outputs.shape)
#     reshaped_outputs = outputs.view(-1, vocab_size)
    _, predicted_ids = torch.max(outputs, dim=-1)
    predicted_ids = predicted_ids.squeeze(0)
    seqs = seqs.squeeze(0).squeeze(-1)

torch.Size([1, 2379, 260, 4])


In [39]:
seqs.shape

torch.Size([2379, 260])

In [40]:
predicted_ids.shape

torch.Size([2379, 260])

In [41]:
seqs.shape

torch.Size([2379, 260])

In [42]:
i = -1
last_pred = predicted_ids[i].reshape(20, 13)
last_state = seqs[i].reshape(20, 13)

In [43]:
last_state

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 2, 2, 1, 0, 0, 0, 0, 0],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 3, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0],
        [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0],
        [3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device='cuda:0')

In [44]:
last_pred

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 2, 2, 1, 0, 0, 0, 0, 0],
        [2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 2, 2, 2, 2, 2, 2, 1, 0, 0, 0],
        [0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 1, 0, 0],
        [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 2, 1, 0, 0, 0, 0, 0, 0],
        [3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]], device='cuda:0')

In [45]:
plot_mt(last_pred)

In [46]:
plot_mt(last_state)