In [None]:
MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32


def get_dataloaders(accelerator: Accelerator, batch_size: int = 16):
    """
    Creates a set of `DataLoader`s for the `glue` dataset,
    using "bert-base-cased" as the tokenizer.

    Args:
        accelerator (`Accelerator`):
            An `Accelerator` object
        batch_size (`int`, *optional*):
            The batch size for the train and validation DataLoaders.
    """
    tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
    datasets = load_dataset("glue", "mrpc")

    def tokenize_function(examples):
        # max_length=None => use the model max length (it's actually the default)
        outputs = tokenizer(examples["sentence1"], examples["sentence2"], 
                            truncation=True, max_length=None)
        return outputs

    # Apply the method we just defined to all the examples in all the splits of the dataset
    # starting with the main process first:
    with accelerator.main_process_first():
        tokenized_datasets = datasets.map(
            tokenize_function,
            batched=True,
            remove_columns=["idx", "sentence1", "sentence2"],
        )

    # We also rename the 'label' column to 'labels' which is the expected name for labels 
    # by the models of the transformers library
    tokenized_datasets = tokenized_datasets.rename_column("label", "labels")

    def collate_fn(examples):
        # On TPU it's best to pad everything to the same length or training will be very slow.
        max_length = 128 if accelerator.distributed_type == DistributedType.TPU else None
        # When using mixed precision we want round multiples of 8/16
        if accelerator.mixed_precision == "fp8":
            pad_to_multiple_of = 16
        elif accelerator.mixed_precision != "no":
            pad_to_multiple_of = 8
        else:
            pad_to_multiple_of = None
            
        #print("examples:\n", examples)    

        # tokenizer.pad() is a method used to pad sequences of tokens to a specified length.
        tokens = tokenizer.pad(
            examples,
            padding="longest",
            max_length=max_length,
            #pad_to_multiple_of = 16,
            pad_to_multiple_of=pad_to_multiple_of,
            return_tensors="pt")
        
        ## print each token's shape 
        #for i, token in enumerate(tokens["input_ids"]):
        #    print(i, token.shape)
        
        return tokens
        
        #return tokenizer.pad(
        #    examples,
        #    padding="longest",
        #    max_length=max_length,
        #    pad_to_multiple_of=pad_to_multiple_of,
        #    return_tensors="pt",
        #)

    # Instantiate dataloaders.
    train_dataloader = DataLoader(
        tokenized_datasets["train"], # each tokenized dataset element may have different length.
        shuffle=True, 
        collate_fn=collate_fn,  # each element in a batch w/ batch_size to be paded with the specified length 
        batch_size=batch_size, 
        drop_last=True
    )
    eval_dataloader = DataLoader(
        tokenized_datasets["validation"],
        shuffle=False,
        collate_fn=collate_fn,
        batch_size=EVAL_BATCH_SIZE,
        drop_last=(accelerator.mixed_precision == "fp8"),
    )

    return train_dataloader, eval_dataloader


#def training_function(config, args):
def training_function():
    # Initialize accelerator
    #accelerator = Accelerator(cpu=args.cpu, mixed_precision=args.mixed_precision)
    accelerator = Accelerator()
    # Sample hyper-parameters for learning rate, batch size, seed and a few other HPs
    lr = config["lr"]
    num_epochs = int(config["num_epochs"])
    seed = int(config["seed"])
    batch_size = int(config["batch_size"])

    metric = evaluate.load("glue", "mrpc")

    # If the batch size is too big we use gradient accumulation
    gradient_accumulation_steps = 1
    if batch_size > MAX_GPU_BATCH_SIZE and accelerator.distributed_type != DistributedType.TPU:
        gradient_accumulation_steps = batch_size // MAX_GPU_BATCH_SIZE
        batch_size = MAX_GPU_BATCH_SIZE

    set_seed(seed)
    train_dataloader, eval_dataloader = get_dataloaders(accelerator, batch_size)
    # Instantiate the model (we build the model here so that the seed also control new weights initialization)
    model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", return_dict=True)

    # We could avoid this line since the accelerator is set with `device_placement=True` (default value).
    # Note that if you are placing tensors on devices manually, this line absolutely needs to be before the optimizer
    # creation otherwise training will not work on TPU (`accelerate` will kindly throw an error to make us aware of that).
    model = model.to(accelerator.device)
    # Instantiate optimizer
    optimizer = AdamW(params=model.parameters(), lr=lr)

    # Instantiate scheduler
    num_training_steps= (len(train_dataloader) * num_epochs) // gradient_accumulation_steps
    #print("num_training_steps: ", num_training_steps)
    lr_scheduler = get_linear_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        #num_training_steps=(len(train_dataloader) * num_epochs) // gradient_accumulation_steps,
        num_training_steps = num_training_steps,
    )

    # Prepare everything
    # There is no specific order to remember, we just need to unpack the objects in the same order we gave them to the
    # prepare method.

    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
    )
     
    num_training_steps = num_epochs * len(train_dataloader)  // gradient_accumulation_steps
    #print("after accelerator prepared, num_training_steps: ", num_training_steps)
    #print("length of train_dataloader: ", len(train_dataloader)) 
    progress_bar = tqdm(range(num_training_steps))
    
    # Now we train the model
    for epoch in range(num_epochs):
        model.train()
        for step, batch in enumerate(train_dataloader):
            # We could avoid this line since we set the accelerator with `device_placement=True`.
            #batch.to(accelerator.device)
            outputs = model(**batch)
            loss = outputs.loss
            loss = loss / gradient_accumulation_steps
            accelerator.backward(loss)
            #if step % gradient_accumulation_steps == 0:
            if (step+1) % gradient_accumulation_steps == 0:
                #print(step+1)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)

        model.eval()
        for step, batch in enumerate(eval_dataloader):
            # We could avoid this line since we set the accelerator with `device_placement=True`.
            #batch.to(accelerator.device)
            with torch.no_grad():
                outputs = model(**batch)
            predictions = outputs.logits.argmax(dim=-1)
            references = batch["labels"]
            #predictions, references = accelerator.gather_for_metrics((predictions, references))
            metric.add_batch(
                predictions=predictions,
                references=references,
            )

        eval_metric = metric.compute()
        # Use accelerator.print to print only on the main process.
        #print(f"epoch {epoch}:", eval_metric)
        accelerator.print(f"epoch {epoch}:", eval_metric)

#config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 16}
config = {"lr": 2e-5, "num_epochs": 3, "seed": 42, "batch_size": 8}
#training_function(config, args)

#notebook_launcher(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes)
from accelerate import notebook_launcher

notebook_launcher(training_function, num_processes=1)
        