In [30]:
import torch
import pandas as pd

In [98]:
class VerafilesDataset(torch.utils.data.Dataset):
    def __init__(self, root, transform=None):
        super().__init__()
        
        self.root = root
        self.transform = transform
        self.df = pd.read_csv(root)

        self.idx2label = {k:v for k,v in enumerate(self.df["label"].unique())}
        self.label2idx = {v:k for k,v in self.idx2label.items()}

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        text = row["text"]
        label = self.label2idx[row["label"]]
        return {
            "text": text,
            "label": label
        }

In [92]:
class XFactDataset(torch.utils.data.Dataset):
    def __init__(self, root, examples=20, transform=None):
        super().__init__()
        
        self.root = root
        self.transform = transform
        self.df = pd.read_csv(root, sep="\t")

        self.idx2label = {k:v for k,v in enumerate(self.df["label"].unique())}
        self.label2idx = {v:k for k,v in self.idx2label.items()}

        self.df = pd.concat([self.df[self.df["label"] == key].head(examples) for key in self.label2idx.keys()], ignore_index=True)

    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        text = row["claim"]
        label = self.label2idx[row["label"]]
        return {
            "text": text,
            "label": label
        }

In [33]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")

In [34]:
def collate_fn(batch):

    tokenized = tokenizer([x["text"] for x in batch])

    return {
        "input_ids": tokenized["input_ids"],
        "attention_mask": tokenized["attention_mask"],
        "labels": torch.tensor([x["label"] for x in batch]),
    }

In [99]:
vf_ds = VerafilesDataset("vefafiles.csv", transform=None)
vf_dl = torch.utils.data.DataLoader(vf_ds, collate_fn=collate_fn, batch_size=8)

In [100]:
xf_ds = XFactDataset("xlm_fakenews/english_test.tsv")
xf_dl = torch.utils.data.DataLoader(xf_ds, collate_fn=collate_fn, batch_size=8)