In [None]:
%cd /kaggle/working
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py

In [None]:
!python pytorch-xla-env-setup.py --apt-packages libomp5 libopenblas

In [None]:
#%cd /kaggle/input/googleai4codesource/Google-AI4Code/

In [None]:
#!python -m src.clean --data /kaggle/input/googleai4codemerged/train_all.parquet --output /kaggle/working/train_cleaned.parquet --clear code

In [None]:
#!python -m src.featurize \
#            --data /kaggle/working/train_cleaned.parquet \
#            --output /kaggle/working/transformer_data.parquet \
#            --task transformer \
#            --features_out_path /kaggle/working/transformer_features.json \
#            --num_selected_code_cells 20 \
#            --mode train 

In [None]:
!apt-get install -y libomp5

In [None]:
import torch
import torch.nn as nn

from transformers import AutoModel, AutoTokenizer

device = "cuda" if torch.cuda.is_available() else "cpu"

class TransformersModel(nn.Module):
    def __init__(self, model_path):
        super(TransformersModel, self).__init__()
        self.model = AutoModel.from_pretrained(model_path)
        self.top = nn.Linear(769, 1)
        
    def forward(self, ids, mask, fts):
        x = self.model(ids, mask)[0]
        x = torch.cat((x[:, 0, :], fts),1)
        x = self.top(x)
        return x


In [None]:
import torch_xla
import torch_xla.debug.metrics as met
import torch_xla.distributed.parallel_loader as pl
import torch_xla.utils.utils as xu
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.test.test_utils as test_utils
import warnings
warnings.filterwarnings("ignore")
import gc


In [None]:
import json
import torch
import pickle
import numpy as np

from transformers import AutoModel, AutoTokenizer
from torch.utils.data import Dataset


class XGBrankerDataSet:
    def __init__(self, data_path):
        self.data_path = data_path

    def load_data(self):
        with open(self.data_path, "rb") as input_file:
            X_train, y_train, groups = pickle.load(input_file)
        return X_train, y_train, groups


class TransformersDataset(Dataset):
    def __init__(self, data_path):
        super().__init__()
        self.data_path = data_path

    def load_data(self):
        with open(self.data_path, "rb") as f:
            self.ids = np.load(f)
            self.masks = np.load(f)
            self.fts = np.load(f)
            self.ranks = np.load(f)

    def __getitem__(self, index):
        return (
            torch.from_numpy(self.ids[index]),
            torch.from_numpy(self.masks[index]),
            torch.FloatTensor([self.fts[index]]),
            torch.FloatTensor([self.ranks[index]]),
        )

    def __len__(self):
        return self.ids.shape[0]


In [None]:
def reduce_fn(vals):
    # take average
    return sum(vals) / len(vals)

In [None]:
import sys
import logging
import argparse
import torch
import numpy as np

from torch.utils.data import DataLoader
from transformers import AdamW, get_linear_schedule_with_warmup

from tqdm import tqdm


def to_device(data):
    return [d.to(device, dtype=torch.long) for d in data[0]], data[1].to(device, dtype=torch.float)


def train_transformer(
    data_path,
    output_model_path,
    model_name_or_path,
    accumulation_steps,
    batch_size,
    epochs,
    n_workers,
):
    np.random.seed(0)
    
    train_data = TransformersDataset(
        data_path,
    )

    train_data.load_data()
    
    #TPU
    # defining data samplers and loaders 
    train_sampler = torch.utils.data.distributed.DistributedSampler(
          train_data,
          num_replicas=xm.xrt_world_size(), # tell PyTorch how many devices (TPU cores) we are using for training
          rank=xm.get_ordinal(), # tell PyTorch which device (core) we are on currently
          shuffle=True)


    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=n_workers,
    )

    model = TransformersModel(model_name_or_path)
    model = model.to(device)
    xm.master_print('done loading model')

    # Creating optimizer and lr schedulers
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer if not any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.01,
        },
        {
            "params": [
                p for n, p in param_optimizer if any(nd in n for nd in no_decay)
            ],
            "weight_decay": 0.0,
        },
    ]
    
    lr = 3e-5 * xm.xrt_world_size() # scale the learning rate
   
    num_train_optimization_steps =  int(len(train_loader) / batch_size / xm.xrt_world_size() * epochs) 
    #num_train_optimization_steps = int(epochs * len(train_loader) / accumulation_steps)
    optimizer = AdamW(
        optimizer_grouped_parameters, lr=lr, correct_bias=False
    )  # To reproduce BertAdam specific behavior set correct_bias=False

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_optimization_steps,
    )  # PyTorch scheduler

    criterion = torch.nn.L1Loss()

    for e in range(epochs):
        gc.collect()
        model.train()
        para_loader = pl.ParallelLoader(train_loader, [device]) 
        tbar = tqdm(para_loader.per_device_loader(device))
        loss_list = []
        preds = []
        labels = []

        for idx, data in enumerate(tbar):
            ids, mask, fts, target = data
            
            optimizer.zero_grad()
            ids = ids.to(device)
            mask = mask.to(device)
            fts = fts.to(device)
            target = target.to(device)
            
            pred = model(ids, mask, fts)
            loss = criterion(pred, target)
            
            loss.backward()
            if idx % accumulation_steps == 0 or idx == len(tbar) - 1:
                xm.optimizer_step(optimizer, barrier=True)
                scheduler.step()
                
            loss_reduced = xm.mesh_reduce('loss_reduce',loss,reduce_fn) 
         
            tbar.set_description(
                f"Epoch {e + 1} Loss: {loss_reduced} lr: {scheduler.get_last_lr()}"
            )

        torch.save(model.state_dict(), output_model_path)    

train_transformer(
    '/kaggle/input/googleai4codemerged/transformer_data.npy',
    '/kaggle/working/trained_mode.bin',
    'microsoft/codebert-base',
    4,
    10,
    5,
    8,
)
