In [None]:
from transformers import AutoTokenizer, AutoModelForTokenClassification, TrainingArguments, Trainer
import torch
from datasets import load_dataset

# Load the pre-trained ESM-2 model and tokenizer
model_name = "facebook/esm-2-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)

# Load the dataset
dataset = load_dataset("your_dataset_name")

# Tokenize the input sequences and add special tokens
tokenized_inputs = tokenizer(dataset["input_text"], truncation=True, padding=True)

# Convert the labels to numerical values
label_list = dataset["label_list"]
label_map = {label: i for i, label in enumerate(label_list)}
labels = [[label_map[label] for label in example["labels"]] for example in dataset["examples"]]

# Create a PyTorch dataset and dataloader
class MyDataset(torch.utils.data.Dataset):
    def __init__(self, inputs, labels):
        self.inputs = inputs
        self.labels = labels
        
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.inputs.items()}, torch.tensor(self.labels[idx])
    
    def __len__(self):
        return len(self.labels)

train_dataset = MyDataset(tokenized_inputs["train"], labels["train"])
eval_dataset = MyDataset(tokenized_inputs["validation"], labels["validation"])

train_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=1e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
    logging_dir='./logs',
    logging_steps=10,
)

# Instantiate a Trainer object and fine-tune the model
trainer = Trainer(
    model=model,
    args=train_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=lambda data: {'input_ids': torch.stack([item[0]['input_ids'] for item in data]),
                               'attention_mask': torch.stack([item[0]['attention_mask'] for item in data]),
                               'labels': torch.stack([item[1] for item in data])}
)

trainer.train()