In [None]:
!pip install pytorch-lightning==1.2.8 --quiet
!pip install transformers==4.5.1 --quiet

In [None]:
import heapq
import numpy as np
import pandas as pd
import pickle
import pytorch_lightning as pl
import seaborn as sb
import torch
import torch.nn as nn

from torch.utils.data import Dataset
from tqdm.auto import tqdm
from transformers import BertModel, BertTokenizerFast as BertTokenizer, AdamW, get_linear_schedule_with_warmup

In [None]:
LABEL_COLUMNS = ['label-1', 'label-2', ]

In [None]:
BERT_MODEL_NAME = 'dccuchile/bert-base-spanish-wwm-uncased'

In [None]:
class CxCommentTagger(pl.LightningModule):

    def __init__(self, n_training_steps=None, n_warmup_steps=None):
        super().__init__()
        n_classes = len(LABEL_COLUMNS)
        self.bert = BertModel.from_pretrained(BERT_MODEL_NAME, return_dict=True)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes)
        self.n_training_steps = n_training_steps
        self.n_warmup_steps = n_warmup_steps
        self.criterion = nn.BCELoss()

    def forward(self, input_ids, attention_mask, labels=None):
        output = self.bert(input_ids, attention_mask=attention_mask)
        output = self.classifier(output.pooler_output)
        output = torch.sigmoid(output)
        loss = 0
        if labels is not None:
            loss = self.criterion(output, labels)
        return loss, output

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return {"loss": loss, "predictions": outputs, "labels": labels}

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        loss, outputs = self(input_ids, attention_mask, labels)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        return loss

    def training_epoch_end(self, outputs):

        labels = []
        predictions = []
        for output in outputs:
            for out_labels in output["labels"].detach().cpu():
                labels.append(out_labels)
            for out_predictions in output["predictions"].detach().cpu():
                predictions.append(out_predictions)

        labels = torch.stack(labels).int()
        predictions = torch.stack(predictions)

        for i, name in enumerate(LABEL_COLUMNS):
            class_roc_auc = auroc(predictions[:, i], labels[:, i])
            self.logger.experiment.add_scalar(f"{name}_roc_auc/Train", class_roc_auc, self.current_epoch)

    def configure_optimizers(self):

        optimizer = AdamW(self.parameters(), lr=2e-5)

        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.n_warmup_steps,
            num_training_steps=self.n_training_steps
        )

        return dict(
            optimizer=optimizer,
            lr_scheduler=dict(
                scheduler=scheduler,
                interval='step'
            )
        )

In [None]:
class CxCommentsDataset(Dataset):

    def __init__(
            self,
            data: pd.DataFrame,
            tokenizer: BertTokenizer,
            max_token_len: int = 128
    ):
        self.data = data
        self.tokenizer = tokenizer
        self.max_token_len = max_token_len

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]
        comment_text = data_row.comment_text
        # labels = data_row[LABEL_COLUMNS]

        encoding = self.tokenizer.encode_plus(
            comment_text,
            add_special_tokens=True,
            max_length=self.max_token_len,
            return_token_type_ids=False,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt',
        )

        return dict(
            call_id=self.data.index[index],
            comment_text=comment_text,
            input_ids=encoding["input_ids"].flatten(),
            attention_mask=encoding["attention_mask"].flatten(),
            # labels=torch.FloatTensor(labels)
        )

In [None]:
THRESHOLD = 0.5
MAX_TOKEN_COUNT = 512

In [None]:
def predict(df):
    # Load model from a checkpoint file
    tm = CxCommentTagger.load_from_checkpoint(checkpoint_path="checkpoint-file.ckpt")
    tm.eval()
    tm.freeze()
    # Load tokenizer
    tk = BertTokenizer.from_pretrained(BERT_MODEL_NAME)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    trained_model = tm.to(device)

    dataset = CxCommentsDataset(df, tk, max_token_len=MAX_TOKEN_COUNT)
    # Predict each row in the dataset
    idx = []
    predictions = []
    for item in tqdm(dataset):
        # print(item["call_id"])
        _, prediction = trained_model(
            item["input_ids"].unsqueeze(dim=0).to(device),
            item["attention_mask"].unsqueeze(dim=0).to(device)
        )
        predictions.append(prediction.flatten())
        idx.append(item["call_id"])
    predictions = torch.stack(predictions).detach().cpu()
    # Update probs to 1 based on a threshold
    predictions = predictions.numpy()
    upper, lower = 1, 0
    preds = np.where(predictions > THRESHOLD, upper, lower)
    # Check whether all comments have at least one label (type and subtype)
    for i, preds_i in enumerate(preds):
        if np.sum(preds_i) > 0:
            continue
        max_ = np.max(predictions[i])
        preds[i] = np.where(predictions[i] >= max_, upper, lower)

    return idx, preds * predictions

In [None]:
df = pd.read_csv("file-to-predict.csv", index_col='id-col', dtype={'id-col': str}, sep=';', encoding='latin_1')

In [None]:
df.head()

In [None]:
predictions = predict(df)

In [None]:
preds = []

for i in range(predictions[1].shape[0]):
  pred_label = []
  for label, prediction in zip(LABEL_COLUMNS, predictions[1][i]):
    if prediction == 0:
      continue
    pred_label.append(f'{label}: {prediction}')
  preds.append((predictions[0][i], ";".join(pred_label)))

In [None]:
file_name = 'predictions.pkl'
open_file = open(file_name, "wb")
pickle.dump(preds, open_file)
open_file.close()

In [None]:
idxs = [idx for idx, _ in preds]
preds_ = [pred for _, pred in preds]
df_preds = pd.DataFrame(data=preds_, index=idxs)

result = pd.merge(df, df_preds, left_index=True, right_index=True)
result.to_csv('predictions.csv')