# RoBERTa BCE

## Variables d'environnement

In [17]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
# Id des GPU disponibles : 0 et 1

## Importation

In [18]:
import os
import time
import datetime
from typing import Any, Union, Dict, List
import uuid
import json

import pandas as pd
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torchtext
import nltk
import sklearn
import transformers
import torchmetrics as tm
from torchmetrics import MetricCollection, Metric, Accuracy, Precision, Recall, AUROC, HammingDistance, F1, ROC, AUC, PrecisionRecallCurve


from loguru import logger
from tqdm.auto import tqdm
tqdm.pandas()

# import warnings
# warnings.filterwarnings("ignore")

## Constantes

In [25]:
CUSTOME_NAME = "roberta-bce"

# Dataset
DATA_DIR_PATH = os.path.abspath("../../data")
TRAIN_DATASET_PATH = os.path.join(DATA_DIR_PATH, "jigsaw2019-train.csv")
TEST_DATASET_PATH = os.path.join(DATA_DIR_PATH, "jigsaw2019-test.csv")
LABEL_LIST = ['toxicity', 'severe_toxicity', 'obscene', 'sexual_explicit',
            'identity_attack', 'insult', 'threat']
IDENTITY_LIST = ['male', 'female', 'transgender', 'other_gender', 'heterosexual',
                'homosexual_gay_or_lesbian', 'bisexual','other_sexual_orientation',
                'christian', 'jewish', 'muslim', 'hindu','buddhist', 'atheist',
                'other_religion', 'black', 'white', 'asian', 'latino',
                'other_race_or_ethnicity', 'physical_disability',
                'intellectual_or_learning_disability',
                'psychiatric_or_mental_illness','other_disability']
SELECTED_IDENTITY_LIST = ['male', 'female', 'black', 'white', 'homosexual_gay_or_lesbian',
                    'christian', 'jewish', 'muslim', 'psychiatric_or_mental_illness']

# Session
SESSION_DIR_PATH = os.path.abspath("../../session")
SESSION_DATETIME = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S-%f")
SESSION_NAME = f"{CUSTOME_NAME}_{SESSION_DATETIME}"
CURRENT_SESSION_DIR_PATH = os.path.join(SESSION_DIR_PATH, SESSION_NAME)
# Créer le dossier de la session
os.makedirs(CURRENT_SESSION_DIR_PATH, exist_ok=True)

# Architecture de fichier dans `CURRENT_SESSION_DIR_PATH`
LOG_FILE_NAME = f"{SESSION_NAME}.loguru.log"
MODEL_FILE_NAME = f"{SESSION_NAME}.model"
TEST_FILE_NAME = f"{SESSION_NAME}.test.csv"
METRIC_FILE_NAME = f"{SESSION_NAME}.metric.json"
LOG_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, LOG_FILE_NAME)
MODEL_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, MODEL_FILE_NAME)
TEST_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, TEST_FILE_NAME)
METRIC_FILE_PATH = os.path.join(CURRENT_SESSION_DIR_PATH, METRIC_FILE_NAME)

# CUDA
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Logging

In [23]:
logger.add(LOG_FILE_PATH, level="TRACE")
logger.info(f"{SESSION_NAME=}")
logger.info(f"{TRAIN_DATASET_PATH=}")
logger.info(f"{TEST_DATASET_PATH=}")
logger.info(f"{CURRENT_SESSION_DIR_PATH=}")

2022-03-24 18:36:51.105 | INFO     | __main__:<cell line: 2>:2 - SESSION_NAME='roberta-bce_2022-03-24T18-36-33-957856'
2022-03-24 18:36:51.108 | INFO     | __main__:<cell line: 3>:3 - TRAIN_DATASET_PATH='/work2/home/ing1/corentin/hatespeech-detection-models/data/jigsaw2019-train.csv'
2022-03-24 18:36:51.110 | INFO     | __main__:<cell line: 4>:4 - TEST_DATASET_PATH='/work2/home/ing1/corentin/hatespeech-detection-models/data/jigsaw2019-test.csv'
2022-03-24 18:36:51.112 | INFO     | __main__:<cell line: 5>:5 - CURRENT_SESSION_DIR_PATH='/work2/home/ing1/corentin/hatespeech-detection-models/session/roberta-bce_2022-03-24T18-36-33-957856'


