In [1]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas-dev

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100  5116  100  5116    0     0   3311      0  0:00:01  0:00:01 --:--:--  3309
Updating... This may take around 2 minutes.
Updating TPU runtime to pytorch-dev20200515 ...
Found existing installation: torch 1.5.0
Uninstalling torch-1.5.0:
  Successfully uninstalled torch-1.5.0
Found existing installation: torchvision 0.6.0a0+35d732a
Uninstalling torchvision-0.6.0a0+35d732a:
Done updating TPU runtime
  Successfully uninstalled torchvision-0.6.0a0+35d732a
Copying gs://tpu-pytorch/wheels/torch-nightly+20200515-cp37-cp37m-linux_x86_64.whl...
\ [1 files][ 91.0 MiB/ 91.0 MiB]                                                
Operation completed over 1 objects/91.0 MiB.                                     
Copying gs://tpu-pytorch/wheels/torch_xla-nightly+20200515-cp37-cp37m-linux_x86_64.whl...
\ [1 files][119.5 MiB/119.5 MiB]              

In [3]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

        
import torch
import transformers
import torch.nn as nn
import torch
from tqdm import tqdm
import transformers
from sklearn import model_selection
from sklearn import metrics
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup

/kaggle/input/jigsaw-multilingual-toxic-comment-classification/validation-processed-seqlen128.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train-processed-seqlen128.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-unintended-bias-train.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/validation.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test-processed-seqlen128.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/test.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/jigsaw-unintended-bias-train-processed-seqlen128.csv
/kaggle/input/jigsaw-multilingual-toxic-comment-classification/sample_submission.csv
/kaggle/input/bertbasemultilingualuncased/bert-base-multilingual-uncased/vocab.txt
/kaggle/input/bertbasemultilingualuncased/bert-base-multilingual-uncased



In [2]:
import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl


In [26]:
MAX_LEN = 128
TRAIN_BATCH_SIZE = 128
VALID_BATCH_SIZE = 64
EPOCHS = 10
ACCUMULATION = 2
BERT_PATH = "../input/bertbasemultilingualuncased/bert-base-multilingual-uncased"
MODEL_PATH = "model.bin"
TOKENIZER = transformers.BertTokenizer.from_pretrained(BERT_PATH, do_lower_case=True)


In [5]:
class BERTDataset:
    def __init__(self, comment_text, target):
        self.comment_text = comment_text
        self.target = target
        self.tokenizer = TOKENIZER
        self.max_len = MAX_LEN

    
    def __len__(self):
        return len(self.comment_text)
    

    def __getitem__(self, item):
        comment_text = self.comment_text[item]
        comment_text = " ".join(comment_text.split())

        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,
            max_length=self.max_len
        )


        ids = inputs['input_ids']
        mask = inputs['attention_mask']
        token_type_ids = inputs['token_type_ids']

        padding_len = self.max_len  - len(ids)
        ids = ids + ([0] * padding_len)
        mask = mask + ([0] * padding_len)
        token_type_ids = token_type_ids + ([0] * padding_len)

        return {
            "ids" : torch.tensor(ids, dtype = torch.long),
            "mask" : torch.tensor(mask, dtype = torch.long),
            "token_type_ids" : torch.tensor(token_type_ids, dtype = torch.long),
            "target" : torch.tensor(self.target[item], dtype = torch.float)
        }


In [6]:
class BERTBaseUncased(nn.Module):
    def __init__(self):
        super(BERTBaseUncased, self).__init__()
        self.bert = transformers.BertModel.from_pretrained(BERT_PATH)
        self.bert_drop = nn.Dropout(0.3)
        self.out = nn.Linear(768 * 2, 1)
    
    def forward(self, ids, mask, token_type_ids):
        o1, _ = self.bert(
            ids,
            attention_mask = mask,
            token_type_ids = token_type_ids
        )

        mean_pooling = torch.mean(o1, 1)
        max_pooling, _ = torch.max(o1, 1)
        cat = torch.cat((mean_pooling, max_pooling), 1)

        bo = self.bert_drop(cat)
        output = self.out(bo)
        return output


In [28]:
def loss_fn(outputs, target):
    return nn.BCEWithLogitsLoss()(outputs, target.view(-1, 1))



def train_fn(data_loader, model, optimizer, device, scheduler):
    model.train()
    for bi, d in enumerate(data_loader):
        ids = d['ids']
        mask = d['mask']
        token_type_ids = d['token_type_ids']
        target = d['target']

        ids = ids.to(device, dtype=torch.long)
        mask = mask.to(device, dtype = torch.long)
        token_type_ids = token_type_ids.to(device, dtype = torch.long)
        target = target.to(device, dtype = torch.float)

        optimizer.zero_grad()
        outputs = model(
            ids = ids,
            mask = mask,
            token_type_ids = token_type_ids
        )

        loss = loss_fn(outputs, target)
        loss.backward()
        xm.optimizer_step(optimizer)
        scheduler.step()
        if bi % 10 == 0:
            xm.master_print("bi = {}, loss = {}".format(bi, loss))
        


