In [None]:
def train_loop(
    wrapped_model, 
    loader,  # training data loader
    device,
    eval_train_data,
    eval_train_labels,
    eval_test_data = None,  # None when eval is on MTEB or no train/test split
    eval_test_labels =None,
    eval_every_epochs =  True, #bool, {True, False} 
    eval_every_batches = 0, # int, 0 would be like none
    eval_function = KNNEval,
    pooler = mean_pool,
    eval_rep=None, # representation to evaluate, if None it is the same used by pooler
    dist_metric = "euclidean",
    mteb_saving_path = None,
    mteb_tasks = None, 
    n_epochs=1,
    lr=2e-5,
    scale = 20.0,  # we multiply similarity score by this scale value, it is the inverse of the temperature
):
    
    ## training set up
    ...

    # initialize eval list
    if (eval_every_epochs != 0) | (eval_every_batches != 0):
        training_eval_results = defaultdict(list)

    ## training
    for epoch in range(n_epochs):
        wrapped_model.model.train()  # make sure model is in training mode
        # initialize the dataloader loop with tqdm (tqdm == progress bar)
        loop = tqdm(loader, leave=True)
        for i_batch, batch in enumerate(loop):
            ## train -- finished
            # zero all gradients on each new step
            optim.zero_grad()
            # prepare batches and move all to the active device
            anchor_ids = batch[0][0].to(device)     # this are all anchor abstracts from the batch,len(anchor_ids)= len(batch)
            anchor_mask = batch[0][1].to(device)
            pos_ids = batch[1][0].to(device)       # this each positive pair from each anchor, all in one array, also len(batch)
            pos_mask = batch[1][1].to(device)

            # get hidden state
            a = wrapped_model.get_outputs(input_ids=anchor_ids, attention_mask=anchor_mask)
            p = wrapped_model.get_outputs(input_ids = pos_ids, attention_mask=pos_mask)
            
            # get the mean pooled vectors  
            a = pooler(a, anchor_mask)
            p = pooler(p, pos_mask)

            # loss
            ...

            ## evaluation every batches
            if eval_every_batches != 0:
                eval_results = eval_function()
                [training_eval_results[k].append(v) for k, v in eval_results.items()]

        if eval_every_epochs != 0:
            eval_results = eval_function()
            [training_eval_results[k].append(v) for k, v in eval_results.items()]


    if (eval_every_epochs != 0) | (eval_every_batches != 0):
        ...
        return losses, df_training_eval_results
    else:
        return losses