In [2]:
#Importing Important Librarries
import os
import torch
import pandas as pd
from scipy import stats
import numpy as np

from tqdm import tqdm
from collections import OrderedDict, namedtuple
import torch.nn as nn
from torch.optim import lr_scheduler
import joblib

import logging
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup, get_constant_schedule
import sys
from sklearn import metrics, model_selection

import warnings
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.data_parallel as dp
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings

warnings.filterwarnings("ignore")


In [3]:
#Building Helper Functions and Classes

class AverageMeter:
    """
    Computes and stores the average and current value
    """
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


In [4]:
class BERTBaseUncased(nn.Module):
    def __init__(self, bert_path):
        super(BERTBaseUncased, self).__init__()
        self.bert_path = bert_path
        self.bert = transformers.BertModel.from_pretrained(self.bert_path)
        self.bert_drop = nn.Dropout(0.3)
        self.out = nn.Linear(768 * 2, 1)

    def forward(
            self,
            ids,
            mask,
            token_type_ids
    ):
        o1, o2 = self.bert(
            ids,
            attention_mask=mask,
            token_type_ids=token_type_ids)
        
        apool = torch.mean(o1, 1)
        mpool, _ = torch.max(o1, 1)
        cat = torch.cat((apool, mpool), 1)

        bo = self.bert_drop(cat)
        p2 = self.out(bo)
        return p2


In [5]:
class BERTDatasetTraining:
    def __init__(self, comment_text, targets, tokenizer, max_length):
        self.comment_text = comment_text
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.targets = targets

    def __len__(self):
        return len(self.comment_text)

    def __getitem__(self, item):
        comment_text = str(self.comment_text[item])
        comment_text = " ".join(comment_text.split())

        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,
            max_length=self.max_length,
        )
        ids = inputs["input_ids"]
        token_type_ids = inputs["token_type_ids"]
        mask = inputs["attention_mask"]
        
        padding_length = self.max_length - len(ids)
        
        ids = ids + ([0] * padding_length)
        mask = mask + ([0] * padding_length)
        token_type_ids = token_type_ids + ([0] * padding_length)
        
        return {
            'ids': torch.tensor(ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids': torch.tensor(token_type_ids, dtype=torch.long),
            'targets': torch.tensor(self.targets[item], dtype=torch.float)
        }

In [6]:
#Importing BERT Multilingual Library Uncased
mx = BERTBaseUncased(bert_path="../input/bert-base-multilingual-uncased/")

#Importing the Dataset
df_train1 = pd.read_csv("../input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv", usecols=["comment_text", "toxic"]).fillna("none")
df_train2 = pd.read_csv("../input/jigsaw-multilingual-toxic-comment-classification/jigsaw-unintended-bias-train.csv", usecols=["comment_text", "toxic"]).fillna("none")

#Concatenating the dataset from two sources
df_train_full = pd.concat([df_train1, df_train2], axis=0).reset_index(drop=True)
df_train = df_train_full.sample(frac=1).reset_index(drop=True).head(200000)

#Validation Datatset
df_valid = pd.read_csv('../input/jigsaw-multilingual-toxic-comment-classification/validation.csv', 
                       usecols=["comment_text", "toxic"])

df_train = pd.concat([df_train, df_valid], axis=0).reset_index(drop=True)
df_train = df_train.sample(frac=1).reset_index(drop=True)

In [7]:
#Defining Training Functions

def _run():
    def loss_fn(outputs, targets):
        return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))

    def train_loop_fn(data_loader, model, optimizer, device, scheduler=None):
        model.train()
        for bi, d in enumerate(data_loader):
            ids = d["ids"]
            mask = d["mask"]
            token_type_ids = d["token_type_ids"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            optimizer.zero_grad()
            outputs = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )

            loss = loss_fn(outputs, targets)
            if bi % 10 == 0:
                xm.master_print(f'bi={bi}, loss={loss}')

            loss.backward()
            xm.optimizer_step(optimizer)
            if scheduler is not None:
                scheduler.step()

    def eval_loop_fn(data_loader, model, device):
        model.eval()
        fin_targets = []
        fin_outputs = []
        for bi, d in enumerate(data_loader):
            ids = d["ids"]
            mask = d["mask"]
            token_type_ids = d["token_type_ids"]
            targets = d["targets"]

            ids = ids.to(device, dtype=torch.long)
            mask = mask.to(device, dtype=torch.long)
            token_type_ids = token_type_ids.to(device, dtype=torch.long)
            targets = targets.to(device, dtype=torch.float)

            outputs = model(
                ids=ids,
                mask=mask,
                token_type_ids=token_type_ids
            )

            targets_np = targets.cpu().detach().numpy().tolist()
            outputs_np = outputs.cpu().detach().numpy().tolist()
            fin_targets.extend(targets_np)
            fin_outputs.extend(outputs_np)    

        return fin_outputs, fin_targets

    
    MAX_LEN = 192
    TRAIN_BATCH_SIZE = 64
    EPOCHS = 2

    #Creating tokenizer for Text Preprocessing
    tokenizer = transformers.BertTokenizer.from_pretrained("../input/bert-base-multilingual-uncased/", do_lower_case=True)

    train_targets = df_train.toxic.values
    valid_targets = df_valid.toxic.values
    

