In [16]:
from datasets import load_dataset
from torch.utils.data import DataLoader

dataset = load_dataset("yelp_review_full")
train_loader = DataLoader(dataset["train"], batch_size=4)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import argparse
import datetime
import functools
import json
import os
import random
from typing import Optional, Sequence, Tuple, Type

import numpy as np
import wandb

import torch
import torch.distributed as dist
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP
)
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

from sequence_models.samplers import (
    SortishSampler,
    ApproxBatchSampler,
    ClusteredSortishSampler,
)

from evodiff.utils import Tokenizer
from dayhoff.collators import LMCollator, OAMaskCollator
from dayhoff.constants import MSA_ALPHABET_PLUS, TaskType
from dayhoff.datasets import UniRefDataset
from dayhoff.model import (
    ARDiffusionModel,
    OrderAgnosticDiffusionModel,
)
from dayhoff.model import create_model


# default to a single-GPU setup if not present
RANK = 0
LOCAL_RANK = 0
WORLD_SIZE = 1
DEVICE = torch.device(f"cuda:0")


def is_amlt() -> bool:
    return os.environ.get("AMLT_OUTPUT_DIR", None) is not None


def load_config_and_model(
    config_fpath: str,
) -> Tuple[dict, Tokenizer, nn.Module, Type[nn.Module]]:
    """Parses the experiment config to load the model and tokenizer

    Parameters:
    -----------
    config_fpath: str
        The path to the experiment config file

    Returns:
    --------
    config: dict
        The experiment config
    tokenizer: Tokenizer
        The model's tokenizer
    model: nn.Module
        A task-wrapped version of the specified model, which returns the appropriate loss and metrics
    block: Type[nn.Module]
        The block class used repeatedly in the module. It should not be split by any sharding.
    """
    with open(config_fpath, "r") as f:
        config = json.load(f)
    config["task"] = config["task"].lower().strip()
    tokenizer = Tokenizer(MSA_ALPHABET_PLUS)
    task = TaskType(config["task"].lower().strip())

    # create the model
    model, block = create_model(
        task, config["model_type"], config["model_config"], tokenizer.mask_id.item()
    )

    # add the task-specific wrapper
    aux_loss_weight = config.get("aux_loss_weight", 0.0)
    if task == TaskType.OADM:
        model = OrderAgnosticDiffusionModel(
            model, tokenizer.pad_id, aux_loss_weight=aux_loss_weight
        )
    elif task == TaskType.LM:
        model = ARDiffusionModel(model, aux_loss_weight=aux_loss_weight)
    else:
        raise ValueError(f"Unknown task: {config['task']}")
    return config, tokenizer, model, block


def get_dataloader(
    config: dict, tokenizer: Tokenizer, args: argparse.Namespace
) -> DataLoader:
    if is_amlt():
        data_top_dir = args.data_root or "/ddn/evodiff/"
    else:
        data_top_dir = args.data_root or "/data1/data/"

    dataset = config["dataset"]
    data_dir = os.path.join(data_top_dir, dataset + "/")

    if config["task"] == "oadm":
        collator = OAMaskCollator(
            tokenizer=tokenizer,
            pad_to_multiple_of=config.get("pad_to_multiple_of", None),
        )
    elif config["task"] == "lm":
        collator = LMCollator(
            tokenizer=tokenizer,
            pad_to_multiple_of=config.get("pad_to_multiple_of", None),
            flip_prob=config.get("flip_prob", 0.0),
            fim_prob=config.get("fim_prob", 0.0),
            swap_bos_eos_on_flip=config.get("swap_bos_eos_on_flip", True),
        )
    else:
        raise ValueError(f"Unknown task: {config['task']}")

    # load the dataset
    ds_train = UniRefDataset(data_dir, "train", max_len=config["max_len"])
    train_idx = ds_train.indices

    # create the dataloader
    if args.mini_run:
        tindices = np.arange(
            0, 1000
        )  # np.arange(21546293,31546293,1)#(1000000,21546293, 1)
        train_indices = np.sort(np.random.choice(tindices, 100, replace=False))
        train_sampler = Subset(ds_train, train_indices)
        len_train = train_indices
        dl_train = DataLoader(
            dataset=train_sampler,
            shuffle=True,
            batch_size=1,
            num_workers=1,
            collate_fn=collator,
        )
    else:
        metadata = np.load(os.path.join(data_dir, "lengths_and_offsets.npz"))
        len_train = np.minimum(metadata["ells"][train_idx], config["max_len"])
        if "uniref50" in dataset:
            train_sortish_sampler = SortishSampler(
                len_train, config["bucket_size"], num_replicas=WORLD_SIZE, rank=RANK
            )
        elif "uniref90" in dataset:
            with open(os.path.join(data_dir) + "clustered_splits.json") as f:
                clusters = json.load(f)["train"]
            train_sortish_sampler = ClusteredSortishSampler(
                len_train,
                clusters,
                config["bucket_size"],
                num_replicas=WORLD_SIZE,
                rank=RANK,
            )
        train_sampler = ApproxBatchSampler(
            train_sortish_sampler,
            config["max_tokens"],
            config["max_batch_size"],
            len_train,
            batch_mult=8,
        )

        dl_train = DataLoader(
            dataset=ds_train,
            batch_sampler=train_sampler,
            num_workers=8,
            collate_fn=collator,
            pin_memory=True,
        )

    return dl_train


