In [None]:
import os
import torch
import pandas as pd
from longformer.longformer import Longformer, LongformerConfig
from longformer.sliding_chunks import pad_to_window_size
from transformers import RobertaTokenizer

In [6]:
root_dir = '../../data/longformer/patentsview/'

In [None]:
config = LongformerConfig.from_pretrained(root_dir + 'longformer-base-4096/') 
# choose the attention mode 'n2', 'tvm' or 'sliding_chunks'
# 'n2': for regular n2 attantion
# 'tvm': a custom CUDA kernel implementation of our sliding window attention
# 'sliding_chunks': a PyTorch implementation of our sliding window attention
config.attention_mode = 'sliding_chunks'

In [3]:
model = Longformer.from_pretrained(root_dir + 'longformer-base-4096/', config=config)
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer.model_max_length = model.config.max_position_embeddings

In [7]:
df = pd.read_csv(root_dir + 'example_data.csv')

In [10]:
SAMPLE_TEXT = df.iloc[0,1]
input_ids = torch.tensor(tokenizer.encode(SAMPLE_TEXT)).unsqueeze(0)  # batch of size 1

# TVM code doesn't work on CPU. Uncomment this if `config.attention_mode = 'tvm'`
# model = model.cuda(); input_ids = input_ids.cuda()

In [11]:
# Attention mask values -- 0: no attention, 1: local attention, 2: global attention
attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=input_ids.device) # initialize to local attention
#attention_mask[:, [1, 4, 21,]] =  2  # Set global attention based on the task. For example,
                                     # classification: the <s> token
                                     # QA: question tokens

In [13]:
# padding seqlen to the nearest multiple of 512. Needed for the 'sliding_chunks' attention
input_ids, attention_mask = pad_to_window_size(
        input_ids, attention_mask, config.attention_window[0], tokenizer.pad_token_id)

In [16]:
output = model(input_ids, attention_mask=attention_mask)

In [27]:
output[0][0]

tensor([[-0.0717,  0.0684, -0.0172,  ..., -0.0503, -0.0049, -0.0549],
        [-0.1447, -0.0408, -0.1058,  ..., -0.2911,  0.0543,  0.0252],
        [-0.0500,  0.1325,  0.1291,  ..., -0.3074,  0.2530,  0.0613],
        ...,
        [-0.0236,  0.0741, -0.0145,  ..., -0.0990, -0.0409, -0.0745],
        [-0.0236,  0.0741, -0.0145,  ..., -0.0990, -0.0409, -0.0745],
        [-0.0236,  0.0741, -0.0145,  ..., -0.0990, -0.0409, -0.0745]],
       grad_fn=<SelectBackward0>)

In [None]:
#Github Dev Test