In [None]:
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset
import pandas as pd
import ast

In [None]:
data_path = 'data/ted/en.tsv'

In [None]:
model_name = 'microsoft/deberta-v3-xsmall'

In [None]:
model = AutoModel.from_pretrained(model_name, output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False, return_tensors='pt')

In [None]:
dataset = load_dataset('csv', data_files=data_path, sep='\t', converters={'sentences': ast.literal_eval})

In [None]:
def preprocess_dataset(dataset, tokenizer):
    dataset = dataset['train'].remove_columns('sub_sentences')

    def concatenate_sentences(example):
        example['sentences'] = ' '.join(example['sentences'])
        return example
    
    dataset = dataset.map(concatenate_sentences, 
                          desc='Concatenatings passage sentences.')

    def tokenize_dataset(examples):
        tokenized_texts = tokenizer(examples['sentences'], padding=True, max_length=256, truncation=True)
        return tokenized_texts

    tokenized_dataset = dataset.map(
            tokenize_dataset,
            batched=True,
            remove_columns=dataset.column_names,
            desc="Running tokenizer on dataset",
        )

    return tokenized_dataset

In [None]:
dataset = preprocess_dataset(dataset, tokenizer)

In [None]:
from transformers import default_data_collator
from torch.utils.data import DataLoader
import torch

In [None]:
device = "cuda"#torch.cuda.is_available()

In [None]:
data_collator = default_data_collator

dataloader = DataLoader(dataset, shuffle=False, collate_fn=data_collator, batch_size=8)

In [None]:
model.config

In [None]:
model.to(device)
model.eval()
with torch.no_grad():
    for step, batch in enumerate(dataloader):
        batch = {k: v.to('cuda') for k, v in batch.items()}
        # num_layers, batch_size, max_seq_len (max 215), hidden_size
        hidden_states = model(**batch)['hidden_states']
        non_pad_tokens = batch['attention_mask'].sum(axis=1)
        # batch_size, num_layers, max_seq_len, hidden_size
        hidden_states = torch.stack(hidden_states, dim=1)
        for batch_idx in range(hidden_states.shape[0]):
            passage_hidden_states = hidden_states[batch_idx, :, :non_pad_tokens[batch_idx], :]
            

In [None]:
passage_hidden_states.shape

In [None]:
hidden_states