def seed_everything(seed: int) -> None:
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)




In [19]:
parser = argparse.ArgumentParser()
parser.add_argument("config_fpath")
parser.add_argument(
    "out_fpath",
    type=str,
    nargs="?",
    default=os.getenv("AMLT_OUTPUT_DIR", "/tmp") + "/",
)
parser.add_argument("data_root", type=str, nargs="?", default=None)
parser.add_argument(
    "--mini_run", action="store_true"
)  # Set to True if running on subset of data
parser.add_argument("--checkpoint_freq", type=int, default=2000)  # in steps
parser.add_argument(
    "--random_seed", type=int, default=0
)  # lambda reweighting term from Austin D3PM
parser.add_argument("--dtype", type=str, default="bfloat16")
parser.add_argument("--verbose", action="store_true")
parser.add_argument("--no_wandb", action="store_true")
parser.add_argument("--last_step", default=-1, type=int)

fake_args = [
    "jamba3B-uniref50.json",       # config_fpath
    "/tmp/output/",        # out_fpath
    "../data/",  # data_root
    # "--mini_run",
    "--no_wandb"
]

args = parser.parse_args(fake_args)

In [20]:
print(
    f"Starting job on rank {RANK} with local rank {LOCAL_RANK} and world size {WORLD_SIZE}"
)
seed_everything(args.random_seed)

# dist.init_process_group(backend="nccl")
# get the config, tokenizer, and model
if args.verbose:
    print("Initializing model...", RANK)
config, tokenizer, model, blk_types = load_config_and_model(args.config_fpath)
if RANK == 0:
    if args.no_wandb:
        wandbmode = "disabled"
    else:
        wandbmode = "online"
    wandb.init(config=config, mode=wandbmode)
if args.verbose:
    print("Done initializing model.", RANK)

# store the command line args in the config and dump to disk
config["dtype"] = args.dtype
config["random_seed"] = args.random_seed
config["world_size"] = WORLD_SIZE
if RANK == 0:
    os.makedirs(args.out_fpath, exist_ok=True)
    with open(os.path.join(args.out_fpath, "config.json"), "w") as f:
        json.dump(config, f)

# training dtype and local device
dtype = {
    "float32": torch.float32,
    "float16": torch.float16,
    "bfloat16": torch.bfloat16,
}[args.dtype]

padding_idx = tokenizer.pad_id  # PROTEIN_ALPHABET.index(PAD)
if RANK == 0:
    print("Using {} as padding index".format(padding_idx))
    print("Using {} as masking index".format(tokenizer.mask_id))
    print(
        f"Model has {sum(p.numel() for p in model.parameters())} trainable parameters."
    )
if args.verbose:
    print("Moving and sharding model...", RANK)
# set the default device
torch.cuda.set_device(LOCAL_RANK)


if args.verbose:
    print("Initializing data...", RANK)
dl_train = get_dataloader(config, tokenizer, args)

Starting job on rank 0 with local rank 0 and world size 1




Using 30 as padding index
Using 28 as masking index
Model has 2981431616 trainable parameters.


In [23]:
dl_train_iter = iter(dl_train)

In [24]:
dl_train_iter_obs = next(dl_train_iter)

In [30]:
dl_train_iter_obs[0].shape,dl_train_iter_obs[1].shape

(torch.Size([1000, 80]), torch.Size([1000, 80]))

In [None]:
dl_train_iter_obs[0].shape

In [18]:
for batch in train_loader:
    print(batch)
    break

{'label': tensor([4, 1, 3, 3]), 'text': ["dr. goldberg offers everything i look for in a general practitioner.  he's nice and easy to talk to without being patronizing; he's always on time in seeing his patients; he's affiliated with a top-notch hospital (nyu) which my parents have explained to me is very important in case something happens and you need surgery; and you can get referrals to see specialists without having to see him first.  really, what more do you need?  i'm sitting here trying to think of any complaints i have about him, but i'm really drawing a blank.", "Unfortunately, the frustration of being Dr. Goldberg's patient is a repeat of the experience I've had with so many other doctors in NYC -- good doctor, terrible staff.  It seems that his staff simply never answers the phone.  It usually takes 2 hours of repeated calling to get an answer.  Who has time for that or wants to deal with it?  I have run into this problem with many other doctors and I just don't get it.  Yo