In [1]:
import torch
from torch.utils.data import Dataset as TorchDataset
from datasets import load_dataset
import numpy as np
import random
from transformers import EsmConfig, EsmForMaskedLM, EsmTokenizer, DataCollatorForLanguageModeling, TrainingArguments, Trainer, EarlyStoppingCallback, EvalPrediction
from transformers.data.data_collator import DataCollatorMixin, _torch_collate_batch
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
from transformers.trainer import has_length
from transformers.utils import is_datasets_available
from transformers.trainer_pt_utils import LengthGroupedSampler, RandomSampler
import argparse
from torch import nn

In [None]:
def parse_args():

    parser = argparse.ArgumentParser()

    # Data arguments
    parser.add_argument('--data_path', type=str, required=True, help='Dataset path')
    parser.add_argument('--tokenizer_path', type=str, required=True, help='Tokenizer path')
    parser.add_argument('--save_path', type=str, default='./best_model', help='Save Path')
    parser.add_argument('--weight_path', type=str, default='./best_model', help='Weight Save Path')

    # Model arguments
    parser.add_argument('--hidden_size', type=int, default=512, help='Embedding dimension')
    parser.add_argument('--num_hidden_layers', type=int, default=24, help='Number of transformer blocks')
    parser.add_argument('--seed', type=int, default=338, help='Random seed for reproducibility')
    parser.add_argument('--model_path', type=str, required=True, help='Model path')

    # Training arguments
    parser.add_argument('--wandb', action='store_true', help='Use Weights & Biases for logging')
    parser.add_argument('--batch_size', type=int, default=32, help='Batch size')
    parser.add_argument('--num_epochs', type=int, default=1, help='Number of epochs')
    parser.add_argument('--learning_rate', type=float, default=1e-4, help='Learning rate')
    parser.add_argument('--mlm_probability', type=str, choices=['0.05', '0.15', '0.3', '0.5', 'mixed'], default='mixed',
                        help='MLM probability')
    parser.add_argument('--patience', type=int, default=3, help='Early stopping patience')
    parser.add_argument('--group_by_length', action='store_true', default=True, help='Group by length for SortedTrainer')
    parser.add_argument('--grad_accum', type=int, default=1, help='Gradient Accumulation')
    parser.add_argument('--evaluation_strategy', type=str, default='steps', help='Evaluation strategy to use')
    parser.add_argument('--lr_scheduler_type', type=str, default='cosine', help='Type of learning rate scheduler')
    parser.add_argument('--optim', type=str, default='adamw_torch', help='Optimizer')
    parser.add_argument('--log_path', type=str, help='Path for logging')
    parser.add_argument('--logging_steps', type=int, default=10, help='Step interval for logging')
    parser.add_argument('--eval_steps', type=int, default=50, help='Step interval for evaluation')
    parser.add_argument('--weight_decay', type=float, default=0.01, help='Weight decay rate')
    parser.add_argument('--warmup_steps', type=int, default=10, help='Number of warmup steps')
    parser.add_argument('--effective_batch_size', type=int, default=1000000, help='Effective batch size for training')
    parser.add_argument('--save_total_limit', type=int, default=5, help='Maximum number of checkpoints to keep')
    parser.add_argument('--load_best_model_at_end', action='store_true', help='Load the best model at the end of training')
    parser.add_argument('--greater_is_better', action='store_true', help='Determines if a greater metric signifies a better model')

    # Output arguments
    parser.add_argument('--output_dir', type=str, default='output', help='Output directory')

    args = parser.parse_args()

    return args

args = parse_args()


In [2]:
class args:
    data_path = 'nikraf/uniref128-256AA'
    tokenizer_path = 'facebook/esm2_t30_150M_UR50D'
    model_path = 'facebook/esm2_t6_8M_UR50D'
    save_path = './best_model'
    weight_path = './best_model'
    hidden_size = 512
    num_hidden_layers = 24
    seed = 338
    wandb = True
    batch_size = 1
    num_epochs = 1
    lr = 1e-4
    mlm_probability = 'mixed'
    patience = 3
    group_by_length = True
    grad_accum = 1
    evaluation_strategy = 'steps'
    lr_scheduler_type = 'cosine'
    optim = 'adamw_torch'
    log_path = './mlmlog.txt'
    logging_steps = 10
    eval_steps = 50
    weight_decay = 0.01
    warmup_steps = 10
    effective_batch_size = 1000000
    valid_size = 5000
    max_length = 512
    fp16 = False
    save_total_limit = 5
    load_best_model_at_end = True
    greater_is_better = True
    output_dir = 'output'

