In [1]:
%pip install transformers -U datasets

Note: you may need to restart the kernel to use updated packages.


In [2]:
from datasets import load_dataset
import pandas as pd
import torch
import os
from tqdm import tqdm

data = load_dataset("rajpurkar/squad")
# ("gips-mai/enc_descr")



In [3]:
data

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 87599
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 10570
    })
})

In [4]:
data["train"][0]

{'id': '5733be284776f41900661182',
 'title': 'University_of_Notre_Dame',
 'context': 'Architecturally, the school has a Catholic character. Atop the Main Building\'s gold dome is a golden statue of the Virgin Mary. Immediately in front of the Main Building and facing it, is a copper statue of Christ with arms upraised with the legend "Venite Ad Me Omnes". Next to the Main Building is the Basilica of the Sacred Heart. Immediately behind the basilica is the Grotto, a Marian place of prayer and reflection. It is a replica of the grotto at Lourdes, France where the Virgin Mary reputedly appeared to Saint Bernadette Soubirous in 1858. At the end of the main drive (and in a direct line that connects through 3 statues and the Gold Dome), is a simple, modern stone statue of Mary.',
 'question': 'To whom did the Virgin Mary allegedly appear in 1858 in Lourdes France?',
 'answers': {'text': ['Saint Bernadette Soubirous'], 'answer_start': [515]}}

In [5]:
from datasets import Dataset

def encode_dataset(df, tokenizer, model):
    encodings = []

    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Encoding rows"):
        identifier = row[0]
        discription = row[1]

        # Tokenize input and propagate through model
        inputs = tokenizer(discription, return_tensors='pt', truncation=True).to('cuda')
        with torch.no_grad():
            outputs = model(**inputs)

        # The outputs include the last hidden state, pooler output, and hidden states
        # We are interested in the last hidden state
        encoding = outputs.last_hidden_state.cpu().numpy()  # Convert tensor to numpy
        encodings.append({"id": identifier, "enc": encoding})

    return Dataset.from_list(encodings)

In [6]:
from transformers import RobertaModel, RobertaTokenizer

# init

tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
model = RobertaModel.from_pretrained('roberta-base').cuda()

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [7]:
# load raw data
df = pd.read_csv('formated_results/results_all_formated.csv')

# encode dataset
enc_data = encode_dataset(df, tokenizer, model)

  identifier = row[0]
  discription = row[1]
  identifier = row[0]
  discription = row[1]
Encoding rows:  45%|████▍     | 89824/201008 [15:13<19:02, 97.35it/s] 

In [None]:
enc_data_dict =  enc_data.train_test_split(test_size=0.3)

In [None]:
# publish to hub

%pip install huggingface_hub

Note: you may need to restart the kernel to use updated packages.


In [None]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
enc_data_dict.push_to_hub("gips-mai/descriptions_enc")

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

CommitInfo(commit_url='https://huggingface.co/datasets/gips-mai/descriptions_enc/commit/b25229adecca31eec7933101f0388db12f0d26ac', commit_message='Upload dataset', commit_description='', oid='b25229adecca31eec7933101f0388db12f0d26ac', pr_url=None, pr_revision=None, pr_num=None)