In [1]:
import os
import time
import copy
import random

import torch
import numpy as np
import pandas as pd

from sklearn.model_selection import train_test_split

from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.nn.functional import one_hot

from transformers import AutoTokenizer, AutoModel, AutoModelForSeq2SeqLM

from IPython.display import HTML, display

## Workspace Setup

In [2]:
# ----------------------- #
# --- Workspace Setup --- #
# ----------------------- #

# set project directory
# -----------------------
project_folder = 'drive/MyDrive/Datasci266/w266_project/'
os.chdir(project_folder)

# set device
# -----------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# wrap cell outputs
# -----------------------
def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

## Load Raw Training Data

In [3]:
# -------------------------- #
# --- Load Training Data --- #
# -------------------------- #

df = pd.read_pickle('data/train_filtered.pkl')
df.head()

Unnamed: 0,article,abstract,section_names
0,with significant research efforts being direct...,synaptic memory is considered to be the main ...,introduction\nformalism\nresults and discussio...
1,the open connectome project ( located at http:...,"* _ abstract _ : * in this paper , we present...",introduction\nprocedure\nresults and future work
2,i am grateful to alekos kechris for informing ...,we describe the fundamental constructions and...,acknowledgments
3,set theory was proposed with the intended use ...,"recently , a multi - level fuzzy min max neur...",introduction\nmulti-level fuzzy min-max neural...
4,this work is financially supported by the nati...,we review the direct cp and t violation in th...,acknowledgements


In [4]:
# Define Training and Validation Datasets
# --------------------------------------
print('Size of dataset: ', len(df))

# --- index for validation dataset --- #
np.random.seed(35)
size = .1
val_n = round(len(df)*size)

val_index = np.random.randint(0, len(df), val_n)
val_index

# --- training dataset --- #
df_train = df.drop(val_index)
df_train = df_train['abstract'].to_list()
print('Size of training data: ', len(df_train))
print('Duplicate each training sentence: ')
df_train = [item for item in df_train for _ in range(10)]
print('Size of training data: ', len(df_train))


# --- validation dataset --- #
df_val = df.iloc[val_index]
df_val = df_val['abstract'].to_list()
print('Size of validation data: ', len(df_val))
print('Duplicate each validation sentence: ')
df_val = [item for item in df_val for _ in range(10)]
print('Size of training data: ', len(df_val))


Size of dataset:  2418
Size of training data:  2188
Duplicate each training sentence: 
Size of training data:  21880
Size of validation data:  242
Duplicate each validation sentence: 
Size of training data:  2420


In [5]:
# Test examples
# ------------------

test_text = """
Neurons are the main components of nervous tissue in all animals
except sponges and Placozoa. Non-animals like plants and fungi do
not have nerve cells. Molecular evidence suggests that the ability to
generate electric signals first appeared in evolution some 700 to 800
million years ago, during the Tonian period. Predecessors of neurons
were the peptidergic secretory cells. They eventually gained new gene
modules which enabled cells to create post-synaptic scaffolds and ion
channels that generate fast electrical signals. The ability to generate
electric signals was a key innovation in the evolution of the nervous
system.
"""

test_abstract = """
Human frontocentral event-related potentials (FC-ERPs) are ubiquitous
neural correlates of cognition and control, but their generating
multiscale mechanisms remain mostly unknown. We used the Human
Neocortical Neurosolver(HNN)’s biophysical model of a canonical
neocortical circuit under exogenous thalamic and cortical drive to
simulate the cell and circuit mechanisms underpinning the P2, N2, and
P3 features of the FC-ERP observed after Stop-Signals in the
Stop-Signal task (SST). We demonstrate that a sequence of simulated
external thalamocortical and cortico-cortical drives can produce the
FC-ERP, similar to what has been shown for primary sensory cortices.
We used this model of the FC-ERP to examine likely circuit-mechanisms
underlying FC-ERP features that distinguish between successful and
failed action-stopping. We also tested their adherence to the
predictions of the horse-race model of the SST, with specific
hypotheses motivated by theoretical links between the P3 and Stop
process. These simulations revealed that a difference in P3 onset
between successful and failed Stops is most likely due to a later
arrival of thalamocortical drive in failed Stops, rather than, for
example, a difference in effective strength of the input. In contrast,
the same model predicted that early thalamocortical drives underpinning
the P2 and N2 differed in both strength and timing across stopping
accuracy conditions. Overall, this model generates novel testable
predictions of the thalamocortical dynamics underlying FC-ERP
generation during action-stopping. Moreover, it provides a detailed
cellular and circuit-level interpretation that supports links between
these macroscale signatures and predictions of the behavioral race
model. Significance statement The frontocentral event-related potential
(FC-ERP) is an easily-measurable neural correlate of cognition and
control. However, the cortical dynamics that produce this signature in
humans are complex, limiting the ability of researchers to make
predictions about its underlying mechanisms. In this study, we used the
biophysical model included in the open-source Human Neocortical
Neurosolver software to simulate and evaluate the likely cellular and
circuit mechanisms that underlie the FC-ERP in the Stop-Signal task. We
modeled mechanisms of the FC-ERP during successful and unsuccessful
stopping, generating testable predictions regarding Stop-associated
computations in human frontal cortex. Moreover, the resulting model
parameters provide a starting point for simulating mechanisms of the
FC-ERP and other frontal scalp EEG signatures in other task conditions
and contexts.
"""

# MLM Training Pipeline

In [6]:
# ------------------------- #
# --- Training Pipeline --- #
# ------------------------- #

from transformers import BartForSequenceClassification, BartTokenizer, AdamW, BartForConditionalGeneration
import torch
from torch import nn

In [7]:
# model and tokenizer
# ------------------

model_name = "facebook/bart-base"

tokenizer = BartTokenizer.from_pretrained(
    model_name
)

model = BartForConditionalGeneration.from_pretrained(
    model_name
)

model.to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.72k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/558M [00:00<?, ?B/s]

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50265, 768, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50265, 768, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 768)
      (layers): ModuleList(
        (0-5): 6 x BartEncoderLayer(
          (self_attn): BartSdpaAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (activation_fn): GELUActivation()
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), 

In [8]:
run_test = False

In [9]:
if run_test:
    try:
        test_tokenized_data
    except:
        test_tokenized_data = False

    if not test_tokenized_data:
        test_tokenized_data = tokenizer(
            df_train,
            max_length=512,
            truncation=True,
            padding='max_length',
            return_tensors="pt",
        )

In [10]:
# define Datast class
# ------------------

class ContinuedPretrainData(Dataset):
    def __init__(self, base_data, tokenizer, device, max_len=512):
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.data = []
        self.labels = []
        self.attention_mask = []

        tokenized_examples = tokenizer(
            df_train,
            max_length=512,
            truncation=True,
            padding='max_length',
            return_tensors="pt"
        )

        input_ids = tokenized_examples['input_ids']

        # mask tokens
        mask_indices = torch.rand(input_ids.shape) < 0.15  # 15% probability

        # un-mask padding tokens
        mask_indices[tokenized_examples['attention_mask'] == 0] = False
        mask_indices[:, 0] = False # un-mask first token in sequence
        # mask_indices[input_ids == 0] = False  # un-mask 0 (bos) tokens
        # mask_indices[input_ids == 1] = False  # un-mask 1 (pad) tokens
        mask_indices[input_ids == 2] = False  # un-mask 2 (eos) tokens

        masked_tokens = input_ids.clone()
        masked_tokens[mask_indices] = tokenizer.mask_token_id
        masked_tokens = torch.tensor(masked_tokens)

        # Generate labels from masked tokens
        labels = input_ids.clone()
        labels[~mask_indices] = -100
        labels[mask_indices] = input_ids[mask_indices]
        labels = torch.tensor(labels)

        # send data to device
        self.data = masked_tokens.to(device)
        self.labels = labels.to(device)
        self.attention_mask = tokenized_examples['attention_mask'].to(device)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        return {'inputs': self.data[index],
                'labels': self.labels[index],
                'attention_mask': self.attention_mask[index],
        }

In [11]:
%%capture

# initialize datasets
train_dataset = ContinuedPretrainData(
    df_train,
    tokenizer,
    device
)

val_dataset = ContinuedPretrainData(
    df_val,
    tokenizer,
    device
)

In [12]:
batch_size = 4

# initialize dataloaders
train_dataloader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False
)