In [3]:
dataset = load_dataset(args.data_path)['train']


class Dataset(TorchDataset):
    
    def __init__(self, dataset):
        self.seqs = dataset['seqs']
        self.lengths = [len(seq) for seq in self.seqs]

    def __len__(self):
        return len(self.seqs)
    
    def __avg__(self):
        return sum(self.lengths) / len(self.lengths)

    def __getitem__(self, idx):
        seq = self.seqs[idx]
        return seq
    
train_dataset = Dataset(dataset)


In [4]:
class EsmForMLM(EsmForMaskedLM):

    def __init__(self, model_path, hidden_size, num_hidden_layers, seed):
        torch.manual_seed(seed)
        self.model_path = model_path
        config = EsmConfig.from_pretrained(model_path)
        config.hidden_size = hidden_size
        config.num_hidden_layers = num_hidden_layers
        config.num_attention_heads = config.hidden_size // 32
        super().__init__(config)

        self.apply(self._init_weights)

    def _init_weights(self, module):

        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)


In [5]:
class DataCollatorForMixedMLM(DataCollatorMixin):
    
    def __init__(self, tokenizer: EsmTokenizer, mlm: bool = True, return_tensors: str = 'pt', pad_to_multiple_of: Optional[int] = None):
        self.tokenizer = tokenizer
        self.mlm = mlm
        self.return_tensors = return_tensors
        self.pad_to_multiple_of = pad_to_multiple_of


    def torch_call(self, input): # input is list of examples from dataset
        batch = self.tokenizer(input, return_tensors='pt', padding='longest', truncation=False, add_special_tokens=False)
        if self.mlm:
            batch['input_ids'], labels = self.torch_mask_tokens(batch['input_ids'])
            batch['labels'] = labels # the keys here, need to match the keys in the model
        return batch # in here are input_ids, attention_mask, and labels as keys in dictionary

    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 90% MASK, 5% random, 5% original.
        """
        labels = inputs.clone()
        #probability_matrix = torch.full(labels.shape, random.uniform(self.min_prob, self.max_prob))
        probability_matrix = torch.normal(mean=0.3, std=0.12, size=labels.shape)
        probability_matrix = torch.clamp(probability_matrix, min=0.0, max=1.0)

        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens
        # 85% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.90)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced # 0.5 because half of remaining are random
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]
        # The rest of the time (5% of the time) we keep the masked input tokens unchanged
        return inputs, labels
    

class DataCollatorForMLM(DataCollatorForLanguageModeling):

    def __init__(self, tokenizer, mlm_probability=0.15, **kwargs):
        super().__init__(tokenizer=tokenizer, mlm=True, mlm_probability=mlm_probability, **kwargs)

    def torch_mask_tokens(self, inputs: Any, special_tokens_mask: Optional[Any] = None) -> Tuple[Any, Any]:
        """
        Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original.
        """

        labels = inputs.clone()
        # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)
        probability_matrix = torch.full(labels.shape, self.mlm_probability)
        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens

        # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
        indices_replaced = torch.bernoulli(torch.full(labels.shape, 0.9)).bool() & masked_indices
        inputs[indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

        # 10% of the time, we replace masked input tokens with random word
        indices_random = torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & masked_indices & ~indices_replaced
        random_words = torch.randint(len(self.tokenizer), labels.shape, dtype=torch.long)
        inputs[indices_random] = random_words[indices_random]

        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels


In [6]:
if args.effective_batch_size > args.batch_size: # effective batch_size is in tokens
# for MLM, probably 1e6 tokens and learning rate of 1e-4 is good
# or 1e5 tokens and lr of 1e-5
    num_devices = torch.cuda.device_count() if torch.cuda.device_count() > 1 else 1 # for cpu
    avg_length = train_dataset.__avg__()
    args.grad_accum = int((args.effective_batch_size / avg_length) / (args.batch_size * num_devices))
    args.grad_accum = args.grad_accum if args.grad_accum > 0 else 1
    if args.grad_accum == 1:
        args.effective_batch_size = avg_length * args.batch_size * num_devices

    print('\n-----Batching Summary-----\n')
    print(f'Number of devices: {num_devices}')
    print(f'Average sequence length: {avg_length}')
    print(f'Local batch size: {args.batch_size} seqs')
    print(f'Gradient accumulation: {args.grad_accum}')
    print(f'Effective batch size: {int(args.effective_batch_size)} tokens')
else:
    args.grad_accum = 1


-----Batching Summary-----

Number of devices: 1
Average sequence length: 184.8241525771167
Local batch size: 1 seqs
Gradient accumulation: 5410
Effective batch size: 1000000 tokens


In [7]:
def compute_metrics(p: EvalPrediction):
    logits = p.predictions[0] if isinstance(p.predictions, tuple) else p.predictions
    labels = p.label_ids
    logits = np.array(logits)
    labels = np.array(labels)
    preds = np.argmax(logits, axis=-1)
    valid_indices = (labels != -100)
    valid_preds = preds[valid_indices]
    valid_labels = labels[valid_indices]
    accuracy = np.mean(valid_preds == valid_labels)
    return {'mlm_accuracy': accuracy}

def log_metrics(config, metrics, header=None): # need a log_path in args, needs to be txt file
    def log_nested_dict(d, parent_key=''):
        filtered_results = {}
        for k, v in d.items():
            new_key = f'{parent_key}_{k}' if parent_key else k
            if isinstance(v, dict):
                filtered_results.update(log_nested_dict(v, new_key))
            elif 'runtime' not in k or 'second' not in k:
                filtered_results[new_key] = round(v, 5) if isinstance(v, (float, int)) else v
        return filtered_results

    filtered_results = log_nested_dict(metrics)

    with open(config.log_path, 'a') as f:
        if header is not None:
            f.write(header + '\n')
        for k, v in filtered_results.items():
            f.write(f'{k}: {v}\n')
        f.write('\n')

In [10]:
def get_data_collator(mlm_probability, tokenizer):

    if mlm_probability == 'mixed':
        return DataCollatorForMixedMLM(tokenizer=tokenizer)
    else:
        mlm_probability = float(mlm_probability)
        return DataCollatorForMLM(tokenizer=tokenizer, mlm_probability=mlm_probability)

model = EsmForMLM(args.model_path,hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers, seed=args.seed)
tokenizer = EsmTokenizer.from_pretrained(args.tokenizer_path)
data_collator = get_data_collator(args.mlm_probability, tokenizer)


In [11]:
class SortedTrainer(Trainer):
    def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
        if self.train_dataset is None or not has_length(self.train_dataset):
            return None
        if self.args.group_by_length:
            if is_datasets_available() and isinstance(self.train_dataset, TorchDataset):
                lengths = self.train_dataset.lengths # this requires your dataset has a self.lengths with the lengths in it
            else:
                lengths = None
            model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
            return LengthGroupedSampler(
                self.args.train_batch_size * self.args.gradient_accumulation_steps,
                dataset=self.train_dataset,
                lengths=lengths,
                model_input_name=model_input_name,
            )
        else:
            return RandomSampler(self.train_dataset)

training_args = TrainingArguments(
    report_to='wandb' if args.wandb else None,
    output_dir=args.save_path,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    gradient_accumulation_steps=args.grad_accum,
    logging_steps=args.logging_steps,
    evaluation_strategy='steps',
    eval_steps=args.eval_steps,
    num_train_epochs=args.num_epochs,
    weight_decay=args.weight_decay,
    warmup_steps=args.warmup_steps,
    lr_scheduler_type='cosine',
    learning_rate=args.lr,
    optim='adamw_torch',
    seed=args.seed,
    data_seed=args.seed,
    save_steps=args.eval_steps,
    save_total_limit=3,
    load_best_model_at_end=True,
    greater_is_better=False,
    fp16=args.fp16,
    group_by_length=True,
)


trainer = SortedTrainer(
    model=model,
    tokenizer=tokenizer,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=args.patience)]  # usually 3 - 5
)

trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mnikraf99[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011277777777932999, max=1.0…

  0%|          | 0/90 [00:00<?, ?it/s]

KeyboardInterrupt: 