In [1]:
%load_ext lab_black
%env XLA_PYTHON_CLIENT_PREALLOCATE=false
import jax

jax.config.update("jax_platform_name", "cpu")
jax.config.update("jax_enable_x64", True)
from matplotlib import pyplot as plt
import numpy as np
import os
import pandas as pd
import pickle
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from transformers import (
    BertConfig,
    BfBertForSequenceClassification,
    BertForSequenceClassification,
)
from typing import List, Tuple
import seaborn as sns
import sys
import wandb

import brunoflow as bf
from brunoflow.ad import Node
from brunoflow.net import LogReg
from brunoflow.opt import Adam, cross_entropy_loss, regularize
from preprocessing.datasets import (
    MNIST,
    FFOOM,
    NoisyLinear,
    BreastCancer,
    Bank7,
    BankFull,
    FirstTokenRepeatedOnce,
    Contains1FirstToken,
    ContainsTokenSet,
    FirstTokenRepeatedOnceImmediately,
    Contains1,
    FirstTokenRepeatedImmediately,
    FirstTokenRepeatedLast,
    AdjacentDuplicate,
    BinCountOnes,
)

from utils import catchtime, gpu_memory_usage
from kvq_utils import (
    convert_sentence_to_tokens_and_target_idx,
    find_matching_nodes,
    preprocessgrad_per_parent_per_word_data_per_layer,
    rename_matmul_kvq_nodes,
    summarize_max_grad_kvq,
    plot_max_grad_against_layer_per_word,
)
from entropy_utils import (
    gather_entropies_of_input_ids,
    gather_abs_grad_of_input_ids,
    gather_grad_of_input_ids,
)

env: XLA_PYTHON_CLIENT_PREALLOCATE=false


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
try:
    import jax
    from jax import numpy as jnp

    jax_device_kind = jax.devices()[0].device_kind
    print(f"Running JAX on {jax_device_kind}")
    # if "NVIDIA" not in jax_device_kind:
    #     raise ValueError("Imported JAX, but not running on a GPU, terminating.")
except ImportError:
    print("No JAX available to import!")

Running JAX on cpu


In [3]:
os.environ["WANDB_NOTEBOOK_NAME"] = os.path.join(
    os.getcwd(), "train_bert_synthetic.ipynb"
)
excluded_all_caps_params = {k for k in locals().keys() if k.isupper()}

In [4]:
# Parameters
# Data parameters
# DATASET_NAME = "FirstTokenRepeatedOnce"
# DATASET_NAME = "Contains1FirstToken"
# DATASET_NAME = "FirstTokenRepeatedOnceImmediately"
# DATASET_NAME, DATASET_KWARGS_IDENTIFIABLE = "ContainsTokenSet", {
#     "num_points": 1000,
#     "token_set": [1, 2, 3],
# }
# DATASET_NAME, DATASET_KWARGS_IDENTIFIABLE = "Contains1", {"num_points": 20000}
# DATASET_NAME, DATASET_KWARGS_IDENTIFIABLE = "FirstTokenRepeatedImmediately", {
#     "num_points": 20000
# }
# DATASET_NAME, DATASET_KWARGS_IDENTIFIABLE = "FirstTokenRepeatedLast", {
#     "num_points": 20000
# }
# DATASET_NAME, DATASET_KWARGS_IDENTIFIABLE = "AdjacentDuplicate", {"num_points": 20000}
DATASET_NAME, DATASET_KWARGS_IDENTIFIABLE = "BinCountOnes", {
    "num_points": 1200,
    "num_classes": 2,
    "seq_len": 64,
}
DATASET_KWARGS = {}
SEED = 0

BATCH_SIZE = 32
TEST_BATCH_SIZE = 10
VAL_BATCH_SIZE = 10

MAX_TEST_SET_SIZE = 200
MAX_VAL_SET_SIZE = 200
MAX_VAL_SET_SIZE = 200

SAMPLING_FRACTION = 1.0