In [13]:
print(
    'Masked token ID:\t',
    tokenizer.mask_token_id,
    '\n'
    'Length of Dataset:\t',
    len(train_dataset),
)

Masked token ID:	 50264 
Length of Dataset:	 21880


In [14]:
print(
    train_dataset[0]['inputs'][0:5],
    '\n',
    train_dataset[0]['labels'][0:5],
    sep='',
)

tensor([    0, 47621,  3783,    16,  1687], device='cuda:0')
tensor([-100, -100, -100, -100, -100], device='cuda:0')


In [15]:
# next(iter(train_dataloader))

In [16]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.00001
)

reinitialize_base_model = False

if reinitialize_base_model:
    model = BartForConditionalGeneration.from_pretrained(
        "facebook/bart-base"
    )
    model.to(device)

def continued_train_loop(
        train_dataloader,
        val_dataloader,
        model,
        optimizer,
        reporting_interval=50,
        save_state=False,
        out_directory='states/',
        out_fname='mlm_model_state',
    ) -> dict:
    """
    Complete a single pass through the training dataloader
    """

    # --- Training Loop --- #
    model.train()  # training mode
    train_epoch_loss = 0
    val_epoch_loss = 0
    train_acc_history = []  # store training accuracies for each batch
    train_loss_history = []  # store losses for each batch
    val_acc_history = {'correct':0, 'total':0}  # store validation accuracies for each batch

    for batch, example in enumerate(train_dataloader):

        inputs = {
            "input_ids": example["inputs"],
            "labels": example["labels"],
            "attention_mask": example["attention_mask"],
        }

        # forward pass
        outputs = model(**inputs)

        # get loss
        loss = outputs.loss

        # backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_epoch_loss += loss.item()

        # calculate accuracy on masked tokens
        mask = (inputs["input_ids"] == tokenizer.mask_token_id)  # identify masked tokens
        masked_logits = outputs.logits[mask]
        predictions = torch.argmax(masked_logits, dim=-1)  # get predicted token index
        flat_labels = inputs["labels"][mask]  # filter labels on masked tokens
        correct = (predictions == flat_labels).sum().item()
        train_acc = correct / len(predictions)  # calculate accuracy
        train_acc = round(train_acc, 5)

        # training progress
        if int(batch + 1) % reporting_interval == 0:
            print(f'\tFinished batches: {str(batch + 1)}')
            print(f'\tTrain Loss: {round(loss.item(), 4)}')
            print(f'\tTrain Acc: {train_acc}')

        train_acc_history.append(train_acc)
        train_loss_history.append(loss.item())

    # Print average training loss
    avg_train_loss = round(train_epoch_loss / len(train_loss_history), 5)
    print(f"Average training loss: {avg_train_loss}")

    # --- Validation Loop --- #
    model.eval()  # evaluation mode
    with torch.no_grad():  # disable gradient calculation

        for batch, example in enumerate(val_dataloader):
            inputs = {
                "input_ids": example["inputs"],
                "labels": example["labels"],
                "attention_mask": example["attention_mask"],
            }
            outputs = model(**inputs)
            loss = outputs.loss

            val_epoch_loss += loss.item()

            # calculate accuracy
            mask = (inputs["input_ids"] == tokenizer.mask_token_id)
            masked_logits = outputs.logits[mask]
            predictions = torch.argmax(masked_logits, dim=-1)
            flat_labels = inputs["labels"][mask]
            correct = (predictions == flat_labels).sum().item()
            # val_acc = correct / len(predictions)
            # val_acc = round(val_acc, 5)

            val_acc_history['correct']+=correct
            val_acc_history['total']+=len(predictions)

    # Print training and validation results
    # avg_val_loss = round(val_epoch_loss / len(val_dataloader), 5)
    # print(f"Average validation loss: {avg_val_loss}")
    avg_val_loss = round(val_epoch_loss / batch, 5)
    print(f"Average validation loss: {avg_val_loss}")
    overall_val_acc = round(val_acc_history['correct'] / val_acc_history['total'], 5)


    # Print final accuracies
    print(f"Final Train Accuracy: {round(train_acc, 5)}")
    print(f"Validation Accuracy: {overall_val_acc}")

    if save_state:
        # torch.save(model.state_dict(), 'states/model_state.pth')
        model.save_pretrained(out_directory+out_fname)

    # Update return dictionary
    return {
        'train_loss_history': train_loss_history,
        'train_acc_history': train_acc_history,
        'val_acc': overall_val_acc,
        'val_loss': avg_val_loss,
    }


