# Journalism Guidance + LM

## Requirements

In [None]:
!pip install transformers[sentencepiece]

Collecting sentencepiece!=0.1.92,>=0.1.91 (from transformers[sentencepiece])
  Downloading sentencepiece-0.1.99-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m11.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99


In [None]:
!pip install nltk

!pip install py-readability-metrics
!python -m nltk.downloader punkt

Collecting py-readability-metrics
  Downloading py_readability_metrics-1.4.5-py3-none-any.whl (26 kB)
Installing collected packages: py-readability-metrics
Successfully installed py-readability-metrics-1.4.5
[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


In [None]:
!pip install lexicalrichness

Collecting lexicalrichness
  Downloading lexicalrichness-0.5.1.tar.gz (97 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m97.8/97.8 kB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: lexicalrichness
  Building wheel for lexicalrichness (setup.py) ... [?25l[?25hdone
  Created wheel for lexicalrichness: filename=lexicalrichness-0.5.1-py3-none-any.whl size=15414 sha256=b680e061b9f54a91f5edf31042d6a77c7b84e6f41add6bcf60592a0b39c5469f
  Stored in directory: /root/.cache/pip/wheels/cd/ba/80/d4dabc1bf242a672ffc00226a2303a7471bb841c0872b2c212
Successfully built lexicalrichness
Installing collected packages: lexicalrichness
Successfully installed lexicalrichness-0.5.1


In [None]:
!pip install numpy==1.19.4
!pip install PyYAML>=5.4
!pip install spacy==2.2.4
!pip install torch==1.7.0
!pip install torchtext==0.3.1
!pip install tqdm==4.53.0
!pip install pandas==1.1.5
!pip install transformers==4.3.2
!pip install fire==0.4.0
!pip install requests==2.23.0
!pip install tensorboard==2.4.1
!pip install download==0.3.5
!pip install nltk>=3.6.6

# !pip install py-readability-metrics
!python -m nltk.downloader punkt
# !pip install lexicalrichness

## Utils

In [None]:
import sys
from functools import reduce

from torch import nn
import torch.distributed as dist


def summary(model: nn.Module, file=sys.stdout):
    def repr(model):
        # We treat the extra repr like the sub-module, one item per line
        extra_lines = []
        extra_repr = model.extra_repr()
        # empty string will be split into list ['']
        if extra_repr:
            extra_lines = extra_repr.split('\n')
        child_lines = []
        total_params = 0
        for key, module in model._modules.items():
            mod_str, num_params = repr(module)
            mod_str = nn.modules.module._addindent(mod_str, 2)
            child_lines.append('(' + key + '): ' + mod_str)
            total_params += num_params
        lines = extra_lines + child_lines

        for name, p in model._parameters.items():
            if hasattr(p, 'shape'):
                total_params += reduce(lambda x, y: x * y, p.shape)

        main_str = model._get_name() + '('
        if lines:
            # simple one-liner info, which most builtin Modules will use
            if len(extra_lines) == 1 and not child_lines:
                main_str += extra_lines[0]
            else:
                main_str += '\n  ' + '\n  '.join(lines) + '\n'

        main_str += ')'
        if file is sys.stdout:
            main_str += ', \033[92m{:,}\033[0m params'.format(total_params)
        else:
            main_str += ', {:,} params'.format(total_params)
        return main_str, total_params

    string, count = repr(model)
    if file is not None:
        if isinstance(file, str):
            file = open(file, 'w')
        print(string, file=file)
        file.flush()

    return count


def grad_norm(model: nn.Module):
    total_norm = 0
    for p in model.parameters():
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
    return total_norm ** 0.5


def distributed():
    return dist.is_available() and dist.is_initialized()

## Data Loader

In [None]:
import json
from typing import List

import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from transformers import PreTrainedTokenizer

import re
import unicodedata

import nltk
from nltk.corpus import stopwords
from nltk.tag import pos_tag
# from pycontractions import Contractions
nltk.download('averaged_perceptron_tagger')
nltk.download('stopwords')
nltk.download('wordnet')


CONTRACTION_MAP = { "ain't": "is not",
                    "aren't": "are not",
                    "can't": "cannot",
                    "can't've": "cannot have",
                    "'cause": "because",
                    "could've": "could have",
                    "couldn't": "could not",
                    "couldn't've": "could not have",
                    "didn't": "did not",
                    "doesn't": "does not",
                    "don't": "do not",
                    "hadn't": "had not",
                    "hadn't've": "had not have",
                    "hasn't": "has not",
                    "haven't": "have not",
                    "he'd": "he would",
                    "he'd've": "he would have",
                    "he'll": "he will",
                    "he'll've": "he he will have",
                    "he's": "he is",
                    "how'd": "how did",
                    "how'd'y": "how do you",
                    "how'll": "how will",
                    "how's": "how is",
                    "I'd": "I would",
                    "I ain't": "I am not",
                    "I'd've": "I would have",
                    "I'll": "I will",
                    "I'll've": "I will have",
                    "I'm": "I am",
                    "I've": "I have",
                    "i'd": "i would",
                    "i'd've": "i would have",
                    "i'll": "i will",
                    "i'll've": "i will have",
                    "i'm": "i am",
                    "i've": "i have",
                    "isn't": "is not",
                    "it'd": "it would",
                    "it'd've": "it would have",
                    "it'll": "it will",
                    "it'll've": "it will have",
                    "it's": "it is",
                    "let's": "let us",
                    "ma'am": "madam",
                    "mayn't": "may not",
                    "might've": "might have",
                    "mightn't": "might not",
                    "mightn't've": "might not have",
                    "must've": "must have",
                    "mustn't": "must not",
                    "mustn't've": "must not have",
                    "needn't": "need not",
                    "needn't've": "need not have",
                    "o'clock": "of the clock",
                    "oughtn't": "ought not",
                    "oughtn't've": "ought not have",
                    "shan't": "shall not",
                    "sha'n't": "shall not",
                    "shan't've": "shall not have",
                    "she'd": "she would",
                    "she'd've": "she would have",
                    "she'll": "she will",
                    "she'll've": "she will have",
                    "she's": "she is",
                    "should've": "should have",
                    "shouldn't": "should not",
                    "shouldn't've": "should not have",
                    "so've": "so have",
                    "so's": "so as",
                    "that'd": "that would",
                    "that'd've": "that would have",
                    "that's": "that is",
                    "there'd": "there would",
                    "there'd've": "there would have",
                    "there's": "there is",
                    "they'd": "they would",
                    "they'd've": "they would have",
                    "they'll": "they will",
                    "they'll've": "they will have",
                    "they're": "they are",
                    "they've": "they have",
                    "to've": "to have",
                    "wasn't": "was not",
                    "we'd": "we would",
                    "we'd've": "we would have",
                    "we'll": "we will",
                    "we'll've": "we will have",
                    "we're": "we are",
                    "we've": "we have",
                    "weren't": "were not",
                    "what'll": "what will",
                    "what'll've": "what will have",
                    "what're": "what are",
                    "what's": "what is",
                    "what've": "what have",
                    "when's": "when is",
                    "when've": "when have",
                    "where'd": "where did",
                    "where's": "where is",
                    "where've": "where have",
                    "who'll": "who will",
                    "who'll've": "who will have",
                    "who's": "who is",
                    "who've": "who have",
                    "why's": "why is",
                    "why've": "why have",
                    "will've": "will have",
                    "won't": "will not",
                    "won't've": "will not have",
                    "would've": "would have",
                    "wouldn't": "would not",
                    "wouldn't've": "would not have",
                    "y'all": "you all",
                    "y'all'd": "you all would",
                    "y'all'd've": "you all would have",
                    "y'all're": "you all are",
                    "y'all've": "you all have",
                    "you'd": "you would",
                    "you'd've": "you would have",
                    "you'll": "you will",
                    "you'll've": "you will have",
                    "you're": "you are",
                    "you've": "you have"
                    }


class PreProcess:
    def __init__(self, lowercase_norm=False, period_norm=False, special_chars_norm=False, accented_norm=False, contractions_norm=False,
                 stemming_norm=False, lemma_norm=False, stopword_norm=False, proper_norm=False):

        self.lowercase_norm = lowercase_norm
        self.period_norm = period_norm
        self.special_chars_norm = special_chars_norm
        self.accented_norm = accented_norm
        self.contractions_norm = contractions_norm
        self.stemming_norm = stemming_norm
        self.lemma_norm = lemma_norm
        self.stopword_norm = stopword_norm
        self.proper_norm = proper_norm

    def lowercase_normalization(self, data):

        return data.lower()

    def period_remove(self, data):

        return data.replace(".", " ")

    def special_char_remove(self, data, remove_digits=False):  # Remove special characters
        tokens = self.tokenization(data)
        special_char_norm_data = []

        for token in tokens:
            sentence = ""
            for word in token:
                sentence += word + " "
            sentence.rstrip()

            clean_remove = re.compile('<.*?>')
            norm_sentence = re.sub(clean_remove, '', sentence)

            norm_sentence = re.sub(r'[^\x00-\x7F]+','', norm_sentence)
            norm_sentence = norm_sentence.replace("\\", "")
            norm_sentence = norm_sentence.replace("-", " ")
            norm_sentence = norm_sentence.replace(",", "")
            special_char_norm_data.append(norm_sentence)

        return special_char_norm_data

    def accented_word_normalization(self, data):  # Normalize accented chars/words
        tokens = self.tokenization(data)
        accented_norm_data = []

        for token in tokens:
            sentence = ""
            for word in token:
                sentence += word + " "
            sentence.rstrip()
            norm_sentence = unicodedata.normalize('NFKD', sentence).encode('ascii', 'ignore').decode('utf-8', 'ignore')

            accented_norm_data.append(norm_sentence)

        return accented_norm_data

    def expand_contractions(self, data, pycontrct=False):  # Expand contractions

        # Simple contraction removal based on pre-defined set of contractions
        contraction_mapping = CONTRACTION_MAP
        contractions_pattern = re.compile('({})'.format('|'.join(contraction_mapping.keys())),
                                          flags=re.IGNORECASE | re.DOTALL)

        def expand_match(contraction):
            match = contraction.group(0)
            first_char = match[0]
            expanded_contraction = contraction_mapping.get(match) \
                if contraction_mapping.get(match) \
                else contraction_mapping.get(match.lower())
            expanded_contraction = first_char + expanded_contraction[1:]
            return expanded_contraction

        tokens = self.tokenization(data)
        contraction_norm_data = []

        for token in tokens:
            sentence = ""
            for word in token:
                sentence += word + " "
            sentence.rstrip()

            expanded_text = contractions_pattern.sub(expand_match, sentence)
            expanded_text = re.sub("'", "", expanded_text)

            contraction_norm_data.append(expanded_text)

        return contraction_norm_data

    def stemming(self, data):
        stemmer = nltk.stem.PorterStemmer()
        tokens = self.tokenization(data)
        stemmed_data = []

        for i in range(len(tokens)):
            s1 = " ".join(stemmer.stem(tokens[i][j]) for j in range(len(tokens[i])))
            stemmed_data.append(s1)

        return stemmed_data

    def lemmatization(self, data):
        lemma = nltk.stem.WordNetLemmatizer()
        tokens = self.tokenization(data)
        lemmatized_data = []

        for i in range(len(tokens)):
            s1 = " ".join(lemma.lemmatize(tokens[i][j]) for j in range(len(tokens[i])))
            lemmatized_data.append(s1)

        return lemmatized_data

    def stopword_remove(self, data):  # Remove special characters
        filtered_sentence = []
        stop_words = set(stopwords.words('english'))
        data = self.tokenization(data)

        for i in range(len(data)):
            res = ""
            for j in range(len(data[i])):
                if data[i][j].lower() not in stop_words:
                    res = res + " " + data[i][j]
            filtered_sentence.append(res)

        return filtered_sentence

    def remove_proper_nouns(self, data):
        common_words = []
        data = self.tokenization(data)
        for i in range(len(data)):
            tagged_sent = pos_tag(data[i])
            proper_nouns = [word for word, pos in tagged_sent if pos == 'NNP']
            res = ""
            for j in range(len(data[i])):
                if data[i][j] not in proper_nouns:
                    res = res + " " + data[i][j]
            common_words.append(res)

        return common_words

    def tokenization(self, data):
        tokens = []
        for i in range(len(data)):
            tokenizer = nltk.tokenize.WhitespaceTokenizer()
            tokens.append(tokenizer.tokenize(data[i]))
        return tokens

    def fit(self, data):

        data = [str(data)]

        if self.special_chars_norm:
            data = self.special_char_remove(data, remove_digits=False)

        # if self.contractions_norm:
        #     data = self.expand_contractions(data)

        if self.accented_norm:
            data = self.accented_word_normalization(data)

        if self.stemming_norm:
            data = self.stemming(data)

        if self.proper_norm:
            data = self.remove_proper_nouns(data)

        if self.stopword_norm:
            data = self.stopword_remove(data)

        if self.lemma_norm:
            data = self.lemmatization(data)

        data = data[0]

        if self.lowercase_norm:
            data = self.lowercase_normalization(str(data))

        if self.period_norm:
            data = self.period_remove(str(data))

        return data


class EncodedDataset(Dataset):
    def __init__(self, texts: List[str], labels: List[int], stylo_features: List[List], tokenizer: PreTrainedTokenizer,
                 max_sequence_length: int = None, min_sequence_length: int = None):
        self.texts = texts
        self.labels = labels
        self.stylo_features = stylo_features
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        self.min_sequence_length = min_sequence_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, index):

      text = self.texts[index]

      label = self.labels[index]

      style_feat = self.stylo_features[index].tolist()

      preprocessor = PreProcess(special_chars_norm=True, lowercase_norm=True, period_norm=True, proper_norm=True, accented_norm=True)

      text = preprocessor.fit(text)

      padded_sequences = self.tokenizer(text, padding='max_length', max_length= self.max_sequence_length, truncation=True)

      return torch.tensor(padded_sequences['input_ids']), torch.tensor(padded_sequences['attention_mask']), torch.tensor(style_feat), int(label)


class EncodeEvalData(Dataset):
    def __init__(self, input_texts: List[str], stylo_features: List[List], tokenizer: PreTrainedTokenizer,
                 max_sequence_length: int = None, min_sequence_length: int = None):

        self.input_texts = input_texts
        self.tokenizer = tokenizer
        self.max_sequence_length = max_sequence_length
        self.min_sequence_length = min_sequence_length
        self.stylo_features = stylo_features
        # self.style_extractor= Stylometry(phraseology_features= True, diversity_features = False, punct_analysis_features = True)


    def __len__(self):
        return len(self.input_texts)

    def __getitem__(self, index):
        text = self.input_texts[index]

        stylo_features = self.stylo_features[index].tolist()
        # Preprocessing
        preprocessor = PreProcess(special_chars_norm=True, lowercase_norm=True, period_norm=True, proper_norm=True, accented_norm=True)

        text = preprocessor.fit(text)

        padded_sequences = self.tokenizer(text, padding='max_length', max_length=self.max_sequence_length, truncation=True)


        return torch.tensor(padded_sequences['input_ids']), torch.tensor(padded_sequences['attention_mask']), torch.tensor(stylo_features)


## Model Code

### Base AI Detector Model

In [None]:
import torch
from torch.nn import Softmax
from torch.nn import CrossEntropyLoss, MSELoss
from typing import Optional, Tuple

from transformers import RobertaForSequenceClassification

from transformers.modeling_outputs import SequenceClassifierOutput

from dataclasses import dataclass

from torch.nn.functional import normalize

@dataclass
class SequenceClassifierOutputWithLastLayer(SequenceClassifierOutput):

    loss: Optional[torch.FloatTensor] = None
    logits: torch.FloatTensor = None
    last_hidden_state: torch.FloatTensor = None
    hidden_states: Optional[Tuple[torch.FloatTensor]] = None
    attentions: Optional[Tuple[torch.FloatTensor]] = None


class RobertaForFusion(RobertaForSequenceClassification):
    _keys_to_ignore_on_load_missing = [r"position_ids"]

    def __init__(self, config):
        super().__init__(config)

        self.soft_max = Softmax(dim=1)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
            config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            loss = None
            if labels is not None:
                if self.num_labels == 1:
                    #  We are doing regression
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), labels.view(-1))
                else:
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))

        softmax_logits = self.soft_max(logits)

        if not return_dict:
            output = (softmax_logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutputWithLastLayer(
            loss=loss,
            logits=softmax_logits,
            last_hidden_state=sequence_output,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )



### J-Guard

In [None]:
class FusedClassifier(torch.nn.Module):
    def __init__(self, lm, device, FUSED_INPUT_SIZE):
        super(FusedClassifier, self).__init__()

        self.lm = lm

        # move to device
        self.lm.to(device)

        self.guidance_head = nn.Sequential(
            nn.Linear(FUSED_INPUT_SIZE, 1024),
            nn.ReLU(),
            nn.Linear(1024, 256)
        ).to(device)

        self.classification_head = nn.Sequential(
            nn.Linear(256, 32),
            nn.ReLU(),
            nn.Linear(32, 2),
            nn.Softmax(dim=-1)
        ).to(device)

        # the LM is already pre-trained, no need to calc grads anymore
        for param in self.lm.parameters():
            param.requires_grad = False

    def forward(self, data, custom_features):

        if len(data) < 3:
          output_dic = self.lm(data[0], attention_mask=data[1])

        else:
          output_dic = self.lm(data[0], attention_mask=data[1], labels=data[2])

        lm_emb_output = output_dic["last_hidden_state"][:, -1, :].detach()

        # append manuall features to Roberta features
        x = torch.cat((lm_emb_output, custom_features), axis=-1)
        x = normalize(x)
        c = self.guidance_head(x)

        return self.classification_head(c)

## Train Code

### J-Guard Training

In [None]:
"""Training code for the detector model"""

import argparse
import pandas as pd
import os
import subprocess
import sys
from itertools import count
from multiprocessing import Process

import torch
import torch.distributed as dist
from torch import nn
from torch.nn.parallel import DistributedDataParallel
from torch.optim import Adam
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from tqdm import tqdm
# from transformers import *
from transformers import RobertaTokenizer

from types import SimpleNamespace

torch.manual_seed(int(1000))

<torch._C.Generator at 0x7b4a696f88b0>

In [None]:
def setup_distributed(port=29500):
    if not dist.is_available() or not torch.cuda.is_available() or torch.cuda.device_count() <= 1:
        return 0, 1

    if 'MPIR_CVAR_CH3_INTERFACE_HOSTNAME' in os.environ:
        from mpi4py import MPI
        mpi_rank = MPI.COMM_WORLD.Get_rank()
        mpi_size = MPI.COMM_WORLD.Get_size()

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(port)

        dist.init_process_group(backend="nccl", world_size=mpi_size, rank=mpi_rank)
        return mpi_rank, mpi_size

    dist.init_process_group(backend="nccl", init_method="env://")
    return dist.get_rank(), dist.get_world_size()


def load_datasets(text_dir, stylo_dir, dataset_name, imp_feat, tokenizer, batch_size,
                  max_sequence_length, random_sequence_length=True):

    # data_path = text_dir+dataset_name+"/CSV/"
    # data_train = pd.read_csv(data_path+"train.csv")
    data_train = pd.read_csv('') # Read your training dataset
    # data_test = pd.read_csv(data_path+"test.csv")
    data_test = pd.read_csv('') # Read your validation dataset

    # stylo_feat_train = pd.read_csv(stylo_dir+dataset_name+"_train_feature.csv")
    stylo_feat_train = pd.read_csv('') # Read your training data stylometric features
    # stylo_feat_test= pd.read_csv(stylo_dir+dataset_name+"_test_feature.csv")
    stylo_feat_test= pd.read_csv('') # Read your validation data stylometric features

    Sampler = DistributedSampler if distributed() and dist.get_world_size() > 1 else RandomSampler

    min_sequence_length = 10 if random_sequence_length else None


    train_dataset = EncodedDataset(data_train.text.values, data_train.label.values, stylo_feat_train[imp_feat].values, tokenizer, max_sequence_length, min_sequence_length)
    train_loader = DataLoader(train_dataset, batch_size, sampler=Sampler(train_dataset), num_workers=0)

    validation_dataset = EncodedDataset(data_test.text.values, data_test.label.values, stylo_feat_test[imp_feat].values, tokenizer, max_sequence_length, min_sequence_length)
    validation_loader = DataLoader(validation_dataset, batch_size=1, sampler=Sampler(validation_dataset))

    return train_loader, validation_loader


def accuracy_sum(logits, labels):
    if list(logits.shape) == list(labels.shape) + [2]:
        # 2-d outputs
        classification = (logits[..., 0] < logits[..., 1]).long().flatten()
    else:
        classification = (logits > 0).long().flatten()
    assert classification.shape == labels.shape
    return (classification == labels).float().sum().item()


def train(model: nn.Module, optimizer, device: str, loader: DataLoader, desc='Train'):
    model.train()

    train_accuracy = 0
    train_epoch_size = 0
    train_loss = 0

    with tqdm(loader, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop:
        for texts, masks, custom_features, labels in loop:

            texts, masks, custom_features, labels = texts.to(device), masks.to(device), custom_features.to(device), labels.to(device)
            batch_size = texts.shape[0]

            optimizer.zero_grad()
            predict_label = model(data=[texts, masks, labels], custom_features = custom_features)

            loss_fct = CrossEntropyLoss()
            loss = loss_fct(predict_label, labels)

            loss.backward()
            optimizer.step()

            batch_accuracy = accuracy_sum(predict_label, labels)
            train_accuracy += batch_accuracy
            train_epoch_size += batch_size
            train_loss += loss.item() * batch_size

            loop.set_postfix(loss=loss.item(), acc=train_accuracy / train_epoch_size)

    return {
        "train/accuracy": train_accuracy,
        "train/epoch_size": train_epoch_size,
        "train/loss": train_loss
    }


def validate(model: nn.Module, device: str, loader: DataLoader, votes=1, desc='Validation'):
    model.eval()

    validation_accuracy = 0
    validation_epoch_size = 0
    validation_loss = 0

    records = [record for v in range(votes) for record in tqdm(loader, desc=f'Preloading data ... {v}',
                                                               disable=distributed() and dist.get_rank() > 0)]
    records = [[records[v * len(loader) + i] for v in range(votes)] for i in range(len(loader))]

    with tqdm(records, desc=desc, disable=distributed() and dist.get_rank() > 0) as loop, torch.no_grad():
        for example in loop:
            losses = []
            logit_votes = []

            for texts, masks, custom_features, labels in example:

              texts, masks, custom_features, labels = texts.to(device), masks.to(device), custom_features.to(device), labels.to(device)
              batch_size = texts.shape[0]

              predict_label = model(data=[texts, masks, labels], custom_features = custom_features)

              loss_fct = CrossEntropyLoss()
              loss = loss_fct(predict_label, labels)
              losses.append(loss)
              logit_votes.append(predict_label)

            loss = torch.stack(losses).mean(dim=0)
            logits = torch.stack(logit_votes).mean(dim=0)

            batch_accuracy = accuracy_sum(logits, labels)
            validation_accuracy += batch_accuracy
            validation_epoch_size += batch_size
            validation_loss += loss.item() * batch_size

            loop.set_postfix(loss=loss.item(), acc=validation_accuracy / validation_epoch_size)

    return {
        "validation/accuracy": validation_accuracy,
        "validation/epoch_size": validation_epoch_size,
        "validation/loss": validation_loss
    }


def _all_reduce_dict(d, device):
    # wrap in tensor and use reduce to gpu0 tensor
    output_d = {}
    for (key, value) in sorted(d.items()):
        tensor_input = torch.tensor([[value]]).to(device)
        # torch.distributed.all_reduce(tensor_input)
        output_d[key] = tensor_input.item()
    return output_d


def run(params):

    rank, world_size = setup_distributed()

    if params.device is None:
        device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu'

    print('rank:', rank, 'world_size:', world_size, 'device:', device)

    import torch.distributed as dist
    if distributed() and rank > 0:
        dist.barrier()

    model_name = 'roberta-large' if params.large else 'roberta-base'
    # tokenization_utils.logger.setLevel('ERROR')
    tokenizer = RobertaTokenizer.from_pretrained(model_name)
    lm = RobertaForFusion.from_pretrained(model_name).to(device)

    # Load the model from checkpoints
    if params.load_from_checkpoint:
        if device == "cpu":
            lm.load_state_dict(torch.load((params.checkpoint_dir + '{}.pt').format(params.checkpoint_name),
                                             map_location='cpu')['model_state_dict'])
        else:
            lm.load_state_dict(
                torch.load((params.checkpoint_dir + '{}.pt').format(params.checkpoint_name))['model_state_dict'])


    model = FusedClassifier(lm=lm, device=device, FUSED_INPUT_SIZE=params.FUSED_INPUT_SIZE)

    if rank == 0:
        summary(model)
        if distributed():
            dist.barrier()

    if world_size > 1:
        model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True)

    train_loader, validation_loader = load_datasets(params.text_dir, params.stylo_dir, params.dataset_name, params.imp_feat, tokenizer, params.batch_size, params.max_sequence_length)

    optimizer = Adam(model.parameters(), lr=params.learning_rate, weight_decay=params.weight_decay)
    epoch_loop = count(1) if params.max_epochs is None else range(1, params.max_epochs + 1)

    best_validation_accuracy = 0
    without_progress = 0
    earlystop_epochs = 3

    for epoch in epoch_loop:
        if world_size > 1:
            train_loader.sampler.set_epoch(epoch)
            validation_loader.sampler.set_epoch(epoch)

        train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}')
        validation_metrics = validate(model, device, validation_loader)

        combined_metrics = _all_reduce_dict({**validation_metrics, **train_metrics}, device)

        combined_metrics["train/accuracy"] /= combined_metrics["train/epoch_size"]
        combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"]
        combined_metrics["validation/accuracy"] /= combined_metrics["validation/epoch_size"]
        combined_metrics["validation/loss"] /= combined_metrics["validation/epoch_size"]

        if rank == 0:

            if combined_metrics["validation/accuracy"] > best_validation_accuracy:
                without_progress = 0
                best_validation_accuracy = combined_metrics["validation/accuracy"]

                model_to_save = model.module if hasattr(model, 'module') else model
                torch.save(dict(
                        epoch=epoch,
                        model_state_dict=model_to_save.state_dict(),
                        optimizer_state_dict=optimizer.state_dict()
                    ),
                    os.path.join(params.out_dir, params.dataset_name+"_roberta_fusion_jstylo.pt")
                )

        without_progress += 1

        if without_progress >= earlystop_epochs:
            break