#Splitting the datatset
    train_dataset = BERTDatasetTraining(
        comment_text=df_train.comment_text.values,
        targets=train_targets,
        tokenizer=tokenizer,
        max_length=MAX_LEN
    )

    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=True)

    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        sampler=train_sampler,
        drop_last=True,
        num_workers=1
    )

    valid_dataset = BERTDatasetTraining(
        comment_text=df_valid.comment_text.values,
        targets=valid_targets,
        tokenizer=tokenizer,
        max_length=MAX_LEN
    )

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
          valid_dataset,
          num_replicas=xm.xrt_world_size(),
          rank=xm.get_ordinal(),
          shuffle=False)

    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=16,
        sampler=valid_sampler,
        drop_last=False,
        num_workers=1
    )

    
    #Traning Algorithm
    device = xm.xla_device()
    model = mx.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.001},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    lr = 0.4 * 1e-5 * xm.xrt_world_size()
    num_train_steps = int(len(train_dataset) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    xm.master_print(f'num_train_steps = {num_train_steps}, world_size={xm.xrt_world_size()}')

    optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )

    for epoch in range(EPOCHS):
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_loop_fn(para_loader.per_device_loader(device), model, optimizer, device, scheduler=scheduler)

        para_loader = pl.ParallelLoader(valid_data_loader, [device])
        o, t = eval_loop_fn(para_loader.per_device_loader(device), model, device)
        
        #Saving the Model in a binary File
        xm.save(model.state_dict(), "model.bin")
        auc = metrics.roc_auc_score(np.array(t) >= 0.5, o)
        xm.master_print(f'AUC = {auc}')

In [8]:
def _mp_fn(rank, flags):
    torch.set_default_tensor_type('torch.FloatTensor')
    a = _run()

FLAGS={}
xmp.spawn(_mp_fn, args=(FLAGS,), nprocs=8, start_method='fork')

num_train_steps = 812, world_size=8
bi=0, loss=0.564293622970581
bi=10, loss=0.3402901291847229
bi=20, loss=0.31182992458343506
bi=30, loss=0.24555177986621857
bi=40, loss=0.3227238655090332
bi=50, loss=0.20644468069076538
bi=60, loss=0.24016498029232025
bi=70, loss=0.3656253516674042
bi=80, loss=0.3177357017993927
bi=90, loss=0.33056870102882385
bi=100, loss=0.2262754738330841
bi=110, loss=0.23070532083511353
bi=120, loss=0.27551010251045227
bi=130, loss=0.21741044521331787
bi=140, loss=0.25732097029685974
bi=150, loss=0.24563854932785034
bi=160, loss=0.20598378777503967
bi=170, loss=0.22359442710876465
bi=180, loss=0.19511005282402039
bi=190, loss=0.19456921517848969
bi=200, loss=0.177840918302536
bi=210, loss=0.310518354177475
bi=220, loss=0.24306896328926086
bi=230, loss=0.19314469397068024
bi=240, loss=0.2609172463417053
bi=250, loss=0.2205863893032074
bi=260, loss=0.18660475313663483
bi=270, loss=0.22155281901359558
bi=280, loss=0.15475620329380035
bi=290, loss=0.2210752367973327