In [19]:
run_single_epoch = False

if run_single_epoch:
    model_results = continued_train_loop(
            train_dataloader,
            val_dataloader,
            model,
            optimizer,
            save_state=True
        )

	Finished batches: 50
	Train Loss: 6.9703
	Train Acc: 0.06818
	Finished batches: 100
	Train Loss: 5.7556
	Train Acc: 0.14815
	Finished batches: 150
	Train Loss: 6.7847
	Train Acc: 0.0625
	Finished batches: 200
	Train Loss: 6.2932
	Train Acc: 0.11735
	Finished batches: 250
	Train Loss: 6.1768
	Train Acc: 0.11554
	Finished batches: 300
	Train Loss: 5.7844
	Train Acc: 0.11976
	Finished batches: 350
	Train Loss: 5.5548
	Train Acc: 0.14583
	Finished batches: 400
	Train Loss: 5.7066
	Train Acc: 0.13171
	Finished batches: 450
	Train Loss: 5.5744
	Train Acc: 0.20979
	Finished batches: 500
	Train Loss: 5.5584
	Train Acc: 0.14706
	Finished batches: 550
	Train Loss: 5.2021
	Train Acc: 0.22
	Finished batches: 600
	Train Loss: 5.3196
	Train Acc: 0.24031
	Finished batches: 650
	Train Loss: 4.9804
	Train Acc: 0.26056
	Finished batches: 700
	Train Loss: 4.9782
	Train Acc: 0.22907
	Finished batches: 750
	Train Loss: 5.0407
	Train Acc: 0.24521
	Finished batches: 800
	Train Loss: 5.0173
	Train Acc: 0.190

Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