# Model parameters
TRAIN_TORCH = True
HIDDEN_SIZE = 32
NUM_EPOCHS = 30
LEARNING_RATE = 0.001
L1_WEIGHT = 0.0
L2_WEIGHT = 0.0
DROPOUT_PROB = 0
NUM_ATTENTION_HEADS = 4
NUM_HIDDEN_LAYERS = 1
# NUM_HIDDEN_LAYERS = 2
ATTENTION_PROBS_DROPOUT_PROB = 0
# ATTENTION_PROBS_DROPOUT_PROB = 0.2
MAX_POSITION_EMBEDDINGS = 512
OVERWRITE_MODEL = True
MODEL_CONFIG_PATH = "../brunoflow/models/bert/config-toy.json"
VALIDATE_DURING_TRAINING = False
EPOCHS_FOR_VALIDATION = [0, 1, 2, 3, 4, 5, 10, 20, 30]

# Analysis parameters
COMPUTE_ENTROPY = True

# Run parameters
PM_RUN_ID = "run_id"
PROJECT_NAME = "bauer-bert-synthetic"
# GROUP_NAME = "trial"
# GROUP_NAME = None

TAGS = ["train_torch", "trial", "bincountones"]

In [5]:
# Construct dataset and data ids
dataset = getattr(sys.modules[__name__], DATASET_NAME)(
    **{**DATASET_KWARGS_IDENTIFIABLE, **DATASET_KWARGS}
)
data_id = f"{dataset.get_name()}"
data_dir = os.path.join("data", DATASET_NAME, data_id, f"{SAMPLING_FRACTION}-{SEED}")
input_dir = os.path.join(data_dir, "inputs")
train_data_path = os.path.join(input_dir, "train.pt")
val_data_path = os.path.join(input_dir, "val.pt")
val_data_df_path = os.path.join(input_dir, "val_data.csv")
test_data_path = os.path.join(input_dir, "test.pt")

# Construct model id
model_id = f"Bert-hs{HIDDEN_SIZE}-numheads{NUM_ATTENTION_HEADS}-bs{BATCH_SIZE}-lr{LEARNING_RATE}-n{NUM_EPOCHS}"
model_id += f"-l1_weight{L1_WEIGHT}" if L1_WEIGHT != 0 else ""
model_id += f"-l2_weight{L2_WEIGHT}" if L2_WEIGHT != 0 else ""
model_id += f"-dropoutprob{DROPOUT_PROB}" if DROPOUT_PROB != 0 else ""

model_dir = os.path.join(data_dir, "models", model_id)

print(f"Data dir: {data_dir}")
print(f"Model dir: {model_dir}")

# Construct model kwargs
model_config_kwargs = dict(
    hidden_size=HIDDEN_SIZE,
    num_attention_heads=NUM_ATTENTION_HEADS,
    num_hidden_layers=NUM_HIDDEN_LAYERS,
    attention_probs_dropout_prob=ATTENTION_PROBS_DROPOUT_PROB,
    max_position_embeddings=MAX_POSITION_EMBEDDINGS,
)

if DATASET_NAME == "BinCountOnes":
    model_config_kwargs[
        "num_labels"
    ] = (
        dataset.seq_len
    )  # set the num_labels to the max possible (even when num_classes is lower) so that the models have same number of connections regardless of the number of classes.

# Set random seeds
torch.manual_seed(SEED)
np.random.seed(SEED)

Data dir: data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0
Model dir: data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/models/Bert-hs32-numheads4-bs32-lr0.001-n30


In [6]:
params_to_log = {
    k: v
    for k, v in locals().items()
    if k.isupper() and k not in excluded_all_caps_params
}

In [7]:
run = wandb.init(
    project=PROJECT_NAME,
    # group=GROUP_NAME,
    config=params_to_log,
    tags=TAGS,
)
print(dict(wandb.config))