if __name__ == '__main__':

    imp_list = ['mean_word_count_sent', 'mean_sent_count_para', 'apos_mean_count', 'apos_mean_count_para','comma_mean_count', 'comma_mean_count_para', 'hash_mean_count', 'hash_mean_count_para', 'wc_lead_sent', 'wc_lead_para', 'num_count', 'passive_sent_count',
       'past_tense_count', 'temp_inconsis']

    params = {"max_sequence_length": 512,
            "learning_rate" : 2e-5,
            "batch_size" : 8,
            "val_batch_size" : 16,
            "max_epochs" : 10,
            "epoch_size" : None,
            "device" : None,
            "large" : False,
            "seed" : 1024,
            "checkpoint_dir": "",
            "load_from_checkpoint": False,
            "weight_decay": 0,
            "FUSED_INPUT_SIZE": 768 + len(imp_list),
            "hidden_size" : 768,
            "num_features" : 2,
            "num_labels": 2,
            "classifier_dropout" : 0.5,
            "out_dir": "",
            "text_dir": "",
            "stylo_dir":"" ,
            "dataset_name": "",
            "imp_feat":imp_list,
            }
    params = SimpleNamespace(**params)

    nproc = int(subprocess.check_output([sys.executable, '-c', "import torch;"
                                         "print(torch.cuda.device_count() if torch.cuda.is_available() else 1)"]))
    if nproc > 1:
        print(f'Launching {nproc} processes ...', file=sys.stderr)

        os.environ["MASTER_ADDR"] = '127.0.0.1'
        os.environ["MASTER_PORT"] = str(29500)
        os.environ['WORLD_SIZE'] = str(nproc)
        os.environ['OMP_NUM_THREAD'] = str(1)
        subprocesses = []

        for i in range(nproc):
            os.environ['RANK'] = str(i)
            os.environ['LOCAL_RANK'] = str(i)
            process = Process(target=run, kwargs=params)
            process.start()
            subprocesses.append(process)

        for process in subprocesses:
            process.join()
    else:
        filter_models = ["TT_GPT4"]

        for model in filter_models:
          params.dataset_name = model

          print("--------------------------------------------------------------")
          print("Model running on:")
          print(params)
          run(params)
          print()
          print("--------------------------------------------------------------")