Average validation loss: 2.36896
Average validation loss: 2.3694
Final Train Accuracy: 0.48598
Validation Accuracy: 0.56384


In [18]:
assert False

AssertionError: 

# Training Loop

## Setup

In [20]:
# --- Load Previous Epoch Results --- #
all_epoch_results = pd.read_pickle(
    'states/mlm_state_checkpoints/all_epoch_results.pkl'
)

all_epoch_results

In [None]:
# --- Setup for Training Loop --- #

epoch_results = {}
start_epoch = 0
num_epochs = 10

save_all = True
load_checkpoint = True

# --- Filepaths --- #

checkpoint_path = 'states/mlm_state_checkpoints/'

if save_all:
    out_folder = 'states/mlm_state_checkpoints/'
else:
    out_folder = 'states/'

# --- Get Last Checkpoint Model --- #
folder_names = [
    name for name in os.listdir('states/mlm_state_checkpoints/')
    if os.path.isdir(os.path.join('states/mlm_state_checkpoints/', name))
]
previous_runs = [
    int(file.split('_')[0]) for file in folder_names if file.split('_')[0].isdigit()
]

if load_checkpoint:
    checkpoint_fname = [
        name for name in folder_names
        if name.split('_')[0] == str(max(previous_runs))
    ]
    checkpoint_fname = checkpoint_fname[0]

    print(f'Loading checkpoint from folder: {checkpoint_fname}')

    model = BartForConditionalGeneration.from_pretrained(
        checkpoint_path+checkpoint_fname,
    )
    model.to(device)
    start_epoch = max(previous_runs)

    print(f'Next loop will start at epoch {start_epoch+1}')


## Run Training Loop

In [None]:
for epoch in range(start_epoch, num_epochs):
    run = epoch + 1
    print(f'Running epoch {run}')
    print('Time:',time.strftime("%H:%M"))

    if save_all:
        save_run = True
    else:
        if run == num_epochs:
            save_run = True
        else:
            save_run = False

    out_name = f'{run}_model_state'

    epoch_results[run] = continued_train_loop(
        train_dataloader,
        val_dataloader,
        model,
        optimizer,
        save_state=save_run,
        out_fname=out_name,
        out_directory=out_folder,
    )

    interim_epoch_results = pd.DataFrame(epoch_results).T
    interim_epoch_results.to_pickle(
        out_folder+'interim_epoch_results.pkl'
    )

