In [1]:
import torch
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaTokenizer, RobertaForSequenceClassification

# Load the saved model
model_path = "Models/roberta_classifier.pth"
model = RobertaForSequenceClassification.from_pretrained('roberta-base', num_labels=3)
model.load_state_dict(torch.load(model_path))
model.eval()

# Ensure the model is on the correct device
device = torch.device("mps")
model.to(device)

# Custom dataset class for inference
class CustomDataset(Dataset):
    def __init__(self, texts, tokenizer, max_len):
        self.texts = texts
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            max_length=self.max_len,
            add_special_tokens=True,
            truncation=True,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten()
        }

# Load data
df = pd.read_csv("Data/persuasion_data.csv")  
arguments = df["argument"].tolist()

# Tokenizer
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')

# Create dataset and dataloader for inference
max_len_tokenizer = 512
inference_dataset = CustomDataset(arguments, tokenizer, max_len_tokenizer)
inference_loader = DataLoader(inference_dataset, batch_size=4, shuffle=False)

# Make predictions
predictions = []

with torch.no_grad():
    for batch in inference_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=-1)
        preds = preds + 1  # Convert predictions to 1, 2, 3
        predictions.extend(preds.cpu().numpy())

# Add predictions to DataFrame
df["predictions"] = predictions

# Save the DataFrame with predictions
#output_path = "Data/arguments_to_annotate_preds_classification.csv"
#output_path = "Data/persuasion_data_preds_classification.csv"
#df.to_csv(output_path, index=False)

df.head()


  from .autonotebook import tqdm as notebook_tqdm
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Unnamed: 0,worker_id,claim,argument,source,prompt_type,rating_initial,rating_final,persuasiveness_metric,predictions
0,PQVTZECGNK3K,Governments and technology companies must do m...,It's time for governments and tech companies t...,Claude 2,Expert Writer Rhetorics,7 - Strongly support,7 - Strongly support,0,2
1,3KTT9HNPV9WX,Governments and technology companies must do m...,"In today's hyper-connected world, our personal...",Claude 3 Haiku,Expert Writer Rhetorics,7 - Strongly support,7 - Strongly support,0,2
2,M76GMRF46C69,Cultured/lab-grown meats should be allowed to ...,The future of food must include cultured/lab-g...,Claude 2,Compelling Case,3 - Somewhat oppose,5 - Somewhat support,2,2
3,3W4KKCTPTP7R,Social media companies should be required to l...,Social media companies should be required to l...,Claude 2,Compelling Case,3 - Somewhat oppose,6 - Support,3,2
4,QQDKMRY3HRXJ,Employers should be allowed to monitor employe...,Allowing employers to monitor employees throug...,Claude 3 Opus,Logical Reasoning,5 - Somewhat support,5 - Somewhat support,0,2


In [9]:
df["predictions"].value_counts

<bound method IndexOpsMixin.value_counts of 0       2
1       2
2       2
3       2
4       2
       ..
3934    2
3935    2
3936    2
3937    2
3938    2
Name: predictions, Length: 3939, dtype: int64>