# Installation and Imports

In [1]:
%%capture
!pip install wandb
!pip install cohere
!pip install datasets

In [2]:
!wandb login

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
[34m[1mwandb[0m: Paste an API key from your profile and hit enter, or press ctrl+c to quit: 
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# Required Imports
import os
import random
import time
import pickle
import numpy as np
import torch
import tqdm
import wandb
from torch.utils.data import Dataset, DataLoader
from torch.distributions import Categorical
from transformers import AutoModel, AutoTokenizer, T5EncoderModel
import cohere
from google.colab import userdata

* 'allow_population_by_field_name' has been renamed to 'populate_by_name'
* 'smart_union' has been removed


# Preparing the dataset for model training

In [5]:
import pandas as pd
import json

json_file = "/content/drive/MyDrive/Research: Elan<>Krupa/processed_squad.json"

with open(json_file, "r") as f:
  data = json.load(f)
len(data), type(data)

(442, list)

In [6]:
data[0].keys()

dict_keys(['title', 'context', 'qas', 'sentences'])

In [7]:
data[0]

{'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary. As at most other universities, Notre Dame\'s students run a number of news media outlets. The nine student-run outlets include three newspapers, both a radio and television station, and several magazines and journals. Begun as a one-page journal in Sep

In [8]:
def find_sentence_by_index(data, index):
    for key, value in data.items():
        if value['start_index'] <= index <= value['end_index']:
            return key, value["end_index"]
    return None


def flattern_data(data_raw):
  data_ready_flat = []
  for doc_data in data_ready:
      for question_data in doc_data.values():
          data_ready_flat.append(question_data)
  return data_ready_flat


def make_model_ready_data(data):
  tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')
  data_ready = []

  for data_index, data_temp in enumerate(data):
      dict_d = {}
      # For text and sentence_end
      para = data[data_index]["context"]
      sentences = []
      sent_end_tokens = []
      # Get all the sentences based on indexes
      for sen_num, sen_idx in data[data_index]["sentences"].items():
          sentence_text = para[sen_idx["start_index"]: sen_idx["end_index"]]
          sentences.append(sentence_text)

      text_tokens = [tokenizer(sen, return_tensors="pt")["input_ids"] for sen in sentences]
      for tensor in text_tokens:
          new_list_single = [0] * len(tensor[0])
          new_list_single[-1] = 1  # Set last element to 1
          sent_end_tokens.append(new_list_single)


      # For answers and answer_sentences
      for i, qas in enumerate(data[data_index]["qas"]):
          question = qas["question"]
          q_tokens = tokenizer(question, return_tensors="pt")["input_ids"]
          answer_index = qas["answers"]["answer_start_full_para"][0]
          answer_sent_index, answer_end_index = find_sentence_by_index(data[data_index]["sentences"], answer_index)
          dict_d[i] = {"text": text_tokens, "title": data_temp["title"],"sentence_end": sent_end_tokens,
                      "question": {"raw":question, "question_token": q_tokens} ,
                      "answer_index": {"answer_sent_index": int(answer_sent_index), "answer_end_index": answer_end_index}}
      data_ready.append(dict_d)
  return data_ready


# Functions for Formatting the dataset for model input.
data_ready = make_model_ready_data(data)
data_ready_flat = flattern_data(data_ready)
len(data_ready_flat), data_ready_flat[0].keys()

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/2.54k [00:00<?, ?B/s]

spiece.model:   0%|          | 0.00/792k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

(10014,
 dict_keys(['text', 'title', 'sentence_end', 'question', 'answer_index']))

In [9]:
# Dataset Builders
def data_collater(batch):
    def flatten_sentence_ends(sentence_ends):
        return torch.tensor([item for sublist in sentence_ends for item in sublist])
    batch_sentence_ends = [flatten_sentence_ends(d['sentence_end']) for d in batch]

    def flatten_text(text_tokens):
      return torch.cat(text_tokens, axis=-1).flatten()
    batch_text = [flatten_text(d['text']) for d in batch]

    batch_questions = [d['question']['raw'] for d in batch]

    def get_answer_idx(sentence_ends_flat, answer_sent_index):
        sentence_ends_flat_idx = sentence_ends_flat.nonzero()
        answer_idx = sentence_ends_flat_idx[answer_sent_index]
        return answer_idx

    batch_answer_idx = [get_answer_idx(sent_ends, d['answer_index']['answer_sent_index'])
                        for sent_ends, d in zip(batch_sentence_ends, batch)]

    batch_title = [t["title"] for t in batch]
    return batch_questions, batch_sentence_ends, batch_text, batch_title, batch_answer_idx


class ListDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data[index]


### Get Wiki Embedding data

Using the subset of embeddings from Cohere/wikipedia-22–12-en-embeddings  ([link](https://huggingface.co/datasets/Cohere/wikipedia-22-12-en-embeddings)) as distraction documents for the retrieval process. We mapped `title` as a key while outputing the dictionaly which saved into the pickle file formats.

In [10]:
with open('/content/drive/MyDrive/Research: Elan<>Krupa/wikipedia-22-12-en-embeddings_squad_doc_subset_4000000.pkl', 'rb') as f:
    wiki_docs = pickle.load(f)

In [11]:
# convert the docs to single matrices for each title
# dict(array([num docs for title x 768])) where each row corresponds to the index in wiki_docs for that title

doc_embs = {}
for title, docs in wiki_docs.items():
    embs = []
    for doc in docs:
        embs.append(doc['emb'])
    doc_embs[title] = torch.from_numpy(np.array(embs))
doc_embs['University of Notre Dame'].shape

torch.Size([99, 768])

# Modeling

In [12]:
def sample_until_answer_divergence(sentence_ends, text, answer_idx, model, greedy=True):
    current_idx = 0
    sample1_single, sample2_single = [], []
    divergence = None
    while current_idx < len(text):
        text_input = text[current_idx:current_idx+400]
        sentence_ends_input = sentence_ends[current_idx:current_idx+400] # [0,0,0,1,0,0,0,1,0,0,0,1] ...
        sentence_ends_idx = torch.nonzero(sentence_ends_input).squeeze(-1) # [3, 7, 11] ... [489, 555, 600]....
        logits = model(text_input.reshape(1,-1), [sentence_ends_idx])[0]

        probs = torch.nn.functional.softmax(logits, dim=-1) # gives logits for chunking at the sentence_ends_idx (same shape as sentence_ends_idx)
        if len(probs) < 2:
            return [], [], None

        if greedy:
            x2, x2_alt = torch.multinomial(probs, 2, replacement=False)
            x1 = torch.argmax(probs, dim=-1)
        else:
            x1, x2 = torch.multinomial(probs, 2, replacement=False)
        if x2 == x1:
            x2 = x2_alt
        chunk_end = current_idx + sentence_ends_idx[x1] + 1
        if chunk_end < answer_idx: # chunks before answer index.
            sample1_single.append(chunk_end)
            sample2_single.append(chunk_end)

            current_idx = chunk_end
        else: # i.e. chunks after answer index
            sample1_single.append(chunk_end) # sample1_single with having chunk_end higher then answer index
            chunk_end2 = current_idx + sentence_ends_idx[x2] + 1 # sample2_single
            sample2_single.append(chunk_end2)
            divergence = (text_input.reshape(1,-1), sentence_ends_idx, x1, x2)
            break # reached to chunk having answer index so break it
    return sample1_single, sample2_single, divergence


def sample_chunking_with_divergence(sentence_ends, text, answer_idx, model, device):
    with torch.no_grad():
        divergence = None
        count = 0
        while divergence is None:
            sample1_single, sample2_single, divergence = sample_until_answer_divergence(sentence_ends, text, answer_idx, model, greedy=(count==0))

            count += 1
            if count > 50:
                print(f'Failed to sample from {sentence_ends}, {text}, {answer_idx}')
                return


        # complete sample1_single
        current_idx = int(sample1_single[-1])
        while current_idx < len(text):
            text_input = text[current_idx:current_idx+400]
            sentence_ends_input = sentence_ends[current_idx:current_idx+400] # [0,0,0,1,0,0,0,1,0,0,0,1].....
            sentence_ends_idx = torch.nonzero(sentence_ends_input).squeeze(-1)
            logits = model(text_input.reshape(1,-1), [sentence_ends_idx])[0]
            probs = torch.nn.functional.softmax(logits, dim=-1)
            x = torch.argmax(probs, dim=-1)
            # x = torch.multinomial(probs, 1).squeeze()
            chunk_end = current_idx + sentence_ends_idx[x] + 1
            sample1_single.append(chunk_end)
            current_idx = chunk_end


        # complete sample2_single
        current_idx = int(sample2_single[-1])
        while current_idx < len(text):
            text_input = text[current_idx:current_idx+400]
            sentence_ends_input = sentence_ends[current_idx:current_idx+400] # [0,0,0,1,0,0,0,1,0,0,0,1]......
            sentence_ends_idx = torch.nonzero(sentence_ends_input).squeeze(-1)
            logits = model(text_input.reshape(1,-1), [sentence_ends_idx])[0]
            probs = torch.nn.functional.softmax(logits, dim=-1)
            x = torch.argmax(probs, dim=-1)
            # x = torch.multinomial(probs, 1).squeeze()
            chunk_end = current_idx + sentence_ends_idx[x] + 1
            sample2_single.append(chunk_end)
            current_idx = chunk_end

    return sample1_single, sample2_single, divergence

### Cohere embedding creation

In [22]:
co = cohere.Client(f"{userdata.get('CO_API_PAID')}")
embedding_model_id = "multilingual-22-12"
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')

In [23]:
def get_cohere_emb(text_list):
    emb = co.embed(texts= text_list, model=embedding_model_id)
    return emb


def get_samples_emb(samples, title, text, answer_idx):
    org_text = []
    start_index = torch.tensor(0, device=answer_idx.device)
    result_index = None
    # Get raw text
    for sample_index, sample in enumerate(samples):
        chunk_token = text[start_index: sample]
        org_chunk_text = tokenizer.decode(chunk_token)

        # print(len(chunk_token), start_index, sample)
        org_chunk_text = title + " " + org_chunk_text # Add title to org_chunk_text.
        org_text.append(org_chunk_text)

        if start_index < answer_idx < sample: # Assign the answer index.
          result_index = sample_index

        start_index = sample

    # Get embeddings.
    chunk_emb = get_cohere_emb(org_text)
    return org_text, chunk_emb, result_index

def prep_samples_for_emb(samples, title, text, answer_idx):
    org_text = []
    start_index = torch.tensor(0, device=answer_idx.device)
    result_index = None
    # Get raw text
    for sample_index, sample in enumerate(samples):
        chunk_token = text[start_index: sample]
        org_chunk_text = tokenizer.decode(chunk_token)

        # print(len(chunk_token), start_index, sample)
        org_chunk_text = title + " " + org_chunk_text # add title to org_chunk_text.
        org_text.append(org_chunk_text)

        if start_index < answer_idx < sample: # Assign the answer index
          result_index = sample_index

        start_index = sample

    return org_text, result_index

In [24]:
# Functions to get embedding for the samples.

def get_random_wiki(doc_embs, wiki_counts):
  keys_list = list(doc_embs.keys())
  random_keys = random.sample(keys_list, wiki_counts)
  # Access the values using the random keys.
  wiki_emb_list = []
  for key in random_keys:
    value = doc_embs[key]
    value1 = random.choice(value)
    wiki_emb_list.append(value1)
  return wiki_emb_list


def add_wiki_sample_emb(sample_emb, doc_embs, wiki_counts=7):
  mapper = []
  sample_embedding = torch.from_numpy(np.array(sample_emb))
  mapper[0:len(sample_embedding)] = ["Sample"]* len(sample_embedding)
  wiki_embedding = get_random_wiki(doc_embs, wiki_counts)
  wiki_embedding = torch.from_numpy(np.array(wiki_embedding))

  mapper[len(sample_embedding):len(wiki_embedding)] = ["Wiki"]* len(wiki_embedding)

  # Combine both the embeddings.
  all_embedding = torch.cat((sample_embedding, wiki_embedding), dim=0)
  return all_embedding, mapper

## Retrieval and get the raw chunk data

In [25]:
# get all chunks up to and including answer chunk
# tokenize and aggregate to get number of tokens before winning chunk.

def get_chunk_tokenize(top_k, mapper, result_index, sample_raw, wiki_docs, title):
    count = 0
    chunk_list = []
    for index in top_k:
        index = int(index)
        map_location = mapper[index]
        if map_location == "Sample":
            chunk_list.append(sample_raw[index])
        elif map_location == "Wiki":
            # get the wiki_raw data
            wiki_index = index - mapper.index("Wiki")
            wiki_raw = wiki_docs[title][wiki_index]["text"]

            # assert is just to check to make sure we have got the correct wiki_raw
            # assert torch.all(torch.from_numpy(wiki_docs[title][wiki_index]["emb"]) == doc_embeddings[check_index])
            chunk_list.append(wiki_raw)

        if index == result_index:
            # got the chunk having the answer.
            break

    context_text_until_result = '\n'.join(chunk_list)
    chunk_tokens = tokenizer(context_text_until_result, return_tensors="pt")["input_ids"][0]
    return chunk_list, chunk_tokens

## Backpropagation

In [26]:
# After getting preference data, re run point where preference data diverges and train with ranking objective.

def train_pref(divergences, text_mask, preferences, model, loss_fn, opt):

    opt.zero_grad()
    logits_list = model(divergences[0], divergences[1], attention_mask=text_mask) # list of distributions
    distributions = [torch.nn.functional.softmax(l, dim=-1) for l in logits_list]

    x_samples = torch.stack([divergences[2], divergences[3]], dim=-1) # x1 and x2 from the sampled output
    x_samples = x_samples.to(logits_list[0].device)
    logits_stacked, logits_mask = padded_stack(logits_list)
    assert torch.all(torch.gather(logits_mask, -1, x_samples))
    logits_sampled = torch.gather(logits_stacked, -1, x_samples)
    norm_entropies = torch.tensor([Categorical(probs = dist).entropy() / torch.log(torch.tensor(dist.shape[0])) for dist in distributions])
    average_entropy = norm_entropies.mean()
    # print("entropy", entropy)
    # x_probs = probs[x_samples]
    # probs_winner = x_probs[:, winner]
    # probs_loser = x_probs[:, loser]


    # preferences = 1-(preferences * 2) # convert 0->1 1->-1 # MarginRankingLoss format
    # loss = loss_fn(logits_sampled[:,0], logits_sampled[:,1], preferences)
    loss = loss_fn(logits_sampled, preferences)

    loss.backward()
    opt.step()
    return loss, average_entropy


In [27]:
def save_checkpoint(model, optimizer, rundir, epoch, step):
    torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'epoch': epoch,
        'step': step
    }, os.path.join(rundir, f'checkpoint_{epoch}_{step:0{6}}.pth'))

def load_checkpoint(model, optimizer, rundir, step):
    # Load model checkpoint and optimizer checkpoint
    loadfile = os.path.join(rundir, f'checkpoint_{step:0{6}}')
    loaded_checkpoint = torch.load(loadfile)
    model.load_state_dict(loaded_checkpoint['model_state_dict'])
    optimizer.load_state_dict(loaded_checkpoint['optimizer_state_dict'])
    print(f'Checkpoint loaded from {loadfile}')

In [28]:
def padded_stack(sequences):
    # Pad sequences to the maximum length
    padded_sequences = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True, padding_value=-1)

    # Create a binary mask indicating valid elements
    mask = (padded_sequences != -1)
    padded_sequences[padded_sequences == -1] = 0

    return padded_sequences, mask

In [29]:
class ChunkerModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = T5EncoderModel.from_pretrained('google/flan-t5-base')
        self.classifier = torch.nn.Linear(768,1)

    def forward(self, input_ids, sentence_ends_idx, attention_mask=None):
        """
            input_ids: B x seq_length(400)
            sentence_dneds_idx: List[Tensor], B tensors of shape #sentence_ends
        """

        embeds = self.encoder(input_ids, attention_mask=attention_mask)['last_hidden_state']

        sentence_ends_embeds = [emb[sentence_ends_idx[i]] for i, emb in enumerate(embeds)]
        # sentence_ends_embeds = embeds[0][sentence_ends_idx]
        preds = [self.classifier(sentence_ends_emb).squeeze(-1) for sentence_ends_emb in sentence_ends_embeds]
        # try:
        #   preds = torch.tensor(preds)
        # except ValueError as e:
        #   # for train pref
        #   preds = torch.stack(preds)
        return preds


def set_seeds(seed: int=1234):
    """Sets random sets for torch operations.

    Args:
        seed (int, optional): Random seed to set. Defaults to 42.
    """
    # Set the seed for general torch operations
    torch.manual_seed(seed)
    # Set the seed for CUDA torch operations (ones that happen on the GPU)
    torch.cuda.manual_seed(seed)


PREFERENCE_DATA_DIR = '/content/drive/MyDrive/Research: Elan<>Krupa/preference_data'
def save_preference_data(chunk, chunk_count):
    current = chunk_count
    filepath = os.path.join(PREFERENCE_DATA_DIR, f'{current:0{6}}.pth')
    # print(f'saving to {filepath}')
    torch.save(chunk, filepath)

# def get_num_chunks_cached():
#     chunkfilelist = os.listdir(PREFERENCE_DATA_DIR)
#     chunklist = [int(os.path.splitext(f)[0]) for f in chunkfilelist if os.path.splitext(f)[1] == '.pth']
#     if chunklist:
#         current = max(chunklist) + 1
#     else:
#         current = 0
#     return current

def load_preference_data(current):
    filepath = os.path.join(PREFERENCE_DATA_DIR, f'{current:0{6}}.pth')
    chunk = torch.load(filepath)
    return chunk


def main():
    RUNDIR = '/content/drive/MyDrive/Research: Elan<>Krupa/run_testing'
    LR = 0.01
    BATCH_SIZE = 12 #24
    EPOCHS = 1
    hparams = {
        'lr': LR,
        'batch': BATCH_SIZE,
        'epochs': EPOCHS
    }
    set_seeds(1234)

    wandb.init(
    # set the wandb project where this run will be logged
    project="smart_chunker_8_25",
    # track hyperparameters and run metadata
    config=hparams
    )

    LOSS_WINDOW = 20
    CACHE_PREFERENCE_DATA = False

    losses = []

    dataset = ListDataset(data_ready_flat)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collater)

    model = ChunkerModel()
    device = "cuda"
    model.to(device)
    tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-base')
    opt = torch.optim.AdamW(model.parameters(), lr=LR)
    loss_fn = torch.nn.CrossEntropyLoss()
    # loss_fn = torch.nn.MarginRankingLoss(margin=1.0)

    chunk_count = -1 # start at -1 so that it immediately increments to zero

    BATCH = None

    for epoch in range(EPOCHS):
        for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataloader)):
            # if BATCH is None:
            #     BATCH = batch
            # else:
            #     batch = BATCH
            chunking_samples = []
            for question, sentence_ends, text, title, answer_idx in zip(*batch):
                chunk_count += 1
                if CACHE_PREFERENCE_DATA:
                    try:
                        chunk_sample = load_preference_data(chunk_count)
                        chunking_samples.append(chunk_sample)
                        continue
                    except FileNotFoundError:
                        pass

                if title.replace("_", " ") not in doc_embs: # Extra step for missing docs.
                    continue
                sentence_ends, text, answer_idx = sentence_ends.to(device), text.to(device), answer_idx.to(device)
                sample1_single, sample2_single, divergence = sample_chunking_with_divergence(sentence_ends, text, answer_idx, model, device)


                chunk_sample = {"title": title,
                                        "text":  text,
                                        "sample1": sample1_single,
                                        "sample2": sample2_single,
                                        "answer_idx": answer_idx,
                                        "divergence": divergence,
                                        "question_raw": question}

                text_sample = chunk_sample["text"]
                title_sample = chunk_sample["title"]
                title_sample = title_sample.replace("_", " ") # Extra step to get match for wiki data title
                sample1_sample = chunk_sample["sample1"]
                sample2_sample = chunk_sample["sample2"]
                answer_idx_single  = chunk_sample["answer_idx"]
                divergence_sample = chunk_sample["divergence"]
                question_sample= chunk_sample["question_raw"]

                sample1_raw, result_index1 = prep_samples_for_emb(sample1_sample, title_sample,  text_sample, answer_idx_single)
                sample2_raw, result_index2 = prep_samples_for_emb(sample2_sample, title_sample,  text_sample, answer_idx_single)
                to_emb = sample1_raw + sample2_raw + [title_sample + " "+ question_sample]
                embs = get_cohere_emb(to_emb)
                question_emb = embs.embeddings[-1:]
                sample1_emb = embs.embeddings[0:len(sample1_raw)]
                sample2_emb = embs.embeddings[len(sample1_raw):-1]
                assert len(sample2_emb) == len(sample2_raw)

                # wiki_emb = doc_embs[title_sample]
                sample1_all_emb, mapper1 = add_wiki_sample_emb(sample1_emb, doc_embs)
                sample2_all_emb, mapper2 = add_wiki_sample_emb(sample2_emb, doc_embs)
                q_embedding = torch.tensor(np.array(question_emb))

                # Compute dot score between query embedding and document embeddings.
                with torch.no_grad():
                    dot_scores1 = torch.mm(q_embedding, sample1_all_emb.transpose(0,1))
                    top_k1 = torch.topk(dot_scores1, k=min(10, dot_scores1.shape[-1])).indices.flatten()
                    dot_scores2 = torch.mm(q_embedding, sample2_all_emb.transpose(0,1))
                    top_k2 = torch.topk(dot_scores2, k=min(10, dot_scores2.shape[-1])).indices.flatten()

                    # identify the number of tokens until correct document (inclusive).
                    chunk_list1, chunk_tokens1 = get_chunk_tokenize(top_k1, mapper1, result_index1, sample1_raw, wiki_docs, title_sample)
                    chunk_list2, chunk_tokens2 = get_chunk_tokenize(top_k2, mapper2, result_index2, sample2_raw, wiki_docs, title_sample)


                # print(f'average tokens to answer: {len(chunk_tokens1)}')
                chunk_sample['tokens_to_answer1'] = len(chunk_tokens1)
                chunk_sample['tokens_to_answer2'] = len(chunk_tokens2)

                # compare for sample 1 and sample 2 to get preference data.
                if len(chunk_tokens1) < len(chunk_tokens2):
                    chunk_sample['preference'] = 0
                else:
                    chunk_sample['preference'] = 1

                chunking_samples.append(chunk_sample)
                if CACHE_PREFERENCE_DATA:
                    if len(chunk_tokens1) == len(chunk_tokens2):
                        print('No difference deteceted, skipping save')
                    else:
                        # save preference data.
                        save_preference_data(chunk_sample, chunk_count)

            average_tokens_to_ans1 = sum([chunk["tokens_to_answer1"] for chunk in chunking_samples])/ len(chunking_samples)
            average_tokens_to_ans2 = sum([chunk["tokens_to_answer2"] for chunk in chunking_samples]) / len(chunking_samples)
            print("average_tokens_to_ans1", average_tokens_to_ans1)
            print("average_tokens_to_ans2", average_tokens_to_ans2)
            divergences_text, divergences_text_mask = padded_stack([chunk['divergence'][0].flatten() for chunk in chunking_samples])
            print(f'Effective batch: {len(divergences_text)}')
            divergences_sentence_ends_idx, divergences_sentence_ends_idx_mask = padded_stack([chunk['divergence'][1] for chunk in chunking_samples])
            divergences_x1 = torch.tensor([chunk['divergence'][2] for chunk in chunking_samples])
            divergences_x2 = torch.tensor([chunk['divergence'][3] for chunk in chunking_samples])
            divergences_batched = (divergences_text, divergences_sentence_ends_idx, divergences_x1, divergences_x2)
            preferences_batched = torch.tensor([chunk['preference'] for chunk in chunking_samples], device=device)
            loss, sampling_entropy = train_pref(divergences_batched, divergences_text_mask, preferences_batched, model, loss_fn, opt)
            losses.append(loss.detach().cpu().numpy())
            losses = losses[-LOSS_WINDOW:]
            print(f'loss: {np.mean(losses)}')
            wandb.log({"loss": np.mean(losses), "average_tokens_to_ans1": average_tokens_to_ans1, "average_tokens_to_ans2": average_tokens_to_ans2,
                       "sampling_entropy_normalized": sampling_entropy})

            if i % 500 == 0:
                save_checkpoint(model, opt, RUNDIR, epoch, i)
                print('Checkpoint saved')

        save_checkpoint(model, opt, RUNDIR, epoch, i)
        print('Checkpoint saved')
        wandb.finish()


