In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import itertools
from typing import Sequence
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
import functools
from datetime import datetime

from torch.utils.tensorboard import SummaryWriter

import torchvision

import os
import sys
if "notebooks" in os.path.abspath('.'):
    sys.path.append('../')
from traces import mtbench_mixtral_utils

In [None]:
traces = mtbench_mixtral_utils.load_all()

In [None]:
trace = list(traces.values())[0]
trace.num_tokens

In [None]:
x = trace.experts_per_token()

In [None]:
def encode_selected_experts(selected_experts: Sequence[int], total_num_experts: int) -> int:
    combinations = list(itertools.combinations(range(total_num_experts), 2))
    combination = tuple(sorted(selected_experts))
    return combinations.index(combination)


def create_dataset(traces: Sequence[mtbench_mixtral_utils.QueryTrace], num_past_experts: int, num_experts_to_predict: int, step_size: int, offset: int = 0) -> TensorDataset:
    X = [] # Each X is a list of the previously selected experts.
    Y = []
    for trace in traces:
        experts_per_token = trace.experts_per_token()
        experts_encoded = np.apply_along_axis(
            lambda x: encode_selected_experts(x, trace.num_experts), -1,
            experts_per_token)
        experts_encoded = experts_encoded.reshape((-1,))
        num_combinations = len(list(itertools.combinations(range(trace.num_experts), 2)))
        experts_torch = torch.from_numpy(experts_encoded)
        experts_one_hot = nn.functional.one_hot(experts_torch, num_combinations).to(torch.float32)
        
        for i in range(num_past_experts + offset, experts_one_hot.shape[0] - num_experts_to_predict, step_size):
            x = experts_one_hot[i - num_past_experts : i]
            y = experts_one_hot[i : i + num_experts_to_predict]
            X.append(x)
            Y.append(y)
    X = torch.stack(X)
    Y = torch.stack(Y)
    return TensorDataset(X.to('mps'), Y.to('mps'))

In [None]:
dataset = create_dataset(traces.values(), 32, 32, 32)
generator = torch.Generator().manual_seed(42)
training_set, validation_set, test_set = torch.utils.data.random_split(dataset, [0.8, 0.1, 0.1], generator=generator)

BATCH_SIZE = 32
training_loader = DataLoader(training_set, batch_size=BATCH_SIZE, shuffle=True)
validation_loader = DataLoader(validation_set, batch_size=BATCH_SIZE)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE)

In [None]:
class MlpPredictor(nn.Module):
    def __init__(self, input_shape: tuple[int], output_shape: tuple[int], hidden_dims: Sequence[int]):
        super().__init__()
        self.output_shape = output_shape
        num_inputs = np.prod(input_shape)
        num_outputs = np.prod(output_shape)
        self.mlp = torchvision.ops.MLP(num_inputs, tuple(hidden_dims) + (num_outputs, ))
        print(num_inputs, hidden_dims, num_outputs)
    
    def forward(self, x):
        batch_size = x.shape[0]
        x = x.reshape((batch_size, -1))
        y = self.mlp(x)
        y = y.reshape((batch_size, *self.output_shape))
        z = torch.softmax(y, dim=-1)
        return z
        

# model = torchvision.ops.MLP(28, [32, 28]).to('mps')
model = MlpPredictor((32, 28), (32, 28), [32]).to('mps')
# model.mlp = model.mlp.to('mps')
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

In [None]:
def train_one_epoch(epoch_index, tb_writer):
    running_loss = 0.
    last_loss = 0.
    num_correct_predictions = 0 # torch.tensor(0, dtype=torch.int64).to('mps')
    device = torch.device('mps')

    for i, data in enumerate(training_loader):
        inputs, labels = data
        inputs = inputs.to(device)
        labels = labels.to(device)
        # Zero your gradients for every batch!
        optimizer.zero_grad()
        # Make predictions for this batch
        outputs = model(inputs)
        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()
        # Adjust learning weights
        optimizer.step()

        # Prediction accuracy
        batch_size, pred_layers, _num_combos = labels.shape
        pred_x_idx = np.arange(0, batch_size * pred_layers) // batch_size
        pred_y_idx = np.repeat(np.arange(0, pred_layers), batch_size)
        pred_z_idx = torch.argmax(outputs, -1).reshape((batch_size * pred_layers))
        # output_argmax = 
        # print(argmax)
        # print(output_argmax.shape)
        predictions = labels[pred_x_idx, pred_y_idx, pred_z_idx]
        correct_predictions = predictions > 0.9
        num_correct_predictions += correct_predictions.sum()
        # print('correct predictions', correct_predictions.sum())
        # print('num_correct_predictions', num_correct_predictions)
        # print(argmax.shape, argmax.dtype)
        # print(labels.shape)
        # print((outputs[argmax] == 1).sum())

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            tb_x = epoch_index * len(training_loader) + i + 1
            tb_writer.add_scalar('Loss/train', last_loss, tb_x)
            running_loss = 0.
            # batch_size, pred_layers, _ = inputs.shape
            past_predictions = 1000 * batch_size * pred_layers
            # print(num_correct_predictions)
            prediction_accuracy = num_correct_predictions.cpu() /  past_predictions
            print('  prediction {} accuracy: {}'.format(i + 1, prediction_accuracy))
            tb_writer.add_scalar('Accuracy/train', prediction_accuracy, tb_x)
            num_correct_predictions = 0

    return last_loss

In [None]:
timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
writer = SummaryWriter('runs/expert_predictor_{}'.format(timestamp))
epoch_number = 0

EPOCHS = 3

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number, writer)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    # Log the running loss averaged per batch
    # for both training and validation
    writer.add_scalars('Training vs. Validation Loss',
                    { 'Training' : avg_loss, 'Validation' : avg_vloss },
                    epoch_number + 1)
    writer.flush()

    # Track best performance, and save the model's state
    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = 'model_{}_{}'.format(timestamp, epoch_number)
        torch.save(model.state_dict(), model_path)

    epoch_number += 1

In [None]:
len(training_set)

In [None]:
1/28

In [None]:
x = x.astype(np.uint32)
out = np.apply_along_axis(lambda x: encode_selected_experts(x, trace.num_experts), -1, x)
out

In [None]:
x = torch.ones(16)

In [None]:
y = x.to('mps')

In [None]:
out.reshape((-1, )).shape

In [None]:
out_torch = torch.from_numpy(out)
out_one_hot = nn.functional.one_hot(out_torch)
out_one_hot.device

In [None]:
out_one_hot[0].shape