[34m[1mwandb[0m: Currently logged in as: [33mkdu[0m ([33methz-rycolab[0m). Use [1m`wandb login --relogin`[0m to force relogin
wandb: ERROR Failed to sample metric: Not Supported


{'DATASET_NAME': 'BinCountOnes', 'DATASET_KWARGS_IDENTIFIABLE': {'num_points': 1200, 'num_classes': 2, 'seq_len': 64}, 'DATASET_KWARGS': {}, 'SEED': 0, 'BATCH_SIZE': 32, 'TEST_BATCH_SIZE': 10, 'VAL_BATCH_SIZE': 10, 'MAX_TEST_SET_SIZE': 200, 'MAX_VAL_SET_SIZE': 200, 'SAMPLING_FRACTION': 1.0, 'TRAIN_TORCH': True, 'HIDDEN_SIZE': 32, 'NUM_EPOCHS': 30, 'LEARNING_RATE': 0.001, 'L1_WEIGHT': 0.0, 'L2_WEIGHT': 0.0, 'DROPOUT_PROB': 0, 'NUM_ATTENTION_HEADS': 4, 'NUM_HIDDEN_LAYERS': 1, 'ATTENTION_PROBS_DROPOUT_PROB': 0, 'MAX_POSITION_EMBEDDINGS': 512, 'OVERWRITE_MODEL': True, 'MODEL_CONFIG_PATH': '../brunoflow/models/bert/config-toy.json', 'VALIDATE_DURING_TRAINING': False, 'EPOCHS_FOR_VALIDATION': [0, 1, 2, 3, 4, 5, 10, 20, 30], 'COMPUTE_ENTROPY': True, 'PM_RUN_ID': 'run_id', 'PROJECT_NAME': 'bauer-bert-synthetic', 'TAGS': ['train_torch', 'trial', 'bincountones']}


### Data Retrieval and Preprocessing

In [8]:
def torch_subset_to_df(torch_data: torch.utils.data.dataset.Subset):
    return [tuple(el.numpy().tolist() for el in row) for row in torch_data]

In [9]:
if (
    os.path.isfile(train_data_path)
    and os.path.isfile(val_data_path)
    and os.path.isfile(test_data_path)
):
    print(
        f"Loading cached train and test sets from {train_data_path} and {test_data_path}."
    )
    train_data = torch.load(train_data_path)
    val_data = torch.load(val_data_path)
    test_data = torch.load(test_data_path)
else:
    train_data = dataset.get_train_data()
    val_data = dataset.get_val_data()
    test_data = dataset.get_test_data()

    # Save the dataset
    os.makedirs(input_dir, exist_ok=True)
    torch.save(train_data, train_data_path)
    torch.save(val_data, val_data_path)
    torch.save(test_data, test_data_path)

    val_data_df = pd.DataFrame(
        torch_subset_to_df(val_data), columns=["sentence", "label"]
    )
    val_data_df.to_csv(val_data_df_path)

train_kwargs = {"batch_size": BATCH_SIZE}
val_kwargs = {"batch_size": VAL_BATCH_SIZE}
test_kwargs = {"batch_size": TEST_BATCH_SIZE}

train_loader = torch.utils.data.DataLoader(train_data, **train_kwargs)
val_loader = torch.utils.data.DataLoader(val_data, **val_kwargs)
test_loader = torch.utils.data.DataLoader(test_data, **test_kwargs)

Loading cached train and test sets from data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/inputs/train.pt and data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/inputs/test.pt.


In [10]:
print(val_data_df_path, model_dir)

data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/inputs/val_data.csv data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/models/Bert-hs32-numheads4-bs32-lr0.001-n30


In [11]:
# After loading/preprocessing your dataset, log it as an artifact to W&B
print(f"Logging datasets to w&b run {wandb.run}.")
artifact = wandb.Artifact(name=data_id, type="dataset")
artifact.add_dir(local_path=input_dir)
run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/inputs)... Done. 0.0s


Logging datasets to w&b run <wandb.sdk.wandb_run.Run object at 0x7f54c87d1eb0>.


<wandb.sdk.wandb_artifacts.Artifact at 0x7f54c85c7a90>

In [12]:
examples = iter(test_loader)
example_data, example_targets = examples.next()
example_data, example_targets

(tensor([[14, 11,  9,  4,  1,  1,  1,  4,  9, 17,  1, 15, 15, 14, 17,  8,  1,  3,
          12,  8, 12,  5, 15,  1,  1, 13, 13, 10,  1,  1, 16,  1,  1, 16,  9,  1,
          11,  9,  1, 10,  2, 17,  1, 18, 18,  1,  1,  8,  6,  3,  1,  9, 19,  1,
          12,  1,  1,  2,  3, 15,  9,  9,  1,  1],
         [ 1,  1,  1,  3,  6, 18,  1, 16,  1,  1,  1,  2,  1, 12,  1,  8,  9,  1,
           1, 16,  1, 12, 19,  5,  8,  1, 14,  1,  5,  9,  2,  1,  1,  1,  1, 19,
           6, 12,  1, 14,  2,  4,  2,  1, 13,  1,  7,  1,  1,  1,  1,  8,  2, 10,
           1, 10,  1,  7, 12, 14,  1, 17,  1,  1],
         [ 1, 11,  1,  9,  1,  5,  1, 13, 13, 14,  1, 14,  1,  1,  1,  1,  2,  1,
           7,  1,  3,  1,  1,  8,  4,  1, 15,  1,  1, 17,  1,  1,  1, 14, 10,  5,
           1,  1,  5, 15,  1,  1,  1,  9,  1,  1,  1,  1, 14,  8,  5,  1,  1, 18,
           1,  1, 16,  1,  1,  1,  7,  1,  1,  1],
         [ 1,  1, 18,  3,  1, 18,  1,  3, 17, 13,  2,  7,  1, 13,  7,  7,  1,  1,
          19,  7, 13, 10,  

### Define training loop

In [13]:
def load_bf_model(config, model_path):
    """Load a BF model from a saved model path (either torch or bf)"""
    bf_model = BfBertForSequenceClassification(config)
    bf_model.load_state_dict(torch.load(model_path))
    return bf_model

In [14]:
def validation(
    model,
    optimizer,
    val_loader,
    epoch,
    batch,
    compute_entropy=False,
    max_val_set_size=None,
):
    # Initialize accumulators (bc we will need to break validation into many batches for computing entropy, etc)
    total_correct = 0
    total_loss = 0
    total_test_points = 0
    total_entropy = 0
    total_grad = 0
    total_abs_val_grad = 0

    model.eval()

    # Loop through each batch in the val_loader
    for inputs, labels in val_loader:
        inputs = inputs.numpy()
        labels = labels.numpy()

        # Apply model to inputs
        bert_outputs = model(inputs, labels=labels)
        logits = bert_outputs.logits
        num_correct_in_batch = sum(np.argmax(logits.val, axis=1) == labels)
        loss: bf.Node = cross_entropy_loss(logits, labels, reduction="sum")

        if compute_entropy:
            # Compute and accumulate entropy, grad, abs_val_grad for the batch in the validation set
            optimizer.zero_gradients()
            model.train()
            loss.backprop(values_to_compute=("abs_val_grad", "entropy", "grad"))
            model.eval()
            entropy_per_example_per_token: np.ndarray = gather_entropies_of_input_ids(
                model=model, input_ids=inputs
            )  # shape: (bs, seq_len)

            abs_val_grads_per_example_per_token = gather_abs_grad_of_input_ids(
                model, inputs
            )  # shape=(len(input_ids), hidden_sz)
            grads_per_example_per_token = gather_grad_of_input_ids(
                model, inputs
            )  # shape=(len(input_ids), hidden_sz)

            entropy = np.sum(entropy_per_example_per_token)  # shape: ()
            abs_val_grad = np.sum(abs_val_grads_per_example_per_token)
            grad = np.sum(grads_per_example_per_token)

            total_entropy += entropy
            total_abs_val_grad += abs_val_grad
            total_grad += grad

        total_loss += loss
        total_correct += num_correct_in_batch
        total_test_points += len(labels)
        if max_val_set_size is not None and total_test_points > max_val_set_size:
            break

    # Compute mean statistics across entire validation cohort
    accuracy = total_correct / total_test_points
    mean_entropy = total_entropy / total_test_points if compute_entropy else None
    mean_grad = total_grad / total_test_points if compute_entropy else None
    mean_abs_val_grad = (
        total_abs_val_grad / total_test_points if compute_entropy else None
    )
    mean_loss = total_loss / total_test_points

    val_metrics = {
        "val": {
            "loss": mean_loss.val,
            "grad": mean_grad,
            "abs_grad": mean_abs_val_grad,
            "entropy": mean_entropy,
            "epoch": epoch,
            "batch": batch,
            "accuracy": accuracy,
        }
    }

    if wandb.run is not None:
        wandb.log(val_metrics)

    model.train()

    return val_metrics

In [15]:
def train_torch_model(
    model: torch.nn.Module,
    optimizer: torch.optim.Optimizer,
    train_loader: torch.utils.data.DataLoader,
    model_dir: str,
    overwrite_model=False,
    num_epochs=NUM_EPOCHS,
    l1_weight=0,
    l2_weight=0,
    validation_params_and_loader=dict(),
):
    """
    Train a torch model.
    validation_params_and_loader is a parameter that should be an empty dict if not doing validation.
    """
    required_val_keys = {
        "val_loader",
        "compute_entropy",
        "epochs_for_validation",
        "max_val_set_size",
    }
    if validation_params_and_loader and not required_val_keys.issubset(
        set(validation_params_and_loader.keys())
    ):
        raise ValueError(
            f"Included a nonempty validation_params_and_loader with keys {set(validation_params_and_loader.keys())}, but didn't contain all of the required keys: {required_val_keys}."
        )

    model.train()
    os.makedirs(model_dir, exist_ok=True)
    model_path = os.path.join(model_dir, "model.pt")

    # Load model if already trained
    if os.path.isfile(model_path):
        if overwrite_model:
            print(f"Retraining and overwriting model at {model_path}.")
        else:
            print(f"Loading trained model from {model_path}.")
            model = torch.load(model_path)
            # Wow there's some DANGER here of the optimizer being attached to a different model than the one being loaded here...oof

            return model

    # Train the model
    losses = dict()
    n_total_steps = len(train_loader)
    for epoch in range(num_epochs):
        for i, (inputs, labels) in enumerate(train_loader):
            # Forward pass
            bert_outputs = model(inputs, labels=labels)
            unregularized_loss = bert_outputs.loss
            # unregularized_loss = cross_entropy_loss(outputs, labels)
            loss = unregularized_loss + regularize(
                model=model, l1_weight=l1_weight, l2_weight=l2_weight
            )

            # Zero out all semiring values (e.g. gradients, entropy, etc) from the previous minibatch to compute the gradients and entropy correctly.
            optimizer.zero_grad()
            unregularized_loss.backward()
            optimizer.step()

            if wandb.run is not None:
                wandb.log(
                    {
                        "train": {
                            "unregularized_loss": unregularized_loss,
                            "epoch": epoch,
                            "batch": i,
                        }
                    }
                )

            if i % 100 == 0:
                print(
                    f"Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i}/{n_total_steps}], "
                    f"Cross-Entropy Loss: {unregularized_loss:.4f}, "
                )

                train_step_num = i + n_total_steps * epoch
                losses[train_step_num] = loss

                # Construct and save entropies and losses to CSV
                losses_df = pd.DataFrame([losses]).melt(
                    var_name="steps", value_name="loss"
                )
                losses_df.to_csv(os.path.join(model_dir, "losses.csv"))

            if (
                validation_params_and_loader
                and epoch in validation_params_and_loader["epochs_for_validation"]
                and i == 0
            ):
                print(f"Starting validation at epoch {epoch}")
                # Note: if you add/change the keys of validation_params_and_loader, you'll need to change the assert at the start of this fct
                val_loader = validation_params_and_loader["val_loader"]
                compute_entropy = validation_params_and_loader["compute_entropy"]
                max_val_set_size = validation_params_and_loader["max_val_set_size"]
                # Run validation
                torch.save(model.state_dict(), model_path)
                bf_model = load_bf_model(model.config, model_path=model_path)
                bf_optimizer = Adam(bf_model.parameters(), step_size=LEARNING_RATE)
                val_metrics = validation(
                    model=bf_model,
                    optimizer=bf_optimizer,
                    val_loader=val_loader,
                    epoch=epoch,
                    batch=i,
                    compute_entropy=compute_entropy,
                    max_val_set_size=max_val_set_size,
                )
                print(
                    f"VALIDATION @ Epoch [{epoch+1}/{NUM_EPOCHS}], Step [{i}/{n_total_steps}], "
                    # f'Total Loss: {regularized_loss.val:.4f}, '
                    # f"Unregularized Cross-Entropy Loss: {unregularized_loss.val:.4f}, "
                    f"Accuracy: {val_metrics['val']['accuracy']:.4f}, "
                    f"Cross-Entropy Loss: {val_metrics['val']['loss']:.4f}, "
                    f"Entropy: {val_metrics['val']['entropy']:.4f}"
                )

    # Save model
    torch.save(model.state_dict(), model_path)

    # Log to W&B
    if wandb.run is not None:
        print(f"Logging model to w&b run {wandb.run}.")
        artifact = wandb.Artifact(name="model", type="model")
        artifact.add_file(local_path=model_path)
        run.log_artifact(artifact)

    return model, model_path

### Define evaluation

In [16]:
def evaluate_model(
    model,
    model_dir,
    test_loader,
    optimizer,
    compute_entropy=False,
    overwrite_metrics=False,
    max_test_set_size: int = None,
):
    model.eval()
    metrics_path = os.path.join(model_dir, "metrics.csv")
    if os.path.isfile(metrics_path):
        if overwrite_metrics:
            print(f"Recomputing and overwriting metrics at {metrics_path}.")
        else:
            print(f"Loading precomputed metrics from {metrics_path}.")
            return pd.read_csv(metrics_path)

    metrics = dict()
    total_correct = 0
    total_loss = 0
    total_test_points = 0
    total_entropy = 0

    for inputs, labels in test_loader:
        inputs = inputs.numpy()
        labels = labels.numpy()
        bert_outputs = model(inputs, labels=labels)
        logits = bert_outputs.logits
        num_correct_in_batch = sum(np.argmax(logits.val, axis=1) == labels)
        loss: bf.Node = cross_entropy_loss(logits, labels, reduction="sum")

        if compute_entropy:
            optimizer.zero_gradients()
            model.train()
            loss.backprop(values_to_compute=("abs_val_grad", "entropy"))
            model.eval()
            entropy_per_example_per_token: np.ndarray = gather_entropies_of_input_ids(
                model=model, input_ids=inputs
            )  # shape: (bs, seq_len)
            entropy = np.sum(entropy_per_example_per_token)  # shape: ()
            total_entropy += entropy

        total_loss += loss
        total_correct += num_correct_in_batch
        total_test_points += len(labels)
        if max_test_set_size is not None and total_test_points > max_test_set_size:
            break

    accuracy = total_correct / total_test_points
    mean_entropy = total_entropy / total_test_points if compute_entropy else None
    loss = total_loss / total_test_points

    metrics = {"accuracy": accuracy, "loss": loss.val, "entropy": mean_entropy}
    metrics_df = pd.DataFrame.from_dict([metrics])

    metrics_df.to_csv(os.path.join(model_dir, "metrics.csv"))

    # Log to W&B
    if wandb.run is not None:
        print(f"Logging test metrics to w&b run {wandb.run}.")
        wandb.log({"test": metrics})

    return metrics_df

### Train and evaluate models

In [17]:
profile_dir = os.path.join(model_dir, "mem_profiles")
os.makedirs(profile_dir, exist_ok=True)

In [18]:
# Initialize model
with catchtime() as t:
    config = BertConfig.from_pretrained(
        pretrained_model_name_or_path=MODEL_CONFIG_PATH,
        **model_config_kwargs,
    )
    config.to_json_file(os.path.join(model_dir, "config.json"))

    model = (
        BfBertForSequenceClassification(config=config)
        if not TRAIN_TORCH
        else BertForSequenceClassification(config=config)
    )


init_model_time = t.time
init_model_nvidia_mem_used = gpu_memory_usage()
try:
    jax.profiler.save_device_memory_profile(
        os.path.join(profile_dir, f"toy_bert_init.prof")
    )
except:
    print("Can't jax profile because no jax")

# Initialize optimizer

# Train model
with catchtime() as t:
    if TRAIN_TORCH:
        validation_params_and_loader = {
            "val_loader": val_loader,
            "compute_entropy": COMPUTE_ENTROPY,
            "max_val_set_size": MAX_VAL_SET_SIZE,
            "epochs_for_validation": EPOCHS_FOR_VALIDATION,
        }
        optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
        model, model_path = train_torch_model(
            model,
            optimizer,
            train_loader,
            model_dir,
            overwrite_model=OVERWRITE_MODEL,
            num_epochs=NUM_EPOCHS,
            l1_weight=L1_WEIGHT,
            l2_weight=L2_WEIGHT,
            validation_params_and_loader=validation_params_and_loader
            if VALIDATE_DURING_TRAINING
            else dict(),
        )
    else:
        optimizer = Adam(model.parameters(), step_size=LEARNING_RATE)
        model, model_path = train_model(
            model,
            optimizer,
            train_loader,
            model_dir,
            overwrite_model=OVERWRITE_MODEL,
            num_epochs=NUM_EPOCHS,
            l1_weight=L1_WEIGHT,
            l2_weight=L2_WEIGHT,
        )
train_model_time = t.time
train_model_nvidia_mem_used = gpu_memory_usage()
try:
    jax.profiler.save_device_memory_profile(
        os.path.join(
            profile_dir,
            f"toy_bert_train_model.prof",
        )
    )
except:
    print("Can't jax profile because no jax")

# Evaluate model
if TRAIN_TORCH:  # Load a BF model if model is currently a torch model
    model = load_bf_model(config, model_path=model_path)
    optimizer = Adam(model.parameters(), step_size=LEARNING_RATE)

with catchtime() as t:
    # validation(
    #     model,
    #     optimizer,
    #     val_loader,
    #     l1_weight=L1_WEIGHT,
    #     l2_weight=L2_WEIGHT,
    #     epoch=1,
    #     batch=i,
    # )
    metrics = evaluate_model(
        model,
        model_dir,
        test_loader,
        optimizer=optimizer,
        compute_entropy=COMPUTE_ENTROPY,
        overwrite_metrics=OVERWRITE_MODEL,
        max_test_set_size=MAX_TEST_SET_SIZE,
    )
eval_model_time = t.time
eval_model_nvidia_mem_used = gpu_memory_usage()
try:
    jax.profiler.save_device_memory_profile(
        os.path.join(
            profile_dir,
            f"toy_bert_eval_model.prof",
        )
    )
except:
    print("Can't jax profile because no jax")
print(metrics)

Time: 0.012 seconds
Retraining and overwriting model at data/BinCountOnes/BinCountOnes-num_classes2-seqlen64-num_points1200-vs20/1.0-0/models/Bert-hs32-numheads4-bs32-lr0.001-n30/model.pt.
Epoch [1/30], Step [0/27], Cross-Entropy Loss: 4.1664, 
Epoch [2/30], Step [0/27], Cross-Entropy Loss: 2.8623, 
Epoch [3/30], Step [0/27], Cross-Entropy Loss: 1.1996, 
Epoch [4/30], Step [0/27], Cross-Entropy Loss: 0.8474, 
Epoch [5/30], Step [0/27], Cross-Entropy Loss: 0.7770, 
Epoch [6/30], Step [0/27], Cross-Entropy Loss: 0.7501, 
Epoch [7/30], Step [0/27], Cross-Entropy Loss: 0.7357, 
Epoch [8/30], Step [0/27], Cross-Entropy Loss: 0.7267, 
Epoch [9/30], Step [0/27], Cross-Entropy Loss: 0.7205, 
Epoch [10/30], Step [0/27], Cross-Entropy Loss: 0.7161, 
Epoch [11/30], Step [0/27], Cross-Entropy Loss: 0.7128, 
Epoch [12/30], Step [0/27], Cross-Entropy Loss: 0.7102, 
Epoch [13/30], Step [0/27], Cross-Entropy Loss: 0.7081, 
Epoch [14/30], Step [0/27], Cross-Entropy Loss: 0.7065, 
Epoch [15/30], Step [0

: 

: 

In [None]:
wandb.log(
    {
        "init_model_time": init_model_time,
        "train_model_time": train_model_time,
        "eval_model_time": eval_model_time,
        "init_model_nvidia_mem_used": init_model_nvidia_mem_used,
        "eval_model_nvidia_mem_used": eval_model_nvidia_mem_used,
    }
)

In [None]:
print(f"Logging profiling info to w&b run {wandb.run}.")
artifact = wandb.Artifact(name=f"{data_id}_profiles", type="profiles")
artifact.add_dir(local_path=profile_dir)
run.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data/BinCountOnes/BinCountOnes-num_classes2-seqlen12-num_points1200-vs20/1.0-0/models/Bert-hs32-numheads4-bs32-lr0.001-n5/mem_profiles)... Done. 0.0s


Logging profiling info to w&b run <wandb.sdk.wandb_run.Run object at 0x7f37af69e8b0>.


<wandb.sdk.wandb_artifacts.Artifact at 0x7f369814b4c0>

In [None]:
wandb.finish()

0,1
eval_model_time,▁
init_model_time,▁
train_model_time,▁

0,1
eval_model_nvidia_mem_used,1881 MiB
eval_model_time,9.71724
init_model_nvidia_mem_used,87 MiB
init_model_time,0.00361
train_model_time,62.42895