## Vérifier la cohérence de l'architecture et l'accès aux ressources

In [21]:
logger.info(f"Checking consistency...")

# Vérifier l'accès aux datasets
if not os.path.exists(TRAIN_DATASET_PATH):
    logger.critical(f"Train dataset does not exist !")
    raise RuntimeError("Train dataset does not exist !")
if not os.path.exists(TEST_DATASET_PATH):
    logger.critical(f"Test dataset does not exist !")
    raise RuntimeError("Test dataset does not exist !")
logger.success("Datasets are reachable")

# Vérifier l'accès aux GPU
GPU_IS_AVAILABLE = torch.cuda.is_available()
GPU_COUNT = torch.cuda.device_count()
logger.info(f"{GPU_IS_AVAILABLE=}")
logger.info(f"{GPU_COUNT=}")
if not GPU_IS_AVAILABLE:
    logger.critical("GPU and CUDA are not available !")
    raise RuntimeError("GPU and CUDA are not available !")
logger.success("GPU and CUDA are available")
logger.info(f"{device=}")
for gpu_id in range(GPU_COUNT):
    gpu_name = torch.cuda.get_device_name(0)
    logger.info(f"GPU {gpu_id} : {gpu_name}")

2022-03-24 18:34:14.978 | INFO     | __main__:<cell line: 1>:1 - Checking consistency...
2022-03-24 18:34:14.983 | SUCCESS  | __main__:<cell line: 10>:10 - Datasets are reachable
2022-03-24 18:34:14.986 | INFO     | __main__:<cell line: 15>:15 - GPU_IS_AVAILABLE=True
2022-03-24 18:34:14.988 | INFO     | __main__:<cell line: 16>:16 - GPU_COUNT=1
2022-03-24 18:34:14.990 | SUCCESS  | __main__:<cell line: 20>:20 - GPU and CUDA are available
2022-03-24 18:34:14.992 | INFO     | __main__:<cell line: 21>:21 - device=device(type='cuda')
2022-03-24 18:34:14.994 | INFO     | __main__:<cell line: 22>:24 - GPU 0 : NVIDIA TITAN X (Pascal)


## Dataset

In [47]:
train_df = pd.read_csv(TRAIN_DATASET_PATH, index_col=0)

In [48]:
# Remplacer toutes les colonnes correspondantes aux labels par 1 ou 0
# si la probabilité est supérieure ou égale à 0.5 ou non
train_df[LABEL_LIST] = (train_df[LABEL_LIST]>=0.5).astype(int)

In [50]:
class JigsawDataset(Dataset):
    def __init__(self, data_df, tokenizer):
        self.data = data_df
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        comment = self.data.iloc[index]["comment_text"]
        label = torch.tensor(self.data.iloc[index][LABEL_LIST].tolist(), dtype=torch.float)
        
        token_list, attention_mask = self.text_to_token_and_mask(comment)

        return dict(index=index, ids=token_list, mask=attention_mask, labels=label)
    
    def text_to_token_and_mask(self, input_text):
        tokenization_dict = tokenizer.encode_plus(input_text,
                                add_special_tokens=True,
                                max_length=128,
                                padding='max_length',
                                truncation=True,
                                return_attention_mask=True,
                                return_tensors='pt')
        token_list = tokenization_dict["input_ids"].flatten()
        attention_mask = tokenization_dict["attention_mask"].flatten()
        return (token_list, attention_mask)

## Model

In [None]:
def set_lr(optim, lr):
    '''
    Set the learning rate in the optimizer
    '''
    for g in optim.param_groups:
        g['lr'] = lr
    return optim

In [None]:
# Transformer class and functions for models and predictions