## Evaluation

In [None]:
import math
import torch
import argparse
from tqdm import tqdm
import pandas as pd
import random
import time

from torch.utils.data import DataLoader


from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from matplotlib import pyplot

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

import decimal


from transformers import *


def float_range(start, stop, step):
    while start < stop:
        yield float(start)
        start += decimal.Decimal(step)


def calculate_program_metrics(far, pd):

    pd_at_far = 0.0
    pd_at_eer = 0.0
    far_at_eer = 0.0

    for i in range(len(far)):
      if far[i] > 0.1:
        pd_at_far = pd[i-1]
        break

    for i in range(len(far)):
      if pd[i] > 1 - far[i]:
        pd_at_eer = (pd[i-1] + pd[i])/2
        far_at_eer = (far[i-1] + far[i])/2
        break


    print("pD @ 0.1 FAR = %.3f" % (pd_at_far))
    print("pD @ EER = %.3f" % (pd_at_eer))
    print("FAR @ EER = %.3f" % (far_at_eer))



class GeneratedTextDetection:
    """
    Artifact class
    """

    def __init__(self, args):
        torch.manual_seed(1000)

        self.args = args

        # Load the model from checkpoints
        self.init_dict = self._init_detector()

    def _init_detector(self):

        init_dict = {"kn_model": None, "kn_tokenizer": None,
                    "unk_model": None, "unk_tokenizer": None,
                   "attr_model": None, "attr_tokenizer": None, }

        if self.args.init_method == "fused":
            model_name = 'roberta-large' if self.args.large else 'roberta-base'
            tokenization_utils.logger.setLevel('ERROR')
            tokenizer = RobertaTokenizer.from_pretrained(model_name)
            lm = RobertaForFusion.from_pretrained(model_name).to(self.args.device)

            print("self.args.fused_input_size: ", self.args.FUSED_INPUT_SIZE)
            # model = FusedClassifier(lm=lm, device=self.args.device, FUSED_INPUT_SIZE=self.args.FUSED_INPUT_SIZE)
            model = FusedClassifier(lm=lm, device=self.args.device, FUSED_INPUT_SIZE=782)
            # Load the model from checkpoints
            if self.args.device == "cpu":
                model.load_state_dict(torch.load((self.args.check_point + '{}.pt').format(self.args.known_model_name),
                                                 map_location='cpu')['model_state_dict'])
            else:
                print("The model loaded will be: ")
                print((self.args.check_point + '{}.pt').format(self.args.known_model_name))
                model.load_state_dict(
                    torch.load((self.args.check_point + '{}.pt').format(self.args.known_model_name))['model_state_dict'])

            init_dict["kn_model"] = model
            init_dict["kn_tokenizer"] = tokenizer
            return init_dict


    def evaluate(self, input_text, stylo_feat):
        """
           Method that runs the evaluation and generate scores and evidence
        """

        # Encapsulate the inputs
        eval_dataset = EncodeEvalData(input_text, stylo_feat, self.init_dict["kn_tokenizer"], self.args.max_sequence_length)
        eval_loader = DataLoader(eval_dataset)

        # Dictionary will contain all the scores and evidences generated by the model
        results = {"cls": [], "LLR_score": [], "prob_score": {"cls_0": [], "cls_1": []}, "generator": None}

        # Set eval mode
        if self.args.init_method == "fused":
            self.init_dict["kn_model"].eval()


        with torch.no_grad():
              for texts, masks, custom_features in eval_loader:
                  texts, masks, custom_features = texts.to(self.args.device), masks.to(self.args.device), custom_features.to(self.args.device)

                  if self.args.init_method == "fused":
                      # Individual model take care all the probes
                      output_dic = self.init_dict["kn_model"](data=[texts, masks], custom_features = custom_features)
                      disc_out = output_dic

                      cls0_prob = disc_out[:, 0].tolist()
                      cls1_prob = disc_out[:, 1].tolist()

                      results["prob_score"]["cls_0"].extend(cls0_prob)
                      results["prob_score"]["cls_1"].extend(cls1_prob)

                      # prior_llr = math.log10(0.5/0.5)

                      # results["LLR_score"].extend([math.log10(prob/(1-prob)) + prior_llr for prob in cls1_prob])

                      _, predicted = torch.max(disc_out, 1)

                      results["cls"].extend(predicted.tolist())

        return results



