In [None]:
!pip install transformers[sentencepiece] #installs full transformers package
!pip install datasets
!pip install accelerate -U
!pip install evaluate

In [None]:
import torch
import transformers
from transformers import AutoTokenizer, DataCollatorWithPadding
import numpy as np
from datasets import load_dataset
from huggingface_hub import notebook_login
from datasets import Dataset, DatasetDict

In [None]:
auth_token="..."

In [None]:
# load the correct datasets from huggingface:
dataset_name = "kghanlon/right_as_train"
data = load_dataset(dataset_name, token=auth_token)
# train/test are from other/right parties (or non green parties)
# validation is the inference set on left parties/green parties

In [None]:
data

In [None]:
data["train"][0]

In [None]:
# what model are we using?
checkpoint = "FacebookAI/roberta-large"
# load its tokenizer
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
# Tokenize the data:
def tokenize_function(example):
    return tokenizer(example["q_sentence"], truncation=True)

# tokenize the data
tokenized_datasets = data.map(tokenize_function, batched=True)
# Create a data collator that adds padding to the batches
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
tokenized_datasets

In [None]:
tokenizer.decode(0)

In [None]:
tokenizer.decode(2)

In [None]:
input_ids_with_context = {}

In [None]:
# modify TEST

target_context_size = 201 #tokens
target_data = tokenized_datasets["test"]
# set start and sep tokens:
start = 0
sep = 2
input_ids_with_context["test"] = []
for i, row in enumerate(target_data):
    context = row["input_ids"] #first add the original sentence
    # remove start and sep tokens:
    context = [x for x in context if x not in [start, sep]]
    # is it already too long?
    if len(context) >= target_context_size:
        # truncate it
        context = context[0:target_context_size]
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["test"].append(final)
    else:
        # we want to add context around the sentence until we reach 200 tokens total
        j = 1
        while(True):
            # how many tokens can we still add?
            tokens_left = target_context_size - len(context)
            # check variable:
            check = True
            # can we add from the row before?
            if (target_data[i-j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i-j]["q_sentence_nr"]+j):
                check = False
                # get the max tokens from the sentence before
                before_tokens = target_data[i-j]["input_ids"][-tokens_left:]
                # remove start and sep tokens:
                before_tokens = [x for x in before_tokens if x not in [start, sep]]
                # add these at the beginning of the context
                context[:0] = before_tokens
                # update tokens_left
                tokens_left = target_context_size - len(context)
            # can we add from the row after?
            # first make sure that row actually exists:
            if i+j < target_data.shape[0]:
                if (target_data[i+j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i+j]["q_sentence_nr"]-j):
                    check = False
                    # get the max tokens from the sentence after
                    after_tokens = target_data[i+j]["input_ids"][0:tokens_left]
                    # remove start and sep tokens:
                    after_tokens = [x for x in after_tokens if x not in [start, sep]]
                    # add these at the end of the context
                    context.extend(after_tokens)
                    # update tokens_left
                    tokens_left = target_context_size - len(context)
            if check is True:
                # we could not add from before and after, so just quit. Context will thus be smaller
                print("check for i = ", i)
                break
            # have we reached the limit?
            if tokens_left <= 1:
                # truncate the context (this always cuts off from the end, maybe we dont want this?) But guarantees that the original sentence remains
                context = context[0:target_context_size]
                break
            j = j+1
        # context should now always contain 200 tokens
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["test"].append(final)
    # if i == 10:
    #     break

In [None]:
# modify TRAIN

target_context_size = 201 #tokens
target_data = tokenized_datasets["train"]
# set start and sep tokens:
start = 0
sep = 2
input_ids_with_context["train"] = []
for i, row in enumerate(target_data):
    context = row["input_ids"] #first add the original sentence
    # remove start and sep tokens:
    context = [x for x in context if x not in [start, sep]]
    # is it already too long?
    if len(context) >= target_context_size:
        # truncate it
        context = context[0:target_context_size]
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["train"].append(final)
    else:
        # we want to add context around the sentence until we reach 200 tokens total
        j = 1
        while(True):
            # how many tokens can we still add?
            tokens_left = target_context_size - len(context)
            # check variable:
            check = True
            # can we add from the row before?
            if (target_data[i-j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i-j]["q_sentence_nr"]+j):
                check = False
                # get the max tokens from the sentence before
                before_tokens = target_data[i-j]["input_ids"][-tokens_left:]
                # remove start and sep tokens:
                before_tokens = [x for x in before_tokens if x not in [start, sep]]
                # add these at the beginning of the context
                context[:0] = before_tokens
                # update tokens_left
                tokens_left = target_context_size - len(context)
            # can we add from the row after?
            # first make sure that row actually exists:
            if i+j < target_data.shape[0]:
                if (target_data[i+j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i+j]["q_sentence_nr"]-j):
                    check = False
                    # get the max tokens from the sentence after
                    after_tokens = target_data[i+j]["input_ids"][0:tokens_left]
                    # remove start and sep tokens:
                    after_tokens = [x for x in after_tokens if x not in [start, sep]]
                    # add these at the end of the context
                    context.extend(after_tokens)
                    # update tokens_left
                    tokens_left = target_context_size - len(context)
            if check is True:
                # we could not add from before and after, so just quit. Context will thus be smaller
                print("check for i = ", i)
                break
            # have we reached the limit?
            if tokens_left <= 1:
                # truncate the context (this always cuts off from the end, maybe we dont want this?) But guarantees that the original sentence remains
                context = context[0:target_context_size]
                break
            j = j+1
        # context should now always contain 200 tokens
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["train"].append(final)
    # if i == 10:
    #     break