class TransformerClassifierStack(nn.Module):
    def __init__(self, tr_model, nb_labels, dropout_prob=0.4, freeze=False):
        super().__init__()
        self.tr_model = tr_model

        # Stack features of 4 last encoders
        self.hidden_dim = tr_model.config.hidden_size * 4

        # hidden linear for the classification
        self.dropout = nn.Dropout(dropout_prob)
        self.hl = nn.Linear(self.hidden_dim, self.hidden_dim)

        # Last Linear for the classification
        self.last_l = nn.Linear(self.hidden_dim, nb_labels)

        # freeze all the parameters if necessary
        for param in self.tr_model.parameters():
            param.requires_grad = not freeze

        # init learning params of last layers
        torch.nn.init.xavier_uniform_(self.hl.weight)
        torch.nn.init.xavier_uniform_(self.last_l.weight)

    def forward(self, ids, mask):
        # ids = [batch_size, padded_seq_len]
        # mask = [batch_size, padded_seq_len]
        # mask: avoid to make self attention on padded data
        tr_output = self.tr_model(input_ids=ids,
                                  attention_mask=mask,
                                  output_hidden_states=True)

        # Get all the hidden states
        hidden_states = tr_output['hidden_states']

        # hs_* = [batch_size, padded_seq_len, 768]
        hs_1 = hidden_states[-1][:, 0, :]
        hs_2 = hidden_states[-2][:, 0, :]
        hs_3 = hidden_states[-3][:, 0, :]
        hs_4 = hidden_states[-4][:, 0, :]

        # features_vec = [batch_size, 768 * 4]
        features_vec = torch.cat([hs_1, hs_2, hs_3, hs_4], dim=-1)

        x = self.dropout(features_vec)
        x = self.hl(x)

        # x = [batch_size, 768 * 4]
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.last_l(x)

        # x = [batch_size, 1]
        return x

def load_roberta_model(nb_labels):
    '''
    Load RoBERTa model without any checkpoint
    RoBERTa for finetuning
    '''
    logger.info(f"transformers.RobertaTokenizer : roberta-base")
    logger.info(f"transformers.AutoModel : roberta-base")
    tokenizer = transformers.RobertaTokenizer.from_pretrained('roberta-base')
    tr_model = transformers.AutoModel.from_pretrained('roberta-base')
    model = TransformerClassifierStack(tr_model, nb_labels)
    return model, tokenizer


def load_roberta_pretrained(path, nb_labels, lr=2e-5):
    '''
    Load RoBERTa from checkout point (already trained on Hate Speech tasks)
    '''
    tokenizer = transformers.RobertaTokenizer.from_pretrained('roberta-base')
    tr_model = transformers.AutoModel.from_pretrained('roberta-base')
    model = TransformerClassifierStack(tr_model, nb_labels)

    loaded = torch.load(path)
    model.load_state_dict(loaded['state_dict'])

    optimizer = transformers.AdamW(model.parameters(), lr=lr)
    optimizer.load_state_dict(loaded['optimizer_dict'])
    optimizer = set_lr(optimizer, lr)

    return model, tokenizer, optimizer

def preds_fn(batch, model, device):
    '''
    Get the predictions for one batch according to the model
    '''
    b_input = batch['ids'].to(device)
    b_mask = batch['mask'].to(device)

    return model(b_input, b_mask)

In [None]:
# Load the model
model, tokenizer = load_roberta_model(nb_labels=len(LABEL_LIST))

## Hyperparamètre

In [None]:
BATCH_SIZE = 128
LR=1e-4
PIN_MEMORY = True
NUM_WORKERS = 0
PREFETCH_FACTOR = 2
NUM_EPOCHS = 1
logger.info(f"{BATCH_SIZE=}")
logger.info(f"{LR=}")
logger.info(f"{PIN_MEMORY=}")
logger.info(f"{NUM_WORKERS=}")
logger.info(f"{PREFETCH_FACTOR=}")
logger.info(f"{NUM_EPOCHS=}")

In [None]:
train_dataset = JigsawDataset(train_df, tokenizer)
train_dataloader = DataLoader(train_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=NUM_WORKERS,
                             prefetch_factor=PREFETCH_FACTOR,
                             pin_memory=PIN_MEMORY)

# Pseudo validation car c'est un sous-ensemble du jeu d'entrainement
# C'est juste pour savoir si l'entraînement s'est bien passé
# Il ne faut pas s'en servir pour optimiser les hyperparamètres
validation_df = train_df.sample(n=50_000)
validation_dataset = JigsawDataset(validation_df, tokenizer)
validation_dataloader = DataLoader(validation_dataset,
                             batch_size=BATCH_SIZE,
                             shuffle=True,
                             num_workers=NUM_WORKERS,
                             prefetch_factor=PREFETCH_FACTOR,
                             pin_memory=PIN_MEMORY)

criterion = torch.nn.BCEWithLogitsLoss()
logger.info(criterion)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR)
logger.info(optimizer)

model.to(device)
criterion.to(device)

