-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d2969d6
commit c878edf
Showing
7 changed files
with
1,606 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from pathlib import Path | ||
|
||
DATA_FOLDER = "/work/data" | ||
|
||
|
||
def get_config(): | ||
return { | ||
"batch_size": 8, # Size of each batch for training | ||
"num_epochs": 2, # Number of training epochs | ||
"lr": 10**-4, # Learning rate for the optimizer | ||
"seq_len": 400, # Maximum sequence length for the model | ||
"d_model": 512, # Dimensionality of the token embeddings | ||
"datasource": "opus_books", # Source of the training data | ||
"lang_src": "en", # Source language code (e.g., English) | ||
"lang_tgt": "fr", # Target language code (e.g., French) | ||
"model_folder": "weights", # Directory to store model weights | ||
"model_basename": "tmodel_", # Base name for the saved model files | ||
"preload": None, # Preloaded model weights if available, otherwise None | ||
"tokenizer_file": "tokenizer_{0}.json", # Filename pattern for saving the tokenizer | ||
"experiment_name": "runs/tmodel", # Name for the experiment, used for logging | ||
} | ||
|
||
|
||
def get_weights_file_path(config, epoch: str): | ||
model_folder = f"{DATA_FOLDER}/{config['datasource']}_{config['model_folder']}" | ||
model_filename = f"{config['model_basename']}{epoch}.pt" | ||
return f"{model_folder}/{model_filename}" | ||
|
||
|
||
# Find the latest weights file in the weights folder | ||
def latest_weights_file_path(config): | ||
model_folder = f"{DATA_FOLDER}/{config['datasource']}_{config['model_folder']}" | ||
model_filename = f"{config['model_basename']}*" | ||
weights_files = list(Path(model_folder).glob(model_filename)) | ||
if len(weights_files) == 0: | ||
return None | ||
weights_files.sort() | ||
return str(weights_files[-1]) |
281 changes: 281 additions & 0 deletions
281
john_toolbox/train/transformers/from_scratch/dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
import logging | ||
from typing import Any | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class BilingualDataset(Dataset): | ||
""" | ||
A Dataset class for handling bilingual data suitable for transformer architectures. | ||
It processes source and target language pairs for sequence-to-sequence tasks, such as | ||
machine translation. The class prepares the data by tokenizing, adding special tokens | ||
(SOS, EOS, PAD), and creating attention masks. | ||
Parameters | ||
---------- | ||
ds : Dataset | ||
The original dataset containing pairs of sentences in two languages. | ||
tokenizer_src : Tokenizer | ||
The tokenizer for the source language. | ||
tokenizer_tgt : Tokenizer | ||
The tokenizer for the target language. | ||
src_lang : str | ||
The source language code (e.g., 'en' for English). | ||
tgt_lang : str | ||
The target language code (e.g., 'fr' for French). | ||
seq_len : int | ||
The fixed sequence length for the model input. Longer sequences will be truncated, | ||
and shorter ones will be padded. | ||
Attributes | ||
---------- | ||
ds : Dataset | ||
Stores the original dataset. | ||
tokenizer_src : Tokenizer | ||
Tokenizer for the source language. | ||
tokenizer_tgt : Tokenizer | ||
Tokenizer for the target language. | ||
src_lang : str | ||
Source language code. | ||
tgt_lang : str | ||
Target language code. | ||
seq_len : int | ||
Methods | ||
------- | ||
__len__() | ||
Returns the length of the dataset. | ||
__getitem__(index) | ||
Returns a preprocessed item from the dataset at the specified index. | ||
Examples | ||
-------- | ||
>>> from transformers import BertTokenizer | ||
>>> tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | ||
>>> bilingual_data = BilingualDataset(ds, tokenizer, tokenizer, 'en', 'fr', 128) | ||
>>> print(bilingual_data[0]) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
ds, | ||
tokenizer_src, | ||
tokenizer_tgt, | ||
src_lang, | ||
tgt_lang, | ||
seq_len, | ||
) -> None: | ||
""" | ||
Initializes the BilingualDataset object by setting up tokenizers and special tokens. | ||
Parameters are the same as described in the class documentation. | ||
""" | ||
super().__init__() | ||
self.ds = ds | ||
self.tokenizer_src = tokenizer_src | ||
self.tokenizer_tgt = tokenizer_tgt | ||
self.src_lang = src_lang | ||
self.tgt_lang = tgt_lang | ||
self.seq_len = seq_len | ||
|
||
# start of sequence | ||
self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64) | ||
# end of sequence | ||
self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64) | ||
# padding token, sentances does not contains same number of word, we need to add pad tokens to have the same length | ||
self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64) | ||
|
||
def __len__(self): | ||
""" | ||
Returns the length of the dataset. | ||
Returns | ||
------- | ||
int | ||
The number of items in the dataset. | ||
""" | ||
return len(self.ds) | ||
|
||
def __getitem__(self, index) -> Any: | ||
""" | ||
Retrieves an item from the dataset and preprocesses it for transformer input. | ||
The method tokenizes the source and target sentences, adds special tokens (SOS, EOS), | ||
pads the sequences to a fixed length, and creates attention masks. | ||
Parameters | ||
---------- | ||
index : int | ||
The index of the item to be retrieved from the dataset. | ||
Returns | ||
------- | ||
dict | ||
A dictionary containing the preprocessed data for transformer input. | ||
Raises | ||
------ | ||
ValueError | ||
If either the source or target sentence is longer than `seq_len - 2`. | ||
""" | ||
|
||
# Extract the source and target sentences from the dataset src | ||
src_target_pair = self.ds[index] | ||
src_text = src_target_pair["translation"][self.src_lang] | ||
tgt_text = src_target_pair["translation"][self.tgt_lang] | ||
|
||
# Tokenize the source and target sentences | ||
enc_input_tokens = self.tokenizer_src.encode(src_text).ids | ||
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids | ||
|
||
# Calculate the number of padding tokens for the encoder input. | ||
# The '2' accounts for both the SOS and EOS tokens added to the sequence. | ||
# This ensures that the total length of the encoder input equals 'seq_len'. | ||
enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 | ||
# Calculate the number of padding tokens for the decoder input. | ||
# The '1' accounts for the SOS token added to the start of the sequence. | ||
# The EOS token is not included in the decoder input, as it's used as a part of the output label. | ||
# This ensures that the total length of the decoder input equals 'seq_len'. | ||
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 | ||
|
||
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0: | ||
LOGGER.error(len(enc_input_tokens)) | ||
LOGGER.error(len(dec_input_tokens)) | ||
LOGGER.error(self.seq_len) | ||
raise ValueError("Sentence too long.") | ||
|
||
# Concatenate SOS, EOS, and padding tokens with the tokenized source sentence | ||
encoder_input = torch.cat( | ||
[ | ||
self.sos_token, | ||
torch.tensor(enc_input_tokens, dtype=torch.int64), | ||
self.eos_token, | ||
torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64), | ||
], | ||
dim=0, | ||
) | ||
|
||
# Concatenate SOS and padding tokens with the tokenized target sentence for decoder input | ||
decoder_input = torch.cat( | ||
[ | ||
self.sos_token, | ||
torch.tensor(dec_input_tokens, dtype=torch.int64), | ||
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), | ||
], | ||
dim=0, | ||
) | ||
|
||
# Create the label for training by adding EOS and padding tokens to the target tokens | ||
label = torch.cat( | ||
[ | ||
torch.tensor(dec_input_tokens, dtype=torch.int64), | ||
self.eos_token, | ||
torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64), | ||
], | ||
dim=0, | ||
) | ||
|
||
# Ensure all tensors have the same sequence length as defined | ||
assert encoder_input.size(0) == self.seq_len | ||
assert decoder_input.size(0) == self.seq_len | ||
assert label.size(0) == self.seq_len | ||
|
||
# Create the encoder mask for attention mechanism. | ||
# This mask is a binary tensor indicating where the padding tokens are. | ||
# The mask has '1's where tokens are not padding and '0's where they are padding. | ||
# This allows the transformer's attention mechanism to ignore padding tokens. | ||
# | ||
# Example: | ||
# If the encoder input is [SOS, 15, 234, 67, EOS, PAD, PAD, PAD] (assuming 'seq_len' is 8), | ||
# then the encoder_mask would be [1, 1, 1, 1, 1, 0, 0, 0]. | ||
# This indicates to the model that it should only attend to the first five tokens. | ||
encoder_mask = (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() | ||
# Create the decoder mask for the attention mechanism in the transformer. | ||
# This mask serves two purposes: | ||
# 1. It masks out padding tokens (similar to the encoder mask). | ||
# 2. It prevents the decoder from 'seeing' future tokens in the sequence, | ||
# ensuring that the prediction for each token only depends on previous tokens. | ||
# | ||
# The mask is created by combining a padding mask and a causal mask. | ||
# Example: | ||
# If the decoder input is [SOS, 56, 78, 102, PAD, PAD, PAD, PAD] (assuming 'seq_len' is 8), | ||
# then the padding mask would be [1, 1, 1, 1, 0, 0, 0, 0]. | ||
# The causal mask for this length would be a lower triangular matrix of size 8x8, | ||
# allowing each token to attend only to itself and preceding tokens. | ||
# The final decoder mask is the combination of these two masks. | ||
# tensor([[[1, 0, 0, 0, 0, 0, 0, 0], | ||
# [1, 1, 0, 0, 0, 0, 0, 0], | ||
# [1, 1, 1, 0, 0, 0, 0, 0], | ||
# [1, 1, 1, 1, 0, 0, 0, 0], | ||
# [1, 1, 1, 1, 0, 0, 0, 0], | ||
# [1, 1, 1, 1, 0, 0, 0, 0], | ||
# [1, 1, 1, 1, 0, 0, 0, 0], | ||
# [1, 1, 1, 1, 0, 0, 0, 0]]], dtype=torch.int32) | ||
decoder_mask = (decoder_input != self.pad_token).unsqueeze(0).unsqueeze( | ||
0 | ||
).int() & causal_mask(size=decoder_input.size(0)) | ||
|
||
return { | ||
"encoder_input": encoder_input, # (seq_len) | ||
"decoder_input": decoder_input, # (seq_len) | ||
"encoder_mask": encoder_mask, # (1, 1, seq_len) | ||
"decoder_mask": decoder_mask, # (1, seq_len) & (1, seq_len, seq_len) | ||
"label": label, # (seq_len) | ||
"src_text": src_text, | ||
"tgt_text": tgt_text, | ||
} | ||
|
||
|
||
def causal_mask(size): | ||
""" | ||
Generates a causal mask to prevent attention to future tokens in a sequence. | ||
This function creates a mask for use in the self-attention mechanism of a transformer's | ||
decoder, ensuring that each position in the sequence can only attend to itself and | ||
positions before it. This maintains the autoregressive property of the decoder. | ||
Parameters | ||
---------- | ||
size : int | ||
The size of the sequence. | ||
Returns | ||
------- | ||
torch.Tensor | ||
A 2D tensor representing the causal mask, where the mask is 0 at and below the diagonal and 1 elsewhere. | ||
Example | ||
------- | ||
For a size of 4, the causal mask would be: | ||
[[0, 1, 1, 1], | ||
[0, 0, 1, 1], | ||
[0, 0, 0, 1], | ||
[0, 0, 0, 0]] | ||
The function uses `torch.triu` to create an upper triangular matrix, which is then inverted. | ||
""" | ||
|
||
# Create an upper triangular matrix using torch.triu. | ||
# The 'diagonal=1' parameter sets the elements above the main diagonal to 1 | ||
# (the main diagonal elements are set to 0). | ||
# This matrix represents which positions should initially be ignored (set to 1). | ||
# Example output for size=4: | ||
# [[0., 1., 1., 1.], | ||
# [0., 0., 1., 1.], | ||
# [0., 0., 0., 1.], | ||
# [0., 0., 0., 0.]] | ||
mask = torch.triu(torch.ones(1, size, size), diagonal=1).type(torch.int) | ||
|
||
# Invert the mask to create a causal mask. | ||
# After inversion, the positions set to 0 (current and past positions) are now 1, | ||
# allowing the model to attend to these positions. | ||
# The inverted (causal) mask for size=4 would be: | ||
# [[1, 0, 0, 0], | ||
# [1, 1, 0, 0], | ||
# [1, 1, 1, 0], | ||
# [1, 1, 1, 1]] | ||
# Here, '1' indicates positions that are allowed to be attended to. | ||
return mask == 0 |
Oops, something went wrong.