print('Time:',time.strftime("%H:%M"))

## Save epoch Results

In [None]:
assert False

In [None]:
rewrite_results = False

# average values for histories for training acc / training loss
def avg_list(row):
    return sum(row)/len(row)

interim_epoch_results['avg_train_acc'] = interim_epoch_results['train_acc_history'].apply(avg_list)
interim_epoch_results['avg_train_loss'] = interim_epoch_results['train_loss_history'].apply(avg_list)

if rewrite_results:
    interim_epoch_results.to_pickle('states/mlm_state_checkpoints/epoch_results.pkl')
else:
    all_epoch_results = pd.read_pickle('states/mlm_state_checkpoints/all_epoch_results.pkl')
    all_epoch_results = pd.concat([all_epoch_results, interim_epoch_results])
    all_epoch_results.to_pickle('states/mlm_state_checkpoints/all_epoch_results.pkl')

all_epoch_results[
    ['val_acc', 'val_loss', 'avg_train_acc', 'avg_train_loss']
]

# End

# Troubleshooting

In [None]:
assert False

In [None]:
# # --- old method --- #

# # filter out short examples
# test_input_ids = test_tokenized_data['input_ids'][test_tokenized_data['attention_mask'][:, 512 - 1] > 0]

# # mask tokens
# test_mask_indices = torch.rand(test_input_ids.shape) < 0.15  # 15% probability
# test_mask_indices[:, 0] = False  # ensure first token is not masked
# test_mask_indices[:, -1] = False  # ensure last token is not masked
# masked_tokens = test_input_ids.clone()
# masked_tokens[test_mask_indices] = tokenizer.mask_token_id
# masked_tokens = torch.tensor(masked_tokens)

# # generate test_labels from masked tokens
# test_labels = test_input_ids.clone()
# test_labels[~test_mask_indices] = -100
# test_labels[test_mask_indices] = test_input_ids[test_mask_indices]
# test_labels = torch.tensor(test_labels)

# --------------------------------------------------------------------------- #
# --------------------------------------------------------------------------- #

# --- new method --- #

# Keep all examples (no filtering)
test_input_ids = test_tokenized_examples['input_ids']

# Mask tokens with conditions
test_mask_indices = torch.rand(test_input_ids.shape) < 0.15  # 15% probability
test_mask_indices[test_tokenized_examples['attention_mask'] == 0] = False  # Exclude padding tokens
test_mask_indices[:, 0] = False  # Ensure first token is not masked
# test_mask_indices[test_input_ids == 0] = False  # Exclude tokens with value 0
# test_mask_indices[test_input_ids == 1] = False  # Exclude tokens with value 1
test_mask_indices[test_input_ids == 2] = False  # Exclude tokens with value 2

masked_tokens = test_input_ids.clone()
masked_tokens[test_mask_indices] = tokenizer.mask_token_id
masked_tokens = torch.tensor(masked_tokens)

# Generate test_labels from masked tokens
test_labels = test_input_ids.clone()
test_labels[~test_mask_indices] = -100
test_labels[test_mask_indices] = test_input_ids[test_mask_indices]
test_labels = torch.tensor(test_labels)

In [None]:
print(
    tokenizer.eos_token,
    tokenizer('<s>Hi</s>'),
    tokenizer('Hi'),
    sep='\n'
)

In [None]:
# test example
# -> https://huggingface.co/transformers/v3.0.2/model_doc/bart.html#bartforsequenceclassification

tmp_inputs = tokenizer("Hello, my dog is cute", return_tensors="pt").to(device)
tmp_labels = torch.tensor([1]).unsqueeze(0).to(device)  # Batch size 1

outputs = model(**tmp_inputs, labels=tmp_labels)
loss, logits = outputs[:2]

print(
    tmp_inputs,
    tmp_labels,
    loss,
    sep='\n\n'
)

In [None]:
single_batch = next(iter(dataloader))
display(single_batch)

In [None]:
outputs = model(input_ids=single_batch['inputs'], labels=single_batch['labels'])
loss = outputs.loss  # Access the loss output
logits = outputs.logits  # Access the logits output