# Main Process

In [30]:
main()

VBox(children=(Label(value='0.001 MB of 0.012 MB uploaded\r'), FloatProgress(value=0.10300925925925926, max=1.…

  0%|          | 0/835 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (576 > 512). Running this sequence through the model will result in indexing errors


average_tokens_to_ans1 358.0
average_tokens_to_ans2 233.58333333333334
Effective batch: 12
loss: 0.6933563351631165


  0%|          | 1/835 [00:38<8:55:41, 38.54s/it]

Checkpoint saved
average_tokens_to_ans1 250.63636363636363
average_tokens_to_ans2 310.3636363636364
Effective batch: 11


  0%|          | 2/835 [00:47<4:51:35, 21.00s/it]

loss: 0.693266749382019
average_tokens_to_ans1 379.8
average_tokens_to_ans2 346.9
Effective batch: 10


  0%|          | 3/835 [00:56<3:35:43, 15.56s/it]

loss: 0.693255603313446
average_tokens_to_ans1 343.90909090909093
average_tokens_to_ans2 284.90909090909093
Effective batch: 11


  0%|          | 4/835 [01:08<3:14:35, 14.05s/it]

loss: 0.693230390548706
average_tokens_to_ans1 231.45454545454547
average_tokens_to_ans2 216.0
Effective batch: 11


  1%|          | 5/835 [01:19<2:59:19, 12.96s/it]

loss: 0.6931735277175903
average_tokens_to_ans1 290.44444444444446
average_tokens_to_ans2 365.1111111111111
Effective batch: 9


  1%|          | 6/835 [01:28<2:40:16, 11.60s/it]

loss: 0.6931214332580566
average_tokens_to_ans1 311.0
average_tokens_to_ans2 257.6363636363636
Effective batch: 11


  1%|          | 7/835 [01:39<2:38:57, 11.52s/it]

loss: 0.6931212544441223
average_tokens_to_ans1 188.75
average_tokens_to_ans2 200.125
Effective batch: 8


  1%|          | 8/835 [01:51<2:40:29, 11.64s/it]

loss: 0.6930993795394897
average_tokens_to_ans1 258.0
average_tokens_to_ans2 259.8333333333333
Effective batch: 12


  1%|          | 9/835 [02:01<2:34:26, 11.22s/it]

loss: 0.693077027797699
average_tokens_to_ans1 297.25
average_tokens_to_ans2 337.5833333333333
Effective batch: 12


  1%|          | 10/835 [02:11<2:30:32, 10.95s/it]

loss: 0.6929630041122437
average_tokens_to_ans1 263.2857142857143
average_tokens_to_ans2 307.57142857142856
Effective batch: 7


  1%|▏         | 11/835 [02:18<2:10:21,  9.49s/it]

loss: 0.6925337314605713
average_tokens_to_ans1 301.44444444444446
average_tokens_to_ans2 315.0
Effective batch: 9


  1%|▏         | 12/835 [02:25<2:00:32,  8.79s/it]

loss: 0.6924245357513428
average_tokens_to_ans1 257.3333333333333
average_tokens_to_ans2 291.22222222222223
Effective batch: 9


  2%|▏         | 13/835 [02:33<1:56:54,  8.53s/it]

loss: 0.6921459436416626
average_tokens_to_ans1 252.9090909090909
average_tokens_to_ans2 313.27272727272725
Effective batch: 11


  2%|▏         | 14/835 [02:44<2:05:51,  9.20s/it]

loss: 0.6919243931770325
average_tokens_to_ans1 232.83333333333334
average_tokens_to_ans2 193.08333333333334
Effective batch: 12


  2%|▏         | 15/835 [02:55<2:15:22,  9.91s/it]

loss: 0.6919695138931274
average_tokens_to_ans1 273.09090909090907
average_tokens_to_ans2 259.09090909090907
Effective batch: 11


  2%|▏         | 16/835 [03:06<2:20:04, 10.26s/it]

loss: 0.6918067336082458


  2%|▏         | 16/835 [03:08<2:40:25, 11.75s/it]


KeyboardInterrupt: 