In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from archehr import PROJECT_DIR
from archehr.data.dataset import QADataset
from archehr.data.utils import load_data, make_query_sentence_pairs

In [92]:
import torch

class QADatasetEmbedding(torch.utils.data.Dataset):
    def __init__(self, data, tokenizer, model, device=torch.device('cpu')):
        super(QADatasetEmbedding, self).__init__()
        self.data = data
        self.tokenizer = tokenizer
        self.model = model.to(device)
        self.device = device
        self.translate_dict = {
            u: k
            for k, u in enumerate(set([i['label'] for i in data]))
        }

    @property
    def emb_size(self):
        return self[0][0].size()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        query, sentence = item['query']

        # make the encoding
        encoding = self.tokenizer(
            query,
            sentence,
            padding=False,
            truncation=False,
            return_tensors='pt'
        )
        encoding.to(self.device)

        # make the embedding
        with torch.no_grad():
            embedding = self.model(**encoding)
        
        return embedding.logits.squeeze(0).to(device), self.translate_dict[item['label']]

In [105]:
def do_eval(model, dataloader, device, loss, target, progress_bar=None):
    """
    Evaluate the model on the validation set.
    
    Args:
        model: The model to evaluate.
        dataloader: The dataloader for the validation set.
        device: The device to use for evaluation.
    
    Returns:
        The average loss and accuracy on the validation set.
    """
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    tp = 0
    fp = 0
    fn = 0

    with torch.no_grad():
        for batch, labels in dataloader:
            # Move inputs and labels to device
            batch = batch.to(device)
            labels = labels.to(device)

            # Forward pass
            with torch.no_grad():
                outputs = model(batch)

            # Compute the loss
            l = loss(outputs, labels)
            val_loss += l.item()

            # Compute accuracy
            _, predicted = torch.max(outputs, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Compute true positives & false positives & false negatives
            tp += sum((labels == target) & (predicted == target)).item()
            fp += sum((labels != target) & (predicted == target)).item()
            fn += sum((labels == target) & (predicted != target)).item()

    # Compute metrics
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = (
        2 * (precision * recall) / (precision + recall) 
        if (precision + recall) > 0 else 0
    )

    avg_loss = val_loss / len(dataloader)
    accuracy = correct / total

    if progress_bar is not None:
        progress_bar.set_postfix(
            loss=f"{avg_loss:.4f}",
            acc=f"{accuracy:.1%}",
            ppv=f"{precision:.1%}",
            rec=f"{recall:.1%}",
            f1=f"{f1:.1%}",
        )

    output_dict = {
        'loss': avg_loss,
        'acc': accuracy,
        'ppv': precision,
        'rec': recall,
        'f1': f1,
    }

    return output_dict


In [51]:
from torch import Tensor
from typing import Optional, Callable

class Mlp(nn.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: Optional[int] = None,
        out_features: Optional[int] = None,
        act_layer: Callable[..., nn.Module] = nn.GELU,
        drop: float = 0.0,
        bias: bool = True,
    ) -> None:
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
        self.drop = nn.Dropout(drop)

    def forward(self, x: Tensor) -> Tensor:
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


## Use nli-deberta-v3-base

In [107]:
# Load the data & the model
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# Load the data
data_path = PROJECT_DIR / "data" / "1.1" / "dev"
data = load_data(data_path)
n_cases = len(data)

# Split train / val
data_train = data[:int(0.8 * n_cases)]
data_val = data[int(0.8 * n_cases):]

# Load the model & tokenizer
model_name = "cross-encoder/nli-deberta-v3-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Remove the last layer
model.classifier = nn.Identity()

In [101]:
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

# Make the pairs
# Make the pairs
pairs_train = make_query_sentence_pairs(data_train)
pairs_val = make_query_sentence_pairs(data_val)

# Make embedding datasets
emb_train = QADatasetEmbedding(pairs_train, tokenizer, model, device=device)
emb_val = QADatasetEmbedding(pairs_val, tokenizer, model, device=device)

# Make embedding dataloaders
train_loader = DataLoader(emb_train, batch_size=128, shuffle=True)
val_loader = DataLoader(emb_val, batch_size=128,)

mlp = Mlp(
    emb_train.emb_size.numel(),
    out_features=len(emb_train.translate_dict)
)
mlp.to(device)

Using: cuda


Mlp(
  (fc1): Linear(in_features=768, out_features=768, bias=True)
  (act): GELU(approximate='none')
  (fc2): Linear(in_features=768, out_features=3, bias=True)
  (drop): Dropout(p=0.0, inplace=False)
)

In [102]:
# Make optimizer and loss
optimizer = torch.optim.AdamW(mlp.parameters())
loss = nn.CrossEntropyLoss()

In [106]:
from tqdm import tqdm

num_epochs = 100

mlp.train()
for epoch in (progress_bar := tqdm(range(num_epochs))):
    for batch, labels in train_loader:
        # Move inputs and labels to device
        batch = batch.to(device)
        labels = labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = mlp(batch)

        # Backward pass and optimization
        l = loss(outputs, labels)
        l.backward()
        optimizer.step()

    if epoch % 10 == 0:
        do_eval(
            mlp,
            val_loader,
            device,
            loss,
            target=dataset_val.translate_dict['essential'],
            progress_bar=progress_bar,    
        )

100%|████████████████████████████████████████████| 100/100 [27:22<00:00, 16.43s/it, acc=49.6%, f1=23.9%, loss=1.4556, ppv=25.4%, rec=22.7%]


## Use other cross-encoder

In [111]:
model = AutoModelForSequenceClassification.from_pretrained('cross-encoder/ms-marco-MiniLM-L12-v2')
tokenizer = AutoTokenizer.from_pretrained('cross-encoder/ms-marco-MiniLM-L12-v2')

# Remove the last layer
model.classifier = nn.Identity()

model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 384, padding_idx=0)
      (position_embeddings): Embedding(512, 384)
      (token_type_embeddings): Embedding(2, 384)
      (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=384, out_features=384, bias=True)
              (key): Linear(in_features=384, out_features=384, bias=True)
              (value): Linear(in_features=384, out_features=384, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=384, out_features=384, bias=True)
              (LayerNorm): LayerNorm((384,), eps=1e

In [112]:
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using: {device}")

# Make the pairs
# Make the pairs
pairs_train = make_query_sentence_pairs(data_train)
pairs_val = make_query_sentence_pairs(data_val)

# Make embedding datasets
emb_train = QADatasetEmbedding(pairs_train, tokenizer, model, device=device)
emb_val = QADatasetEmbedding(pairs_val, tokenizer, model, device=device)

# Make embedding dataloaders
train_loader = DataLoader(emb_train, batch_size=128, shuffle=True)
val_loader = DataLoader(emb_val, batch_size=128,)

mlp = Mlp(
    emb_train.emb_size.numel(),
    out_features=len(emb_train.translate_dict)
)
mlp.to(device)

Using: cuda


Mlp(
  (fc1): Linear(in_features=384, out_features=384, bias=True)
  (act): GELU(approximate='none')
  (fc2): Linear(in_features=384, out_features=3, bias=True)
  (drop): Dropout(p=0.0, inplace=False)
)

In [113]:
from tqdm import tqdm

num_epochs = 100

mlp.train()
for epoch in (progress_bar := tqdm(range(num_epochs))):
    for batch, labels in train_loader:
        # Move inputs and labels to device
        batch = batch.to(device)
        labels = labels.to(device)

        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = mlp(batch)

        # Backward pass and optimization
        l = loss(outputs, labels)
        l.backward()
        optimizer.step()

    if epoch % 10 == 0:
        do_eval(
            mlp,
            val_loader,
            device,
            loss,
            target=dataset_val.translate_dict['essential'],
            progress_bar=progress_bar,    
        )

100%|███████████████████████████████████████████████| 100/100 [10:54<00:00,  6.54s/it, acc=10.2%, f1=0.0%, loss=1.0931, ppv=0.0%, rec=0.0%]