print(loss)

In [None]:
assert False

In [None]:
def continued_train_loop(
        train_dataloader,
        val_dataloader,
        model,
        optimizer,
        reporting_interval=50,
        save_state=False
    ) -> dict:

    # --- Training Loop --- #
    model.train() # training mode
    train_epoch_loss = 0
    val_epoch_loss = 0

    for batch, example in enumerate(train_dataloader):

        inputs = {
            "input_ids": example["inputs"],
            "labels": example["labels"],
            "attention_mask": example["attention_mask"],
        }

        # forward pass
        outputs = model(**inputs)

        # get loss
        # loss = nn.functional.masked_cross_entropy_loss(outputs.logits, inputs["labels"])
        # loss = nn.CrossEntropyLoss()(outputs.logits, inputs["labels"])
        # mask_positions = batch['mask_positions']
        loss = outputs.loss

        # backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_epoch_loss += loss.item()

        # training progress
        if int(batch + 1) % reporting_interval == 0:
            print('\tFinished batches: ', str(batch + 1))
            print('\tCurrent average loss: ', train_epoch_loss/batch)

    # --- Validation Loop --- #
    model.eval()  # evaluation mode
    with torch.no_grad():  # disable gradient calculation

        for batch, example in enumerate(val_dataloader):
            inputs = {
                "input_ids": example["inputs"],
                "labels": example["labels"],
                "attention_mask": example["attention_mask"],
            }
            outputs = model(**inputs)
            loss = outputs.loss

            val_epoch_loss += loss.item()

    # Print training and validation loss
    print(f"Average training loss: {round(train_epoch_loss/batch, 5)}")
    print(f"Average validation loss: {round(val_epoch_loss / len(val_dataloader), 5)}")

    if save_state:
        torch.save(model.state_dict(), 'states/model_state.pth')

    return {'state':model.state_dict(), }

def continued_train_loop(
        train_dataloader,
        val_dataloader,
        model,
        optimizer,
        reporting_interval=50,
        save_state=False
    ) -> dict:

    # --- Training Loop --- #
    model.train()  # training mode
    train_epoch_loss = 0
    val_epoch_loss = 0
    val_correct = 0
    train_acc_history = []  # store training accuracies for each batch

    loss_history = []  # store losses for each batch

    for batch, example in enumerate(train_dataloader):

        inputs = {
            "input_ids": example["inputs"],
            "labels": example["labels"],
            "attention_mask": example["attention_mask"],
        }

        # forward pass
        outputs = model(**inputs)

        # get loss
        loss = outputs.loss
        loss_history.append(loss.item())  # store batch loss

        # backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        train_epoch_loss += loss.item()

        # get indices of masked tokens
        mask = (inputs["input_ids"] == tokenizer.mask_token_id)  # identify masked tokens

        # get logits and predictions for masked tokens
        masked_logits = outputs.logits[mask]
        predictions = torch.argmax(masked_logits, dim=-1)  # get predicted token index

        # filter labels on masked tokens
        flat_labels = inputs["labels"][mask]

        # calculate accuracy
        correct = (predictions == flat_labels).sum().item()
        train_acc = correct / len(predictions)

        # training progress
        if int(batch + 1) % reporting_interval == 0:
            print(f'\tFinished batches: {str(batch + 1)} | Train Loss: {loss.item():.5f} | Train Acc: {train_acc:.5f}')

        train_acc_history.append(train_acc)

    # --- Validation Loop --- #
    model.eval()  # evaluation mode
    with torch.no_grad():  # disable gradient calculation

        for batch, example in enumerate(val_dataloader):
            inputs = {
                "input_ids": example["inputs"],
                "labels": example["labels"],
                "attention_mask": example["attention_mask"],  # include attention mask for validation
            }
            outputs = model(**inputs)
            loss = outputs.loss

            val_epoch_loss += loss.item()

            # calculate accuracy (validation loop)
            mask = (inputs["input_ids"] == tokenizer.mask_token_id)
            masked_logits = outputs.logits[mask]
            predictions = torch.argmax(masked_logits, dim=-1)

            flat_labels = inputs["labels"].view(-1)
            masked_labels = flat_labels[mask]

            val_correct += (predictions == masked_labels).sum().item()

    # Print training and validation results
    train_acc = train_correct / (len(train_dataloader.dataset) * mask.sum().item())  # consider only masked tokens (final calculation)
    val_acc = val_correct / (len(val_dataloader.dataset) * mask.sum().item())  # consider only masked tokens (final calculation)
    print(f"Average training loss: {round(train_epoch_loss/batch, 5)} | Train Accuracy: {train_acc:.5f}")
    print(f"Average validation loss: {round(val_epoch_loss / len(val_dataloader), 5)} | Validation Accuracy: {val_acc:.5f}")

    if save_state:
        torch.save(model.state_dict(), 'states/model_state.pth')

    # Update return dictionary
    return {
        'state': model.state_dict(),
        'train_loss_history': loss_history,
        'train_acc_history': train_acc_history,
        'val_acc': val_acc
    }