In [None]:
# modify VALIDATION (so inference in this case)

target_context_size = 201 #tokens
target_data = tokenized_datasets["validation"]
# set start and sep tokens:
start = 0
sep = 2
input_ids_with_context["validation"] = []
for i, row in enumerate(target_data):
    context = row["input_ids"] #first add the original sentence
    # remove start and sep tokens:
    context = [x for x in context if x not in [start, sep]]
    # is it already too long?
    if len(context) >= target_context_size:
        # truncate it
        context = context[0:target_context_size]
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["validation"].append(final)
    else:
        # we want to add context around the sentence until we reach 200 tokens total
        j = 1
        while(True):
            # how many tokens can we still add?
            tokens_left = target_context_size - len(context)
            # check variable:
            check = True
            # can we add from the row before?
            if (target_data[i-j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i-j]["q_sentence_nr"]+j):
                check = False
                # get the max tokens from the sentence before
                before_tokens = target_data[i-j]["input_ids"][-tokens_left:]
                # remove start and sep tokens:
                before_tokens = [x for x in before_tokens if x not in [start, sep]]
                # add these at the beginning of the context
                context[:0] = before_tokens
                # update tokens_left
                tokens_left = target_context_size - len(context)
            # can we add from the row after?
            # first make sure that row actually exists:
            if i+j < target_data.shape[0]:
                if (target_data[i+j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i+j]["q_sentence_nr"]-j):
                    check = False
                    # get the max tokens from the sentence after
                    after_tokens = target_data[i+j]["input_ids"][0:tokens_left]
                    # remove start and sep tokens:
                    after_tokens = [x for x in after_tokens if x not in [start, sep]]
                    # add these at the end of the context
                    context.extend(after_tokens)
                    # update tokens_left
                    tokens_left = target_context_size - len(context)
            if check is True:
                # we could not add from before and after, so just quit. Context will thus be smaller
                print("check for i = ", i)
                break
            # have we reached the limit?
            if tokens_left <= 1:
                # truncate the context (this always cuts off from the end, maybe we dont want this?) But guarantees that the original sentence remains
                context = context[0:target_context_size]
                break
            j = j+1
        # context should now always contain 200 tokens
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["validation"].append(final)
    # if i == 10:
    #     break

In [None]:
# modify INFERENCE

target_context_size = 201 #tokens
target_data = tokenized_datasets["inference_left"]
# set start and sep tokens:
start = 0
sep = 2
input_ids_with_context["inference_left"] = []
for i, row in enumerate(target_data):
    context = row["input_ids"] #first add the original sentence
    # remove start and sep tokens:
    context = [x for x in context if x not in [start, sep]]
    # is it already too long?
    if len(context) >= target_context_size:
        # truncate it
        context = context[0:target_context_size]
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["inference_left"].append(final)
    else:
        # we want to add context around the sentence until we reach 200 tokens total
        j = 1
        while(True):
            # how many tokens can we still add?
            tokens_left = target_context_size - len(context)
            # check variable:
            check = True
            # can we add from the row before?
            if (target_data[i-j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i-j]["q_sentence_nr"]+j):
                check = False
                # get the max tokens from the sentence before
                before_tokens = target_data[i-j]["input_ids"][-tokens_left:]
                # remove start and sep tokens:
                before_tokens = [x for x in before_tokens if x not in [start, sep]]
                # add these at the beginning of the context
                context[:0] = before_tokens
                # update tokens_left
                tokens_left = target_context_size - len(context)
            # can we add from the row after?
            # first make sure that row actually exists:
            if i+j < target_data.shape[0]:
                if (target_data[i+j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i+j]["q_sentence_nr"]-j):
                    check = False
                    # get the max tokens from the sentence after
                    after_tokens = target_data[i+j]["input_ids"][0:tokens_left]
                    # remove start and sep tokens:
                    after_tokens = [x for x in after_tokens if x not in [start, sep]]
                    # add these at the end of the context
                    context.extend(after_tokens)
                    # update tokens_left
                    tokens_left = target_context_size - len(context)
            if check is True:
                # we could not add from before and after, so just quit. Context will thus be smaller
                print("check for i = ", i)
                break
            # have we reached the limit?
            if tokens_left <= 1:
                # truncate the context (this always cuts off from the end, maybe we dont want this?) But guarantees that the original sentence remains
                context = context[0:target_context_size]
                break
            j = j+1
        # context should now always contain 200 tokens
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["inference_left"].append(final)
    # if i == 10:
    #     break

