In [1]:
!git clone https://github.com/ryokamoi/wice.git

%pip install -q datasets transformers[torch]

%cd wice/data/entailment_retrieval/claim

Cloning into 'wice'...
remote: Enumerating objects: 158, done.[K
remote: Counting objects: 100% (133/133), done.[K
remote: Compressing objects: 100% (103/103), done.[K
remote: Total 158 (delta 29), reused 131 (delta 27), pack-reused 25[K
Receiving objects: 100% (158/158), 20.63 MiB | 2.90 MiB/s, done.
Resolving deltas: 100% (32/32), done.
Updating files: 100% (101/101), done.
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m4.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m261.4/261.4 kB[0m [31m9.8 MB/s[0m eta [36m0:00:00[0m
[?25h/content/wice/data/entailment_retrieval/claim


In [2]:
from datasets import Dataset, load_from_disk
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from transformers import (
    BertConfig,
    BertForSequenceClassification,
    BertTokenizerFast,
    BertPreTrainedModel
    get_linear_schedule_with_warmup,
)
from tqdm.auto import tqdm
from typing import Any, List, Dict, Tuple

dir_path = ""
data_path = ""

# Span-BERT


https://www.kaggle.com/code/anasofiauzsoy/tweet-sentiment-extraction-with-tf2-spanbert


In [3]:
def claim_spans(row_data: pd.Series) -> Tuple[List, List]:
    """
    Extracts start and end character indices of supporting sentences for each row.

    Parameters:
    - row_data (pd.Series): A row from the dataset containing "supporting_sentences" and "evidence" columns.

    Returns:
    - tuple: Two lists containing start and end character indices of supporting sentences.
    """
    indices = sorted(set(x_i for x in row_data["supporting_sentences"] for x_i in x))
    start_idx, end_idx = [], []
    for idx in indices:
        start = len(" ".join(row_data["evidence"][:idx]))
        start_idx.append(start)
        end_idx.append(start + len(row_data["evidence"][idx]))

    return start_idx, end_idx

def load_data(filename: str) -> pd.DataFrame:
    """
    Loads data from a JSON file and preprocesses it.

    Parameters:
    - filename (str): Path to the JSON file.

    Returns:
    - pd.DataFrame: Processed DataFrame containing relevant information.
    """
    df_temp = pd.read_json(filename, lines=True)

    # drop unnecessary columns and get the character indices for the spans
    df_temp[["start_idx", "end_idx"]] = df_temp.apply(
        claim_spans, axis=1, result_type="expand"
    )
    df_temp.drop(columns=["label", "supporting_sentences", "meta"],
                 inplace=True)

    return df_temp

def extract_spans(arr: list) -> Tuple[np.ndarray, np.ndarray]:
    """
    Extracts continuous spans from a list of indices.

    Parameters:
    - arr (list): List of integers representing indices.

    Returns:
    - tuple: Two arrays containing start and end indices of continuous spans.
    """
    start_span = []
    end_span = []
    if arr:
      current_span = [arr[0]]

      for i in range(1, len(arr)):
          if arr[i] == arr[i - 1] + 1:
              current_span.append(arr[i])
          else:
              # End of the current span, start a new one
              start_span.append(current_span[0])
              end_span.append(current_span[-1])
              current_span = [arr[i]]

      # Add the last span
      start_span.append(current_span[0])
      end_span.append(current_span[-1])

    return np.array(start_span, dtype=int), np.array(end_span, dtype=int)

def preprocess_wice_examples(evidence: List, start_char_idx: List, end_char_idx: List, max_len: int) -> Dict:
    """
    Preprocesses examples for the WICE model.

    Parameters:
    - evidence (list): List of evidence sentences.
    - start_char_idx (list): List of start character indices.
    - end_char_idx (list): List of end character indices.
    - max_len (int): Maximum length of the input.

    Returns:
    - dict: Preprocessed data including input IDs, token type IDs, attention mask, start token indices, and end token indices.
    """
    evidence_text = " ".join(evidence)

    # Mark the character indexes in text that are in answer
    is_char_in_span = [0] * len(evidence_text)
    for start_idx, end_idx in zip(start_char_idx, end_char_idx):
        for idx in range(start_idx, end_idx):
            is_char_in_span[idx] = 1

    # Tokenize text
    tokenized_text = tokenizer.encode_plus(
        evidence_text,
        return_offsets_mapping=True,
        max_length=max_len,
        truncation=True,
        padding=True
    )

    # Find tokens through the character offsets for each token in the original text
    span_token_idx = []
    for idx, (start, end) in enumerate(tokenized_text.offset_mapping):
        if sum(is_char_in_span[start:end]) > 0:
            span_token_idx.append(idx)

    # Find start and end token index for tokens from answer
    start_token_idx, end_token_idx = extract_spans(span_token_idx)

    return {
        "input_ids": np.array(tokenized_text.input_ids),
        "token_type_ids": np.array(tokenized_text.token_type_ids),
        "attention_mask": np.array(tokenized_text.attention_mask),
        "start_token_idx": np.array(start_token_idx),
        "end_token_idx": np.array(end_token_idx),
    }

def create_and_process_dataset(data: pd.DataFrame, max_len: int) -> Dataset:
    """
    Creates and processes a dataset.

    Parameters:
    - data (pd.DataFrame): DataFrame containing the dataset.
    - max_len (int): Maximum length of the input.

    Returns:
    - Dataset: Processed dataset.
    """
    preprocessed_data = []
    for index, row in tqdm(data.iterrows(), total=data.shape[0]):
        preprocessed_example = preprocess_wice_examples(row['evidence'], row['start_idx'], row['end_idx'], max_len)
        preprocessed_data.append(preprocessed_example)
    preprocessed_df = pd.DataFrame(preprocessed_data)
    return Dataset.from_pandas(preprocessed_df)

def pad_token_idx(train: Dataset, dev: Dataset) -> Tuple[Dataset, Dataset, int]:
    """
    Pads token indices in training and development datasets.

    Parameters:
    - train (Dataset): Training dataset.
    - dev (Dataset): Development dataset.

    Returns:
    - Tuple: Padded training and development datasets along with the maximum padding length.
    """
    def padding(row, max_len):
            pad_sequence = [-100] * max(0, max_len - len(row['start_token_idx']))
            row['start_token_idx'] = row['start_token_idx'] + pad_sequence
            row['end_token_idx'] = row['end_token_idx'] + pad_sequence
            return row

    max_pad_length = max(max(len(lst) for lst in train['start_token_idx']),
                         max(len(lst) for lst in dev['start_token_idx']))
    train = train.map(lambda x: padding(x, max_pad_length))
    dev = dev.map(lambda x: padding(x, max_pad_length))
    return train, dev, max_pad_length

In [4]:
##### Setup BERT model #####
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "SpanBERT/spanbert-base-cased"
model = BertForSequenceClassification.from_pretrained(model_name).to(device)
tokenizer = BertTokenizerFast.from_pretrained(model_name)

##### Set up static variables #####
save_dir = os.path.join(dir_path, "models", "bert", "wice_classifier")
max_len = 512 # BERT max token length
batch_size = 8
num_epochs = 1

# Reproducibility
seed_val = 42
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

config.json:   0%|          | 0.00/413 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/215M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at SpanBERT/spanbert-base-cased and are newly initialized: ['bert.pooler.dense.weight', 'classifier.weight', 'classifier.bias', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

In [5]:
##### Setup WiCE DATA #####
df_train = load_data("train.jsonl")
df_dev = load_data("dev.jsonl")

train_dataset = create_and_process_dataset(df_train, max_len)
dev_dataset = create_and_process_dataset(df_dev, max_len)
train_dataset, dev_dataset, n_spans = pad_token_idx(train_dataset, dev_dataset)
train_dataset.set_format('torch')
dev_dataset.set_format('torch')

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)

df_dev.head()

  0%|          | 0/1260 [00:00<?, ?it/s]

  0%|          | 0/349 [00:00<?, ?it/s]

Map:   0%|          | 0/1260 [00:00<?, ? examples/s]

Map:   0%|          | 0/349 [00:00<?, ? examples/s]

Unnamed: 0,claim,evidence,start_idx,end_idx
0,Arnold is currently the publisher and editoria...,[(meta data) TITLE: About Us – Media Play News...,"[116, 1648, 1810]","[180, 1716, 1960]"
1,"The Tozzer library itself holds over 260,000 v...",[(meta data) TITLE: Mission & History | About ...,[],[]
2,He appeared in the 2016 Grammy-nominated docum...,[(meta data) TITLE: Steve Aoki 'I'll Sleep Whe...,"[0, 166, 404, 446, 534, 1063, 2528, 3332, 3498]","[103, 208, 419, 533, 859, 1256, 2929, 3497, 3750]"
3,"This further decreased to 41.3 % in 2016 , mai...",[(meta data) TITLE: Behind the God-swapping in...,"[188, 2492, 7402]","[240, 2598, 7564]"
4,"The band's third album, ""Emotive"", was release...",[(meta data) TITLE: A Perfect Circle Album Pre...,"[129, 602, 3643]","[177, 720, 3691]"


In [10]:
class WiceModel(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)

        self.spanbert = model # Base SpanBERT model

        # Added layers for predicting start and end index logits
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_states, config.num_labels)
        self.init_weights()

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, labels=None):
        outputs = self.spanbert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        sequence_output = outputs[0]
        sequence_output = self.dropout(sequence_output)
        logits = self.classifier(sequence_output)

        outputs = (logits, ) + outputs[:]
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            active_loss = attention_mask.view(-1) == 1
            active_logits = logits.view(-1, self.num_labels)
            active_labels = torch.where(
                active_loss, labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(labels)
            )
            loss = loss_fct(active_logits, active_labels)
            outputs = (loss,) + outputs

        return outputs



In [9]:
config = BertConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        label2id=label2id,
        id2label=id2label,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
        revision=model_args.model_revision,
        use_auth_token=True if model_args.use_auth_token else None,
    )



loss_function = nn.CrossEntropyLoss()
optimizer = AdamW(model.parameters(), lr=5e-5)

model = WiceModel(max_len)
model.to(device)

WiceModel(
  (base_model): WiceModel(
    (base_model): WiceModel(
      (base_model): BertForSequenceClassification(
        (bert): BertModel(
          (embeddings): BertEmbeddings(
            (word_embeddings): Embedding(28996, 768, padding_idx=0)
            (position_embeddings): Embedding(512, 768)
            (token_type_embeddings): Embedding(2, 768)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (encoder): BertEncoder(
            (layer): ModuleList(
              (0-11): 12 x BertLayer(
                (attention): BertAttention(
                  (self): BertSelfAttention(
                    (query): Linear(in_features=768, out_features=768, bias=True)
                    (key): Linear(in_features=768, out_features=768, bias=True)
                    (value): Linear(in_features=768, out_features=768, bias=True)
                    (dropout): Dropout(p=0.1, inplace=F

In [14]:
# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0.0

    for batch in train_dataloader:
        input_ids = batch['input_ids']
        token_type_ids = batch['token_type_ids']
        attention_mask = batch['attention_mask']
        start_labels = batch['start_token_idx']
        end_labels = batch['end_token_idx']

        optimizer.zero_grad()
        start_probs, end_probs = model(input_ids=input_ids, attn_mask=attention_mask)

        start_loss = loss_function(start_probs, start_labels)
        end_loss = loss_function(end_probs, end_labels)
        total_loss = start_loss + end_loss
        total_loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {total_loss.item()}')


TypeError: ignored

In [None]:
# Save the trained model if needed
torch.save(model.state_dict(), 'wice_model.pth')