In [None]:
# !pip install accelerate
# !pip install accelerate -U
# !pip install transformers[torch]

In [None]:
import glob

data_dir = 'datasets/babylm_10M/*.train'     # change if needed

# Use glob to get all .train files in the directory
file_paths = glob.glob(data_dir)

# Concatenate all text files into one big text file
with open("combined_dataset.txt", "w") as outfile:
    for file_path in file_paths:
        with open(file_path, "r") as infile:
            outfile.write(infile.read())


In [None]:
from transformers import ElectraTokenizer, ElectraForPreTraining, ElectraConfig, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from torch.utils.data import Dataset, DataLoader
import torch

# get text data from the concatenated text file
with open("combined_dataset.txt", "r") as file:
    text_data = file.read().splitlines()

# Initialize the tokenizer
tokenizer = ElectraTokenizer.from_pretrained('google/electra-small-discriminator')

# Tokenize the text data
inputs = tokenizer(text_data, return_tensors='pt', truncation=True, padding=True, max_length=512)

# Create a PyTorch dataset
class TextDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

dataset = TextDataset(inputs)

# Create a configuration for the model
config = ElectraConfig(
    vocab_size=30522,  # Size of Vocabulary
    embedding_size=128,  # Embedding size
    hidden_size=256,  # Size of the encoder layers and the pooler layer
    num_hidden_layers=12,  # Number of hidden layers in the Transformer encoder
    num_attention_heads=4,  # Number of attention heads for each attention layer in the Transformer encoder
    intermediate_size=1024,  # The size of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder
    hidden_act="gelu",  # The non-linear activation function (function or string) in the encoder and pooler
    hidden_dropout_prob=0.1,  # The dropout probability for all fully connected layers in the embeddings, encoder, and pooler
    attention_probs_dropout_prob=0.1,  # The dropout ratio for the attention probabilities
    max_position_embeddings=512,  # The maximum sequence length that this model might ever be used with
)

# Initialize the model with the configuration
model = ElectraForPreTraining(config)

# Initialize the data collator
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True
)

# Initialize the training arguments
training_args = TrainingArguments(
    output_dir='./results',          # output directory
    num_train_epochs=10,             # total number of training epochs
    per_device_train_batch_size=8,   # batch size per device during training
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir='./logs',            # directory for storing logs
    logging_steps=10,                # Log every X updates steps
    log_level='info',                # Set the logger to the 'info' level
    log_level_replica='info',        # Set the logger of the replicas to the 'info' level
    max_steps=100000,                # Limit the total number of training steps to 100000
)

# Initialize the trainer
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=dataset,               # training dataset
    data_collator=data_collator,
)

# Train the model
trainer.train()

# Save the model, tokenizer, and trainer state
model.save_pretrained('electra_dir/')
tokenizer.save_pretrained('electra_dir/')

loading file vocab.txt from cache at /root/.cache/huggingface/hub/models--google--electra-small-discriminator/snapshots/fa8239aadc095e9164941d05878b98afe9b953c3/vocab.txt
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--google--electra-small-discriminator/snapshots/fa8239aadc095e9164941d05878b98afe9b953c3/tokenizer_config.json
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--google--electra-small-discriminator/snapshots/fa8239aadc095e9164941d05878b98afe9b953c3/tokenizer.json
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--google--electra-small-discriminator/snapshots/fa8239aadc095e9164941d05878b98afe9b953c3/config.json
Model config ElectraConfig {
  "_name_or_path": "google/electra-small-discriminator",
  "architectures": [
    "ElectraForPreTraining"
  ],
  "attention_probs

Step,Training Loss
10,-24.0697
20,-28.3953
30,-44.066
40,-78.1366
50,-136.2806
60,-157.0625
70,-179.9317
80,-305.2004
90,-316.3105
100,-373.1202


Saving model checkpoint to ./results/tmp-checkpoint-500
Configuration saved in ./results/tmp-checkpoint-500/config.json
Model weights saved in ./results/tmp-checkpoint-500/model.safetensors
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Saving model checkpoint to ./results/tmp-checkpoint-1000
Configuration saved in ./results/tmp-checkpoint-1000/config.json
Model weights saved in ./results/tmp-checkpoint-1000/model.safetensors
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Saving model checkpoint to ./results/tmp-checkpoint-1500
Configuration saved in ./results/tmp-checkpoint-1500/config.json
Model weights saved in ./results/tmp-checkpoint-1500/model.safetensors
  return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
Saving model checkpoint to ./results/tmp-checkpoint-2000
Configuration saved in ./results/tmp-checkpoint-2000/config.json
Model weights saved in ./results/tmp-checkpoint-2000/model.safetenso

Step,Training Loss
10,-24.0697
20,-28.3953
30,-44.066
40,-78.1366
50,-136.2806
60,-157.0625
70,-179.9317
80,-305.2004
90,-316.3105
100,-373.1202