In [None]:
# define Datast class
# ------------------

class ContinuedPretrainData(Dataset):
    def __init__(self, base_data, tokenizer, device, max_len=512):
        self.max_len = max_len
        self.tokenizer = tokenizer
        self.data = []
        self.labels = []
        # self.mask_positions = [] # Depracated

        tokenized_examples = tokenizer(
            df_train,
            max_length=512,
            truncation=True,
            padding='max_length',
            return_tensors="pt"
        )

        # filter out short examples
        input_ids = tokenized_examples['input_ids'][tokenized_examples['attention_mask'][:, max_len - 1] > 0]

        # mask tokens
        masked_indices = torch.rand(input_ids.shape) < 0.15  # 15% probability
        masked_indices[:, 0] = False  # ensure first token is not masked
        masked_indices[:, -1] = False  # ensure last token is not masked
        masked_tokens = input_ids.clone()
        masked_tokens[masked_indices] = tokenizer.mask_token_id
        masked_tokens = torch.tensor(masked_tokens)

        # generate labels from masked tokens
        labels = input_ids.clone()
        labels[~masked_indices] = -100
        labels[masked_indices] = input_ids[masked_indices]
        labels = torch.tensor(labels)

        # masked positions (depracated)
        # mask_positions = input_ids.clone()
        # mask_positions[~masked_indices] = False
        # mask_positions[masked_indices] = True
        # mask_positions = torch.tensor(mask_positions)

        # send data to device
        self.data = masked_tokens.to(device)
        self.labels = labels.to(device)
        # self.mask_positions = mask_positions.to(device)

    def __len__(self):
        return self.data.shape[0]

    def __getitem__(self, index):
        return {'inputs': self.data[index],
                'labels': self.labels[index],
                # 'mask_positions': self.mask_positions[index],
        }

In [None]:
# define training loop
# ------------------

optimizer = AdamW(model.parameters(), lr=2e-5)

steps = 0
# batch_size = 4 # this doesn't work - yields an error
batch_size = 2048

# training loop
for epoch in range(2):

    train_dataset = ContinuedPretrainData(df_train, tokenizer, device)

    dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True
    )

    for batch in dataloader:

        inputs = {
            "input_ids": batch["inputs"],
            "labels": batch["labels"],
        }

        # Forward pass
        outputs = model(**inputs)

        # Loss calculation
        loss = nn.CrossEntropyLoss()(outputs.logits, inputs["labels"])

        # Backward pass and optimization
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # Print training progress (optional)
        if steps % 100 == 0:
            print(f"Epoch: {epoch+1}/{3}, Step: {steps}, Loss: {loss.item()}")
            steps += 1


In [None]:
train_dataset[0]

In [None]:
num_zeros = []
num_twos = []
for i in range(0,len(train_dataset)):
    # count the number of times 0 appears in the example
    num_zeros += [torch.sum(train_dataset[i]['inputs'] == 0).item()]
    num_twos += [torch.sum(train_dataset[i]['inputs'] == 2).item()]

len(num_zeros)


In [None]:
num_zeros = np.array(num_zeros)
num_twos = np.array(num_twos)
sum(num_zeros == 1)
sum(num_twos == 1)

In [None]:
# decode
print(
    tokenizer.decode(train_dataset[3]['inputs']),
    sep='\n',
)