In [None]:
# modify INFERENCE

target_context_size = 201 #tokens
target_data = tokenized_datasets["inference_center"]
# set start and sep tokens:
start = 0
sep = 2
input_ids_with_context["inference_center"] = []
for i, row in enumerate(target_data):
    context = row["input_ids"] #first add the original sentence
    # remove start and sep tokens:
    context = [x for x in context if x not in [start, sep]]
    # is it already too long?
    if len(context) >= target_context_size:
        # truncate it
        context = context[0:target_context_size]
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["inference_center"].append(final)
    else:
        # we want to add context around the sentence until we reach 200 tokens total
        j = 1
        while(True):
            # how many tokens can we still add?
            tokens_left = target_context_size - len(context)
            # check variable:
            check = True
            # can we add from the row before?
            if (target_data[i-j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i-j]["q_sentence_nr"]+j):
                check = False
                # get the max tokens from the sentence before
                before_tokens = target_data[i-j]["input_ids"][-tokens_left:]
                # remove start and sep tokens:
                before_tokens = [x for x in before_tokens if x not in [start, sep]]
                # add these at the beginning of the context
                context[:0] = before_tokens
                # update tokens_left
                tokens_left = target_context_size - len(context)
            # can we add from the row after?
            # first make sure that row actually exists:
            if i+j < target_data.shape[0]:
                if (target_data[i+j]["manifesto_id"] == row["manifesto_id"]) & (row["q_sentence_nr"] == target_data[i+j]["q_sentence_nr"]-j):
                    check = False
                    # get the max tokens from the sentence after
                    after_tokens = target_data[i+j]["input_ids"][0:tokens_left]
                    # remove start and sep tokens:
                    after_tokens = [x for x in after_tokens if x not in [start, sep]]
                    # add these at the end of the context
                    context.extend(after_tokens)
                    # update tokens_left
                    tokens_left = target_context_size - len(context)
            if check is True:
                # we could not add from before and after, so just quit. Context will thus be smaller
                print("check for i = ", i)
                break
            # have we reached the limit?
            if tokens_left <= 1:
                # truncate the context (this always cuts off from the end, maybe we dont want this?) But guarantees that the original sentence remains
                context = context[0:target_context_size]
                break
            j = j+1
        # context should now always contain 200 tokens
        # add sep token at beginning and end
        context[:0] = [sep]
        context.append(sep)
        # and now add all this context to the original input_ids:
        final = row["input_ids"]+context
        input_ids_with_context["inference_center"].append(final)
    # if i == 10:
    #     break

In [None]:
tokenized_datasets

In [None]:
# create a new dataset with the new input_ids and new masks:
new_datasets = DatasetDict()
# adjust as necessary:
#for part in ["train", "test"]:
for part in ["train", "test", "validation", "inference_left", "inference_center"]:
    new_at_mask = [[1 for y in x] for x in input_ids_with_context[part]]
    modified_data = {}
    for feature_name in tokenized_datasets[part].features:
        if feature_name == 'input_ids':
            # Replace input_ids with modified_input_ids
            modified_data['input_ids'] = input_ids_with_context[part]
        elif feature_name == 'attention_mask':
            # Use new attention mask
            modified_data['attention_mask'] = new_at_mask
        else:
            # Copy other features as they are
            modified_data[feature_name] = tokenized_datasets[part][feature_name]
    # Create a new Dataset object with modified data
    new_datasets[part] = Dataset.from_dict(modified_data)


In [None]:
new_datasets

In [None]:
# push the modified dataset to Hugginface:
new_datasets.push_to_hub(dataset_name+"_context", private=True, token = auth_token)

In [None]:
# making sure it works
tokenizer.decode(new_datasets["inference_left"]["input_ids"][1])

In [None]:
# This is how to load the dataset then again:
data_loaded = load_dataset(dataset_name+"_context", token=auth_token)