In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from datasets import load_dataset

In [2]:
def preprocess_function(batch):
    responses = [str(response) if response else "" for response in batch["s1"]]
    labels = [str(label) if label else "" for label in batch["s2"]]
    return tokenizer(responses, labels, padding="max_length", max_length=512, truncation=False)

In [16]:
device = torch.device("cpu") #"cuda:0" if torch.cuda.is_available() else
model_name = 'roberta-base'
path_to_fine_tuned_model = './roberta_finetuned_models/checkpoint-2500_roberta'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(path_to_fine_tuned_model).to(device)



In [24]:
test_data = load_dataset('json', data_files={'test':'blocksworld_topn.jsonl'})
test_data = test_data['test'].map(preprocess_function, batched=True)
test_data.set_format(type='torch', columns=['input_ids', 'attention_mask'])

# Initialize DataLoader for test dataset
batch_size = 128  # Define the batch size for inference
test_dataloader = DataLoader(test_data, batch_size=batch_size)

In [26]:
# get the predictions and probabilities
predictions = []
probabilities = []
model.eval()
for batch in test_dataloader:
    with torch.no_grad():
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probs = torch.softmax(logits, dim=-1)
        predictions.extend(torch.argmax(probs, dim=-1).cpu().numpy())
        probabilities.extend(probs.cpu().numpy())
        
        print(predictions)
        print(probabilities)

[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]
[array([9.9997330e-01, 2.6747693e-05], dtype=float32), array([9.9997211e-01, 2.7922026e-05], dtype=float32), array([1.1189043e-04, 9.9988806e-01], dtype=float32), array([9.999721e-01, 2.795240e-05], dtype=float32), array([9.9997139e-01, 2.8605536e-05], dtype=float32), array([9.9997282e-01, 2.7151438e-05], dtype=float32), array([9.9997151e-01, 2.8447696e-05], dtype=float32), array([9.9991524e-01, 8.4726162e-05], dtype=float32), array([9.999664e-01, 3.361464e-05], dtype=float32), array([9.999597e-01, 4.025226e-05], dtype=float32), array([9.9996269e-01, 3.7311496e-05], dtype=float32), array([9.9996293e-01, 3.71

In [27]:
predictions

[0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 0]