In [None]:
# in case you decide to run on Google Colab (uncomment the next line)
# !pip install -qU datasets lightning peft

In [None]:
import os
import pandas as pd

from typing import Optional

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

from sklearn.metrics import f1_score, precision_score, recall_score

import pytorch_lightning as L

from peft import get_peft_model, LoraConfig

from transformers import AutoModel, AutoTokenizer
from transformers.optimization import get_constant_schedule_with_warmup

ACCELERATOR_DEVICE = "cuda"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

I stored the ZIP files containing the CHEMU original dataset in my Google Drive, so I use the following code to load my drive and unzip the files.

In [None]:
# from google.colab import drive

# drive.mount('/content/drive')

# %cd drive/MyDrive/WEG-VENT

# !unzip -n ./data/chemu.ee.dev.zip -d ./data
# !unzip -n ./data/chemu.ee.train.zip -d ./data

The following lines let's you access the tensorboard logs during training. The "logs" directory is created in the current working directory.

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir logs

### Dataset

In [None]:
class CHEMUDataset(Dataset):
    label_to_idx = {
        "0": 0,
        "STARTING_MATERIAL": 1,
        "REAGENT_CATALYST": 2,
        "REACTION_PRODUCT": 3,
        "SOLVENT": 4,
        "OTHER_COMPOUND": 5,
        "TIME": 6,
        "TEMPERATURE": 7,
        "YIELD_PERCENT": 8,
        "YIELD_OTHER": 9,
        "EXAMPLE_LABEL": 10,
        "REACTION_STEP": 11,
        "WORKUP": 12,
    }

    def __init__(
        self,
        tokenizer="bert-base-uncased",
        max_length=512,
        split="train",
        data_dir_path="../data/ee",
        verbose=False,
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
        self.max_length = max_length
        self.split = split
        self.verbose = verbose
        if split not in ["train", "dev"]:
            raise ValueError("split must be one of 'train', 'dev', or 'test'")

        files = os.listdir(data_dir_path + f"_{split}")
        files = [f.split(".")[0] for f in files if f.endswith(".txt")]
        self.files = sorted(files)
        print(f"Found {len(self.files)} files.")

        self.tokens = []
        self.targets = []
        self.attention_masks = []
        # self.files = self.files[:1] if split == 'train' else self.files[:50]  # For debugging, use only the first file in dev split
        for file in self.files:
            # if file not in ['0096', '0130', '0163', '0215', '0225', '0421', '0541', '0544', '0590', '0692', '0696', '0762', '0890', '1069', '1082', '1243', '1278', '1314', '1352', '1388', '1404', '1468']:
            try:
                file_tokens, file_attention_mask, file_targets = self.load_file(file)
                if len(file_tokens) > self.max_length:
                    if self.verbose:
                        print(
                            f"File {file} exceeds max length ({len(file_tokens)} > {self.max_length})"
                        )
                self.tokens.append(file_tokens)
                self.targets.append(file_targets)
                self.attention_masks.append(file_attention_mask)
            except Exception as e:
                print(f"Error loading file {file}: {e}")
                continue

    def load_file(self, file):
        with open(f"../data/ee_{self.split}/{file}.txt") as f:
            text = f.read()

        with open(f"../data/ee_{self.split}/{file}.ann") as f:
            annotations = f.readlines()

        ner_ann, ee_ann = self.generate_dataframe(annotations)
        tokens, targets = self.structure_data(text, ner_ann, ee_ann)
        attention_mask = torch.ones(len(tokens), dtype=torch.int)
        # tokens, targets = torch.tensor(tokens, dtype=torch.long), torch.tensor(targets, dtype=torch.int)
        tokens = torch.tensor(tokens, dtype=torch.long)
        # print(targets)
        # print(torch.tensor(targets))
        targets = torch.tensor(targets, dtype=torch.long)
        return tokens, attention_mask, targets

    def generate_dataframe(self, annotations):
        ner_ann = []
        ee_ann = []
        for line in annotations:
            if line.startswith("T"):
                # ner_ann.append(line.split('\t'))
                # ner_ann = ner_ann.append({
                ner_ann.append(
                    {
                        "id": line.split("\t")[0],
                        "label": line.split("\t")[1].split(" ")[0],
                        "start": int(line.split("\t")[1].split(" ")[1]),
                        "end": int(line.split("\t")[1].split(" ")[2]),
                        "text": line.split("\t")[2].strip(),
                    }
                )
                # }, ignore_index=True)
            elif line.startswith("R"):
                # ee_ann.append(line.split('\t'))
                ee_ann.append(
                    {
                        "id": line.split("\t")[0],
                        "Arg1": line.split("\t")[1].split(" ")[1].split(":")[1],
                        "Arg2": line.split("\t")[1]
                        .split(" ")[2]
                        .split(":")[1]
                        .replace("\n", ""),
                    }
                )

        ner_ann = pd.DataFrame(ner_ann)
        ee_ann = pd.DataFrame(ee_ann)
        ner_ann["label"] = ner_ann["label"].map(self.label_to_idx)

        ner_ann = ner_ann.sort_values(by=["start"]).reset_index(drop=True)
        if (ner_ann["start"][1:].to_numpy() < ner_ann["end"][:-1].to_numpy()).any():
            raise ValueError("Overlapping entities found in the annotations.")

        return ner_ann, ee_ann

    def structure_data(self, text, ner_ann, ee_ann):
        ends = ner_ann["end"].to_numpy()
        starts = ner_ann["start"].to_numpy()
        labels = ner_ann["label"].to_numpy()

        bos_token, eos_token = self.tokenizer("").input_ids
        sections = [
            [bos_token],
            self.tokenizer(text[: starts[0]], add_special_tokens=False).input_ids,
        ]
        section_labels = [[-1], [0] * len(sections[1])]
        for i in range(len(ner_ann)):
            sections.append(
                self.tokenizer(
                    text[starts[i] : ends[i]], add_special_tokens=False
                ).input_ids
            )
            section_labels.append([labels[i]] * len(sections[-1]))
            if i + 1 < len(ner_ann):
                sections.append(
                    self.tokenizer(
                        text[ends[i] : starts[i + 1]], add_special_tokens=False
                    ).input_ids
                )
                section_labels.append([0] * len(sections[-1]))
        sections.append(
            self.tokenizer(text[ends[-1] :], add_special_tokens=False).input_ids
        )
        sections.append([eos_token])

        section_labels.append([0] * len(sections[-2]))
        section_labels.append([-1])

        tokens = [item for sublist in sections for item in sublist]
        targets = [item for sublist in section_labels for item in sublist]

        return tokens, targets

    def __len__(self):
        return len(self.tokens)

    def __getitem__(self, idx):
        # return self.tokens[idx], self.targets[idx]
        return self.tokens[idx], self.attention_masks[idx], self.targets[idx]

    def collate_fn(self, batch):
        tokens, attention_masks, targets = zip(*batch)
        tokens = torch.nn.utils.rnn.pad_sequence(
            tokens, batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        targets = torch.nn.utils.rnn.pad_sequence(
            targets, batch_first=True, padding_value=-1
        )
        attention_masks = torch.nn.utils.rnn.pad_sequence(
            attention_masks, batch_first=True, padding_value=0
        )

        tokens = tokens[:, : self.max_length]
        attention_masks = attention_masks[:, : self.max_length]
        targets = targets[:, : self.max_length]

        return {
            "input_ids": tokens,
            "attention_mask": attention_masks,
            "labels": targets,
            # 'return_loss': True,
        }

    def get_class_weights(self):
        all_targets = torch.concatenate(self.targets)
        all_targets = all_targets[all_targets != -1]
        class_counts = torch.bincount(all_targets)
        class_weights = 1.0 / class_counts
        class_weights = class_weights / class_weights.mean()
        return class_weights

### Data Module

In [None]:
class CHEMUDataModule(L.LightningDataModule):
    def __init__(
        self,
        tokenizer: str = "bert-base-uncased",
        max_length: int = 512,
        batch_size: int = 8,
        num_workers: int = 4,
        data_dir: str = "./data",
        pin_memory: bool = True,
        drop_last: bool = False,
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.data_dir = data_dir
        self.pin_memory = pin_memory
        self.drop_last = drop_last

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def prepare_data(self):
        # Tokenizer download (if necessary)
        AutoTokenizer.from_pretrained(self.tokenizer)

    def setup(self, stage: str = None):
        if stage == "fit" or stage is None:
            self.train_dataset = CHEMUDataset(
                tokenizer=self.tokenizer, max_length=self.max_length, split="train"
            )
            self.val_dataset = CHEMUDataset(
                tokenizer=self.tokenizer, max_length=self.max_length, split="dev"
            )

        if stage == "test" or stage is None:
            self.test_dataset = CHEMUDataset(
                tokenizer=self.tokenizer,
                max_length=self.max_length,
                split="dev",  # or 'test' if available
            )

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.train_dataset.collate_fn,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=self.drop_last,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.val_dataset.collate_fn,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=False,
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.test_dataset.collate_fn,
            num_workers=self.num_workers,
            pin_memory=self.pin_memory,
            drop_last=False,
        )

    def get_class_weights(self):
        # Only available after setup has run
        if self.train_dataset:
            return self.train_dataset.get_class_weights()
        else:
            raise RuntimeError("Call setup() before accessing class weights.")


### Model

In [None]:
class V1Model(nn.Module):
    def __init__(self, base_model: str, num_classes: int, lora_config=None):
        super().__init__()

        self.device = torch.device(ACCELERATOR_DEVICE)
        self.num_classes = num_classes

        self.embedding_model = AutoModel.from_pretrained(base_model)

        self.classifier_head = nn.Linear(768, num_classes, bias=False)
        nn.init.normal_(self.classifier_head.weight, mean=0.0, std=0.0)
        self.classifier_head = self.classifier_head

        self.final_layer_activation = nn.Softmax(dim=-1)

        if lora_config:
            self.embedding_model = get_peft_model(self.embedding_model, lora_config)

    def forward(self, text_input_ids, attn_mask):
        text_embeddings = self.embedding_model(
            text_input_ids, attn_mask
        ).last_hidden_state

        logits = self.classifier_head(text_embeddings)
        probas = self.final_layer_activation(logits)

        return probas

### Manual Optim (Double scheduler)

In [None]:
class DoubleScheduler(L.LightningModule):
    def __init__(
        self,
        model: V1Model,
        len_train_dataloader: int,
        batch_size: int = 16,
        head_lr: float = 1e-3,
        checkpoint_lr: float = 2e-5,
        class_weights: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.model = model
        self.len_dataset = len_train_dataloader
        self.batch_size = batch_size
        self.head_lr = head_lr
        self.checkpoint_lr = checkpoint_lr
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1, weight=class_weights)
        self.automatic_optimization = False  # Enable manual optimization

    def training_step(self, batch, batch_idx):
        text_input_ids, label, attention_mask = (
            batch["input_ids"],
            batch["labels"],
            batch["attention_mask"],
        )

        # return from model has shape (batch_size, seq_length, num_classes)
        # labels has shape (batch_size, seq_length)
        probas = self.model(text_input_ids, attention_mask)

        # self.loss_fct.weight = class_weights
        loss = self.loss_fct(probas.view(-1, self.model.num_classes), label.view(-1))

        self.log("train_loss", loss, prog_bar=True)

        # Manually perform optimization
        head_optimizer, checkpoint_optimizer = self.optimizers()
        head_optimizer.zero_grad()
        checkpoint_optimizer.zero_grad()
        self.manual_backward(loss)
        head_optimizer.step()
        checkpoint_optimizer.step()

        # Step the schedulers
        head_scheduler, checkpoint_scheduler = self.lr_schedulers()
        head_scheduler.step()
        checkpoint_scheduler.step()

        # Log learning rates
        self.log(
            "checkpoint_lr",
            checkpoint_optimizer.param_groups[0]["lr"],
            prog_bar=True,
            on_step=True,
        )
        self.log(
            "head_lr", head_optimizer.param_groups[0]["lr"], prog_bar=True, on_step=True
        )

        return loss

    def validation_step(self, batch, batch_idx):
        text_input_ids, label, attention_mask = (
            batch["input_ids"],
            batch["labels"],
            batch["attention_mask"],
        )
        probas = self.model(text_input_ids, attention_mask)

        # Compute loss
        loss = self.loss_fct(probas.view(-1, self.model.num_classes), label.view(-1))

        # Get predictions and flatten for metric calculation
        preds = torch.argmax(probas, dim=-1).view(-1).cpu().numpy()
        targets = label.view(-1).cpu().numpy()

        # Mask out ignored index (-1) for metrics
        mask = targets != -1
        preds = preds[mask]
        targets = targets[mask]

        # Compute metrics only if there are valid tokens
        if len(targets) > 0:
            f1 = f1_score(targets, preds, average="macro", zero_division=0)
            precision = precision_score(
                targets, preds, average="macro", zero_division=0
            )
            recall = recall_score(targets, preds, average="macro", zero_division=0)
        else:
            f1 = precision = recall = 0.0

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_f1", f1, prog_bar=True)
        self.log("val_precision", precision, prog_bar=True)
        self.log("val_recall", recall, prog_bar=True)

        return loss

    def configure_optimizers(self):
        # Define the optimizers
        head_optimizer = optim.RAdam(
            self.model.classifier_head.parameters(), lr=self.head_lr
        )
        checkpoint_optimizer = optim.RAdam(
            self.model.embedding_model.parameters(), lr=self.checkpoint_lr
        )

        # Define the schedulers
        head_scheduler = get_constant_schedule_with_warmup(
            head_optimizer, num_warmup_steps=0
        )
        checkpoint_scheduler = get_constant_schedule_with_warmup(
            checkpoint_optimizer, num_warmup_steps=0,
        )

        # Return the optimizers and schedulers
        return [head_optimizer, checkpoint_optimizer], [
            head_scheduler,
            checkpoint_scheduler,
        ]

### Auto Optim (single Scheduler)

In [None]:
class SingleScheduler(L.LightningModule):
    def __init__(
        self,
        model: V1Model,
        len_train_dataloader: int,
        batch_size: int = 16,
        head_lr: float = 1e-3,
        checkpoint_lr: float = 2e-5,
        class_weights: Optional[torch.Tensor] = None,
    ):
        super().__init__()
        self.model = model
        self.len_dataset = len_train_dataloader
        self.batch_size = batch_size
        self.head_lr = head_lr
        self.checkpoint_lr = checkpoint_lr
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-1, weight=class_weights)

    def training_step(self, batch, batch_idx):
        text_input_ids, label, attention_mask = (
            batch["input_ids"],
            batch["labels"],
            batch["attention_mask"],
        )

        probas = self.model(text_input_ids, attention_mask)

        loss = self.loss_fct(probas.view(-1, self.model.num_classes), label.view(-1))

        self.log("train_loss", loss, prog_bar=True)
        self.log(
            "lr",
            self.trainer.optimizers[0].param_groups[0]["lr"],
            prog_bar=True,
            on_step=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        text_input_ids, label, attention_mask = (
            batch["input_ids"],
            batch["labels"],
            batch["attention_mask"],
        )
        probas = self.model(text_input_ids, attention_mask)

        # Compute loss
        loss = self.loss_fct(probas.view(-1, self.model.num_classes), label.view(-1))

        # Get predictions and flatten for metric calculation
        preds = torch.argmax(probas, dim=-1).view(-1).cpu().numpy()
        targets = label.view(-1).cpu().numpy()

        # Mask out ignored index (-100) for metrics
        mask = targets != -1
        preds = preds[mask]
        targets = targets[mask]

        # Compute metrics only if there are valid tokens
        if len(targets) > 0:
            f1 = f1_score(targets, preds, average="macro", zero_division=0)
            precision = precision_score(
                targets, preds, average="macro", zero_division=0
            )
            recall = recall_score(targets, preds, average="macro", zero_division=0)
        else:
            f1 = precision = recall = 0.0

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_f1", f1, prog_bar=True)
        self.log("val_precision", precision, prog_bar=True)
        self.log("val_recall", recall, prog_bar=True)

        return loss

    def configure_optimizers(self):
        optimizer = optim.RAdam(self.model.parameters(), lr=self.checkpoint_lr)
        # scheduler = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=0.2*(self.len_dataset//self.batch_size))

        return optimizer

### `Setup & Train`

In [None]:
data_module = CHEMUDataModule(
    tokenizer="bert-base-uncased",
    max_length=512,
    batch_size=8,
    num_workers=2,
)

data_module.prepare_data()
data_module.setup()

In [None]:
model = V1Model(
    base_model="bert-base-uncased",
    num_classes=len(CHEMUDataset.label_to_idx),
    lora_config=LoraConfig(
        r=64,
        lora_alpha=128,
    ),
)

# lightning = SingleScheduler(
#     model=model,
#     len_train_dataloader=len(data_module.train_dataloader()),
#     batch_size=data_module.batch_size,
#     head_lr=5e-4,
#     checkpoint_lr=5e-4,
#     class_weights=class_weights
# )

double_lightning = DoubleScheduler(
    model=model,
    len_train_dataloader=len(data_module.train_dataloader()),
    batch_size=data_module.batch_size,
    head_lr=1e-3,
    checkpoint_lr=2e-4,
    class_weights=data_module.get_class_weights()
)


def get_logger_name():
    return (
        f"chemud-{model.__class__.__name__}"
        f"-{data_module.tokenizer}"
        f"-clr{double_lightning.checkpoint_lr}"
        f"-hlr{double_lightning.head_lr}"
        f"-bs{data_module.batch_size}"
        f"-alpha{model.embedding_model.peft_config['default'].lora_alpha}"
        f"-r{model.embedding_model.peft_config['default'].r}"
        f"-num_classes{len(CHEMUDataset.label_to_idx)}"
    )


trainer = L.Trainer(
    max_epochs=10,
    accelerator=ACCELERATOR_DEVICE,
    devices=1,
    logger=L.loggers.TensorBoardLogger("logs/", name=get_logger_name()),
    # callbacks=[
    #     L.callbacks.ModelCheckpoint(
    #         dirpath="checkpoints/",
    #         filename="chemud-{epoch:02d}-{val_loss:.2f}",
    #         monitor="val_loss",
    #         mode="min",
    #         save_top_k=1,
    #     ),
    # ],
)

# trainer.fit(lightning, datamodule=data_module)
trainer.fit(double_lightning, datamodule=data_module)