def main():


    # imp_list = ['mean_word_count_sent', 'mean_sent_count_para', 'apos_mean_count', 'apos_mean_count_para', 'wc_lead_sent', 'wc_lead_para', 'num_count', 'passive_sent_count',
        # 'past_tense_count', 'temp_inconsis']

    imp_list = ['mean_word_count_sent', 'mean_sent_count_para', 'apos_mean_count', 'apos_mean_count_para','comma_mean_count', 'comma_mean_count_para', 'hash_mean_count', 'hash_mean_count_para', 'wc_lead_sent', 'wc_lead_para', 'num_count', 'passive_sent_count',
       'past_tense_count', 'temp_inconsis']

    args = {"max_sequence_length": 512,
            "learning_rate" : 2e-5,
            "batch_size" : 8,
            "val_batch_size" : 16,
            "max_epochs" : 10,
            "epoch_size" : None,
            "device" : None,
            "large" : False,
            "seed" : 1024,
            "check_point": "",
            "load_from_checkpoint": False,
            "weight_decay": 0,
            "FUSED_INPUT_SIZE": 768 + len(imp_list),
            "hidden_size" : 768,
            "num_features" : 2,
            "num_labels": 2,
            "classifier_dropout" : 0.5,
            "out_dir": "",
            "text_dir": "",
            "stylo_dir":"" ,
            "dataset_name": "",
            "imp_feat":imp_list,
            "known_model_name":"",
            "init_method":"fused"
            }
    args = SimpleNamespace(**args)

    if args.device is None:
        args.device = f'cuda:{0}' if torch.cuda.is_available() else 'cpu'

    filter_models = ["TT_GPT4"] # Enter saved model checkpoint file name


    for model in filter_models:
      args.dataset_name = model

      args.known_model_name = args.dataset_name+"_roberta_fusion_jstylo"

      print("--------------------------------------------------------------")
      print("Model running on: ", args.dataset_name)

      predict_prob = []

      y = []
      artifact = GeneratedTextDetection(args)

      test_data = pd.read_csv("") # Read the testing data
      stylo_feat = pd.read_csv("") # Read the testing stylometric features data

      multiple_lines = 0

      tp = 0
      tn = 0
      fn = 0
      fp = 0

      results = artifact.evaluate(test_data.text.values.tolist(), stylo_feat[imp_list].values)

      for i, value in tqdm(test_data.iterrows()):

        y.append(value.label)

        predict_prob.append(results["prob_score"]['cls_1'][i])

        predicted = results["cls"][i]

        tp += ((predicted == value.label) & (value.label == 1))
        tn += ((predicted == value.label) & (value.label == 0))
        fn += ((predicted != value.label) & (value.label == 1))
        fp += ((predicted != value.label) & (value.label == 0))

      recall = float(tp) / (tp+fn)
      precision = float(tp) / (tp+fp)
      f1_score = 2 * float(precision) * recall / (precision + recall)

      print('TP: %d' % (
          tp))
      print('TN: %d' % (
          tn))
      print('FP: %d' % (
          fp))
      print('FN: %d' % (
          fn))

      print('Accuracy of the discriminator: %d %%' % (
              100 * (tp + tn) / (tp + tn + fp + fn)))
      print('Recall of the discriminator: %d %%' % (
          100 * recall))
      print('Precision of the discriminator: %d %%' % (
          100 * precision))
      print('f1_score of the discriminator: %d %%' % (
          100 * f1_score))


      # calculate scores
      lr_auc = roc_auc_score(y, predict_prob)

      # summarize scores
      print("\n")
      print(" ----- Extra Metrics -----")
      print()
      print('Classifier: ROC AUC=%.3f' % (lr_auc))

      # calculate roc curves
      lr_fpr, lr_tpr, _ = roc_curve(y, predict_prob)

      calculate_program_metrics(lr_fpr, lr_tpr)

      eq_fpr = list(float_range(0, 1, 1 / len(lr_fpr)))
      eq_tpr = [item for item in eq_fpr]

      # plot the roc curve for the model
      pyplot.plot(lr_fpr, lr_tpr, marker='.', label='Roberta')
      pyplot.plot(eq_fpr, eq_tpr, marker='.', label='Random Chance')
      # axis labels

      pyplot.xlabel('Probability of False Alarm')
      pyplot.ylabel('Probability of Detection')
      # show the legend
      pyplot.legend()
      # show the plot
      pyplot.show()

      print()
      print("--------------------------------------------------------------")


if __name__ == "__main__":
    main()