## Metric

### Variantes de Hamming Loss

In [None]:
class HammingLossWithoutThreshold(Metric):
    def __init__(self, num_classes=1, dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.num_classes = num_classes

        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("nbr_sample", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        current_nbr_sample, current_nbr_category = preds.shape
        if current_nbr_category != self.num_classes:
          raise AttributeError("`num_classes` != `current_nbr_category` detected in `pred` parameter")
        
        current_loss_per_pred = torch.absolute(target - preds)
        current_hamming_loss = current_loss_per_pred.sum()

        self.total += current_hamming_loss
        self.nbr_sample += current_nbr_sample

    def compute(self):
        return self.total/(self.num_classes*self.nbr_sample)

In [None]:
class RebalancedHammingLossWithoutThreshold(Metric):
    def __init__(self, num_classes=1, average="macro", dist_sync_on_step=False):
        super().__init__(dist_sync_on_step=dist_sync_on_step)
        self.num_classes = num_classes

        # average = "macro" or None
        self.average = average

        # Nombre de positif 1 & negatif 0 par categorie
        self.add_state(
            "number_positive",
            default=torch.tensor([0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "number_negative",
            default=torch.tensor([0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )

        self.add_state(
            "hamming_loss_positive",
            default=torch.tensor([0.0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )
        self.add_state(
            "hamming_loss_negative",
            default=torch.tensor([0.0 for _ in range(num_classes)]),
            dist_reduce_fx="sum",
        )

        self.add_state("nbr_sample", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        current_nbr_sample, current_nbr_category = preds.shape
        if current_nbr_category != self.num_classes:
            raise AttributeError(
                "`num_classes` != `current_nbr_category` detected in `pred` parameter"
            )

        # Nombre de positif 1 & negatif 0 par categorie
        current_number_positive = target.sum(axis=0)
        current_number_negative = current_nbr_sample - target.sum(axis=0)

        self.number_positive += current_number_positive.int()
        self.number_negative += current_number_negative.int()

        self.nbr_sample += current_nbr_sample

        for class_id in range(self.num_classes):
            positive_filter = target[:, class_id] == 1
            negative_filter = target[:, class_id] == 0

            target_vector = target[:, class_id]
            preds_vector = preds[:, class_id]

            # Filtered vector
            ## Target
            pos_filtered_target_vector = target_vector[positive_filter]
            neg_filtered_target_vector = target_vector[negative_filter]
            ## Preds
            pos_filtered_preds_vector = preds_vector[positive_filter]
            neg_filtered_preds_vector = preds_vector[negative_filter]

            # Hamming Loss without Threshold
            hamming_loss_on_positive = torch.absolute(
                pos_filtered_target_vector - pos_filtered_preds_vector
            )
            hamming_loss_on_negative = torch.absolute(
                neg_filtered_target_vector - neg_filtered_preds_vector
            )

            self.hamming_loss_positive[class_id] += hamming_loss_on_positive.sum()
            self.hamming_loss_negative[class_id] += hamming_loss_on_negative.sum()

    def compute(self):
        factor_pos = self.nbr_sample / (2 * self.number_positive)
        factor_neg = self.nbr_sample / (2 * self.number_negative)

        rebalanced_hamming_loss_per_class = torch.multiply(
            self.hamming_loss_positive, factor_pos
        ) + torch.multiply(self.hamming_loss_negative, factor_neg)
        if self.average == "macro":
            return rebalanced_hamming_loss_per_class.sum() / (
                self.nbr_sample * self.num_classes
            )
        return rebalanced_hamming_loss_per_class / (self.nbr_sample)


### Instanciation des metrics

In [None]:
num_classes = len(LABEL_LIST)
train_metric_dict = dict()

# AUROC Macro
auroc_macro = AUROC(num_classes=num_classes, compute_on_step=True, average="macro")
train_metric_dict["auroc_macro"] = auroc_macro

# AUROC per class
auroc_per_class = AUROC(num_classes=num_classes, compute_on_step=True, average=None)
train_metric_dict["auroc_per_class"] = auroc_per_class

# Hamming Distance without Threshold
hamming_loss_woutt = HammingLossWithoutThreshold(num_classes=num_classes)
train_metric_dict["hamming_loss_without_threshold"] = hamming_loss_woutt

# Rebalanced Hamming Distance without Threshold macro
rebalanced_hamming_loss_woutt_macro = RebalancedHammingLossWithoutThreshold(
    num_classes=num_classes, average="macro"
)
train_metric_dict[
    "rebalanced_hamming_loss_without_threshold_macro"
] = rebalanced_hamming_loss_woutt_macro

# Rebalanced Hamming Distance without Threshold macro
rebalanced_hamming_loss_woutt_per_class = RebalancedHammingLossWithoutThreshold(
    num_classes=num_classes, average=None
)
train_metric_dict[
    "rebalanced_hamming_loss_without_threshold_per_class"
] = rebalanced_hamming_loss_woutt_per_class

In [None]:
train_metric = MetricCollection(train_metric_dict)
train_metric.to(device)

validation_metric = train_metric.clone()
validation_metric.to(device)

### Export metrics

In [None]:
def serialize(object_to_serialize: Any, ensure_ascii: bool = True) -> str:
    """
    Serialize any object, i.e. convert an object to JSON
    Args:
        object_to_serialize (Any): The object to serialize
        ensure_ascii (bool, optional): If ensure_ascii is true (the default), the output is guaranteed to have all incoming non-ASCII characters escaped. If ensure_ascii is false, these characters will be output as-is. Defaults to True.
    Returns:
            str: string of serialized object (JSON)
    """

    def dumper(obj: Any) -> Union[str, Dict]:
        """
        Function called recursively by json.dumps to know how to serialize an object.
        For example, for datetime, we try to convert it to ISO format rather than
        retrieve the list of attributes defined in its object.
        Args:
            obj (Any): The object to serialize
        Returns:
            Union[str, Dict]: Serialized object
        """
        if isinstance(obj, torch.Tensor):
            return obj.cpu().numpy().tolist()
        elif hasattr(obj, "__dict__"):
            return obj.__dict__
        return str(obj)

    return json.dumps(object_to_serialize, default=dumper, ensure_ascii=ensure_ascii)

In [None]:
def export_metric(metric_collection, **kwargs):
    """
    Export MetricCollection to json file

    Args:
        metric_collection: MetricCollection
        **kwargs: field to add in json line
    """
    with open(METRIC_FILE_PATH, "a") as f:
        metric_collection_value = metric_collection.compute()
        metric_collection_value.update(kwargs)
        serialized_value = serialize(metric_collection_value)
        f.write(serialized_value)
        f.write("\n")
    logger.success("Metrics are exported !")

## Entraînement

In [None]:
def train_epoch(epoch_id=None):
    model.train()
    logger.info(f"START EPOCH {epoch_id=}")

    progress = tqdm(train_dataloader, desc='training batch...', leave=False)
    for batch_id, batch in enumerate(progress):
        logger.trace(f"{batch_id=}")
        token_list_batch = batch["ids"].to(device)
        attention_mask_batch = batch["mask"].to(device)
        label_batch = batch["labels"].to(device)

        # Predict
        prediction_batch = model(token_list_batch, attention_mask_batch)
        transformed_prediction_batch = prediction_batch.squeeze()

        # Loss
        loss = criterion(transformed_prediction_batch.to(torch.float32), label_batch.to(torch.float32))

        # Metrics
        train_metrics_collection_dict = train_metric(transformed_prediction_batch, label_batch)
        logger.trace(train_metrics_collection_dict)

        # Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Update progress bar description
        progress_description = "Train Loss : {loss:.4f} - Train AUROC : {acc:.4f}"
        auroc_macro_value = float(train_metrics_collection_dict["auroc_macro"])
        progress_description = progress_description.format(loss=loss.item(), acc=auroc_macro_value)
        progress.set_description(progress_description)

    logger.info(f"END EPOCH {epoch_id=}")

In [None]:
@torch.no_grad()
def valid_epoch(epoch_id=None):
    model.eval()
    logger.info(f"START VALIDATION {epoch_id=}")
    validation_metric.reset()

    loss_list = []
    prediction_list = torch.Tensor([])
    target_list = torch.Tensor([])


    progress = tqdm(validation_dataloader, desc="valid batch...", leave=False)
    for batch_id, batch in enumerate(progress):
        logger.trace(f"{batch_id=}")
        
        token_list_batch = batch["ids"].to(device)
        attention_mask_batch = batch["mask"].to(device)
        label_batch = batch["labels"].to(device)

        # Predict
        prediction_batch = model(token_list_batch, attention_mask_batch)

        transformed_prediction_batch = prediction_batch.squeeze()

        # Loss
        loss = criterion(
            transformed_prediction_batch.to(torch.float32),
            label_batch.to(torch.float32),
        )
        loss_list.append(loss.item())
        prediction_list = torch.concat(
            [prediction_list, transformed_prediction_batch.cpu()]
        )
        target_list = torch.concat([target_list, label_batch.cpu()])

        # Metrics
        validation_metric(transformed_prediction_batch, label_batch)

    loss_mean = np.mean(loss_list)
    logger.trace(validation_metric.compute())
    logger.info(f"END VALIDATION {epoch_id=}")
    export_metric(validation_metric, epoch_id=epoch_id, loss=loss_mean)

In [None]:
progress =  tqdm(range(1,NUM_EPOCHS+1), desc='training epoch...', leave=True)
for epoch in progress:
    # Train
    train_epoch(epoch_id=epoch)

    # Validation
    valid_metrics_dict = valid_epoch(epoch_id=epoch)

    # Save
    torch.save(model, MODEL_FILE_PATH)

## Evaluation

In [55]:
pd.DataFrame([[1,0,1],[0.5,0.2,0.3]], columns=["t", "a", "c"], index=[6,4]).sort_index()

Unnamed: 0,t,a,c
4,0.5,0.2,0.3
6,1.0,0.0,1.0


In [57]:
try:
    del train_df
    del validation_df
except NameError:
    logger.warning("Train DataFrame & Validation DataFrame already deleted")



In [58]:
test_df = pd.read_csv(TEST_DATASET_PATH, index_col=0)

Unnamed: 0,id,comment_text,split,created_date,publication_id,parent_id,article_id,rating,funny,wow,...,white,asian,latino,other_race_or_ethnicity,physical_disability,intellectual_or_learning_disability,psychiatric_or_mental_illness,other_disability,identity_annotator_count,toxicity_annotator_count
3,7084460,arresting man resisting arrest cop suckers see...,test,2016-11-01 16:53:33.561631+00,13,,149218,approved,0,0,...,,,,,,,,,0,76
10,7141509,alternative facts go check people like idea ta...,test,2017-01-30 02:53:48.012277+00,21,919529.0,164687,approved,1,0,...,,,,,,,,,0,72
11,7077814,whine sore loser artster enjoy agony,test,2016-12-03 00:17:42.300700+00,54,649753.0,154126,approved,0,0,...,,,,,,,,,0,80
38,7147990,rarely opportunity agree bennet much case righ...,test,2017-09-13 16:37:16.990602+00,102,,377304,approved,1,0,...,,,,,,,,,0,9
42,7008066,law every freedom asss,test,2017-07-09 07:03:44.153492+00,54,5556167.0,353158,approved,0,0,...,,,,,,,,,0,10


In [None]:
test_dataset = JigsawDataset(test_df, tokenizer)
test_dataloader = DataLoader(test_dataset,
                             batch_size=BATCH_SIZE)

In [None]:
@torch.no_grad()
def evaluation(model):
    model.eval()
    logger.info(f"START EVALUATION")

    index_tensor = torch.Tensor([])
    prediction_tensor = torch.Tensor([])

    progress = tqdm(test_dataloader, desc='test batch...', leave=False)
    for batch_id, batch in enumerate(progress):
        logger.trace(f"{batch_id=}")
        index_batch = batch["index"].to(device)
        token_list_batch = batch["ids"].to(device)
        attention_mask_batch = batch["mask"].to(device)
        label_batch = batch["labels"].to(device)

        # Predict
        prediction_batch = model(token_list_batch, attention_mask_batch)
        transformed_prediction_batch = prediction_batch.squeeze()
        
        index_tensor = torch.concat([index_tensor, index_batch.cpu()])
        prediction_tensor = torch.concat([prediction_tensor, transformed_prediction_batch.cpu()])
    
    logger.info(f"END EVALUATION")
    prediction_test_df = pd.DataFrame(prediction_tensor.tolist(), 
                                     columns=LABEL_LIST,
                                     index=index_tensor.tolist())
    prediction_test_df.to_csv(TEST_FILE_PATH)
    logger.success(f"Test predictions exported !")

In [None]:
evaluation(model)