def eval_fn(data_loader, model, device):
    model.eval()
    fin_targets = []
    fin_outputs = []
    with torch.no_grad():
        for bi, d in enumerate(data_loader):
            ids = d['ids']
            mask = d['mask']
            token_type_ids = d['token_type_ids']
            target = d['target']

            ids = ids.to(device, dtype = torch.long)
            mask = mask.to(device, dtype = torch.long)
            token_type_ids = token_type_ids.to(device, dtype = torch.long)
            target = target.to(device, dtype = torch.float)

            outputs = model(
                ids = ids,
                mask = mask,
                token_type_ids = token_type_ids
            )
            fin_targets.extend(target.cpu().detach().numpy().tolist())
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist())
    return fin_outputs, fin_targets

In [24]:
def run():
    df1 = pd.read_csv('../input/jigsaw-multilingual-toxic-comment-classification/jigsaw-toxic-comment-train.csv', usecols=['comment_text', 'toxic'])
    df2 = pd.read_csv('../input/jigsaw-multilingual-toxic-comment-classification/jigsaw-unintended-bias-train.csv', usecols=['comment_text', 'toxic'])

    df_train = pd.concat([df1, df2], axis=0)
    df_train = df_train.reset_index(drop=True)

    df_valid = pd.read_csv("../input/jigsaw-multilingual-toxic-comment-classification/validation.csv")

    train_dataset = BERTDataset(
        comment_text = df_train.comment_text.values,
        target = df_train.toxic.values
    )
    
    train_sampler = torch.utils.data.distributed.DistributedSampler(
    train_dataset,
    num_replicas=xm.xrt_world_size(), rank = xm.get_ordinal(), shuffle=True
    )
    
    
    train_data_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        num_workers=1,
        sampler = train_sampler,
        drop_last=True
    )

    valid_dataset = BERTDataset(
        comment_text = df_valid.comment_text.values,
        target = df_valid.toxic.values
    )
    
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
    valid_dataset,
    num_replicas=xm.xrt_world_size(),
    rank = xm.get_ordinal(),
    shuffle=True
    )
    
    valid_data_loader = torch.utils.data.DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        num_workers=1,
        sampler = valid_sampler
    )

    device = xm.xla_device()
    model = BERTBaseUncased()
    model.to(device)

    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.001,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]

    num_train_steps = int(len(df_train) / TRAIN_BATCH_SIZE / xm.xrt_world_size() * EPOCHS)
    lr = 3e-5 * xm.xrt_world_size()
    optimizer = AdamW(optimizer_parameters, lr=3e-5)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=num_train_steps
    )

    best_accuracy = 0
    for epoch in range(EPOCHS):
        print("here")
        para_loader = pl.ParallelLoader(train_data_loader, [device])
        train_fn(para_loader.per_device_loader(device), model, optimizer,device, scheduler)
        valid_para_loader = pl.ParallelLoader(valid_data_loader, [device])
        outputs, targets = eval_fn(valid_para_loader, model, device)
        targets = np.array(targets) >=0.5
        accuracy = metrics.roc_auc_score(targets, outputs)
        print("accuracy_score = {accuracy}".format(accuracy=accuracy))
        if(accuracy>best_accuracy):
            xm.save(model.state_dict(), MODEL_PATH)

In [9]:
import torch_xla.distributed.xla_multiprocessing as xmp

In [29]:
def _mp(rank, flags):
    torch.set_default_tensor_type("torch.FloatTensor")
    a = run()


xmp.spawn(_mp, args=({}, ), nprocs=1, start_method='fork')

here
bi = 0, loss = 0.6949639916419983
bi = 10, loss = 0.5543463230133057
bi = 20, loss = 0.2979893088340759
bi = 30, loss = 0.3130144476890564
bi = 40, loss = 0.2549452781677246
bi = 50, loss = 0.29970675706863403
bi = 60, loss = 0.244351327419281
bi = 70, loss = 0.38717958331108093
bi = 80, loss = 0.356704980134964
bi = 90, loss = 0.32659977674484253
bi = 100, loss = 0.34495478868484497
bi = 110, loss = 0.2458876669406891
bi = 120, loss = 0.2737136781215668
bi = 130, loss = 0.21093273162841797
bi = 140, loss = 0.2990308701992035
bi = 150, loss = 0.2106209546327591
bi = 160, loss = 0.2089516520500183


KeyboardInterrupt: 