* the base of this modell https://github.com/openai/gpt-2-output-dataset
* you should run this notebook from the root

In [2]:
from transformers import RobertaForSequenceClassification, RobertaTokenizer
import torch

In [3]:
# evaluate the ipnut 'query' with a pytorch model, in 'device'
def eval(query, tokenizer, model, device):
    
    tokens = tokenizer.encode(query)
    all_tokens = len(tokens)
    #RoBERTa is based on BERT so it can't handle longer sequences then 512
    tokens = tokens[:tokenizer.max_len - 2]
    used_tokens = len(tokens)
    # add begin of the sentence, and end of sentenc tokens
    tokens = torch.tensor([tokenizer.bos_token_id] + tokens + [tokenizer.eos_token_id]).unsqueeze(0)
    mask = torch.ones_like(tokens)

    with torch.no_grad():
        tokens = tokens.to(device)
        attention_mask=mask.to(device)
        logits = model(tokens, attention_mask)[0]
        probs = logits.softmax(dim=-1)

    fake, real = probs.detach().cpu().flatten().numpy().tolist()

    if(fake>real):
        return("fake")
    else:
        return("real")


In [4]:
checkpoint = "./models/base/detector-base.pt"

if torch.cuda.is_available():
    device='cuda'
else:
    device='cpu'

# load the model 
data = torch.load(checkpoint, map_location='cpu')
model_name = 'roberta-large' if data['args']['large'] else 'roberta-base'
model = RobertaForSequenceClassification.from_pretrained(model_name)
tokenizer = RobertaTokenizer.from_pretrained(model_name)
model.load_state_dict(data['model_state_dict'])
model.eval()
model = model.to(device)

### Testing

In [11]:
# if machine is true we read the machine vs real task, else we read the fake vs real task

ef read_data(machine):
    if(machine):
        with open("../test_data/gpt2_generated.txt", "r") as gpt2, open("../test_data/grover_generated.txt", "r") as grover:
            X_test = [line for line in gpt2]+[line for line in grover]
        
        with open("../test_data/x_test.txt") as data, open("../test_data/y_test.txt") as label_file:
            index = 0
            labels = [label.strip() for label in label_file]
            for i, line in enumerate(data):
                if(index==180):
                    break
                if(labels[i]=="0"):
                    X_test.append(line)
                    index += 1
        Y_test = ["fake" if i < 180 else "real" for i in range(360)]
    else:
        with open("../test_data/x_test.txt") as f:
            X_test = [line for line in f]
        with open("../test_data/y_test.txt") as f:
            Y_test = ["fake" if line.strip() == "1" else "real" for line in f]
    return X_test, Y_test

In [21]:
X_test, Y_test = read_data(False)

In [22]:
%%capture
from sklearn.metrics import classification_report

result = []
for article in X_test:
    result.append(eval(article, tokenizer, model, device))

In [23]:
# Fake vs real news
print(classification_report(Y_test, result))

              precision    recall  f1-score   support

        fake       0.60      0.12      0.20      4947
        real       0.52      0.92      0.66      5053

    accuracy                           0.53     10000
   macro avg       0.56      0.52      0.43     10000
weighted avg       0.56      0.53      0.44     10000



In [20]:
# Machine vs real news
# print(classification_report(Y_test, result))

              precision    recall  f1-score   support

        fake       0.90      0.54      0.68       180
        real       0.67      0.94      0.78       180

    accuracy                           0.74       360
   macro avg       0.79      0.74      0.73       360
weighted avg       0.79      0.74      0.73       360



In [18]:
# Only GPT-2 generated, and real news
#print(classification_report(Y_test[:100]+Y_test[180:], result))

              precision    recall  f1-score   support

        fake       0.90      0.98      0.94       100
        real       0.99      0.94      0.96       180

    accuracy                           0.95       280
   macro avg       0.94      0.96      0.95       280
weighted avg       0.96      0.95      0.95       280

