In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
# Input data files are available in the "../input/" directory.
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory
from tqdm import tqdm
tqdm.pandas()
import os
# Any results you write to the current directory are saved as output.
import string
import re    #for regex
import nltk
from nltk.corpus import stopwords
import spacy
from nltk import pos_tag
from nltk.stem.wordnet import WordNetLemmatizer 
from nltk.tokenize import word_tokenize
# Tweet tokenizer does not split at apostophes which is what we want
from nltk.tokenize import TweetTokenizer   
import time
import re



##### Credits: https://www.kaggle.com/christofhenkel/how-to-preprocessing-when-using-embeddings

In [None]:
## Load data
train = pd.read_csv('/kaggle/input/google-quest-challenge/train.csv')
test = pd.read_csv('/kaggle/input/google-quest-challenge/test.csv')
print("Train shape : ",train.shape)
print("Test shape : ",test.shape)

## Pre-processing functions  
various functions, viz removing special chars,punctuations,correcting spellings etc.

In [None]:
def decontract(text):
    """
    function to transform short-hand: can't--->can not
    """
    text = re.sub(r"(W|w)on(\'|\’)t ", "will not ", text)
    text = re.sub(r"(C|c)an(\'|\’)t ", "can not ", text)
    text = re.sub(r"(Y|y)(\'|\’)all ", "you all ", text)
    text = re.sub(r"(Y|y)a(\'|\’)ll ", "you all ", text)
    text = re.sub(r"(I|i)(\'|\’)m ", "i am ", text)
    text = re.sub(r"(A|a)isn(\'|\’)t ", "is not ", text)
    text = re.sub(r"n(\'|\’)t ", " not ", text)
    text = re.sub(r"(\'|\’)re ", " are ", text)
    text = re.sub(r"(\'|\’)d ", " would ", text)
    text = re.sub(r"(\'|\’)ll ", " will ", text)
    text = re.sub(r"(\'|\’)t ", " not ", text)
    text = re.sub(r"(\'|\’)ve ", " have ", text)
    return text

def clean_text(x):
    x = str(x)
    for punct in "/-'":
        x = x.replace(punct, ' ')
    for punct in '&':
        x = x.replace(punct, f' {punct} ')
    for punct in '?!.,"#$%\'()*+-/:;<=>@[\\]^_`{|}~' + '“”’':
        x = x.replace(punct, '')
    return x

def clean_numbers(x):
    x = re.sub('[0-9]{5,}', '12345', x)
    x = re.sub('[0-9]{4}', '1234', x)
    x = re.sub('[0-9]{3}', '123', x)
    x = re.sub('[0-9]{2}', '12', x)
    return x

mispell_dict = {'colour':'color',
                'centre':'center',
                'didnt':'did not',
                'doesnt':'does not',
                'isnt':'is not',
                'shouldnt':'should not',
                'favourite':'favorite',
                'travelling':'traveling',
                'counselling':'counseling',
                'theatre':'theater',
                'canceled':'cancelled',
                'labour':'labor',
                'organisation':'organization',
                'wwii':'world war 2',
                'ww2':'world war 2',
                'citicise':'criticize',
                'instagram': 'social medium',
                'whatsapp': 'social medium',
                'snapchat': 'social medium',
                'facebook': 'social medium',
                'pinterest': 'social medium',
                'linkedin': 'social medium'
                }
def _get_mispell(mispell_dict):
    """
    to rectify spellings as in identfier mispell_dict
    """
    mispell_re = re.compile('(%s)' % '|'.join(mispell_dict.keys()))
    return mispell_dict, mispell_re

mispellings, mispellings_re = _get_mispell(mispell_dict)

def replace_typical_misspell(text):
    def replace(match):
        return mispellings[match.group(0)]
    return mispellings_re.sub(replace, text)

useless_punct = ['च', '不', 'ঢ়', '平', 'ᠠ', '錯', '判', '∙',
                 '言', 'ς', 'ل', '្', 'ジ', 'あ', '得', '水', 'ь', '◦', '创', 
                 '康', '華', 'ḵ', '☺', '支', '就', '„', '」', '어', '谈', '陈', '团', '腻', '权', 
                 '年', '业', 'マ', 'य', 'ا', '売', '甲', '拼', '˂', 'ὤ', '贯', '亚', 'ि', '放', 'ʻ', 'ទ', 'ʖ', 
                 '點', '્', '発', '青', '能', '木', 'д', '微', '藤', '̃', '僕', '妒', '͜', 'ន', 'ध', '이', '希', '特',
                 'ड', '¢', '滢', 'ส', '나', '女', 'క', '没', '什', 'з', '天', '南', 'ʿ', 'ค', 'も', '凰', '步', '籍', '西',
                 'ำ', '−', 'л', 'ڤ', 'ៃ', '號', 'ص', 'स', '®', 'ʋ', '批', 'រ', '치', '谢', '生', '道', '═', '下', '俄', 'ɖ',
                 '觀', 'வ', '—', 'ی', '您', '♥', '一', 'や', '⊆', 'ʌ', '語', 'ี', '兴', '惶', '瀛', '狐', '⁴', 'प', '臣', 'ద',
                 '―', 'ì', 'ऌ', 'ీ', '自', '信', '健', '受', 'ɨ', '시', 'י', 'ছ', '嬛', '湾', '吃', 'ち', 'ड़', '反', '红', '有',
                 '配', 'ে', 'ឯ', '宮', 'つ', 'μ', '記', '口', '℅ι', 'ो', '狸', '奇', 'о', 'ट', '聖', '蘭', '読', 'ū', '標', '要', 
                 'ត', '识', 'で', '汤', 'ま', 'ʀ', '局', 'リ', '्', 'ไ', '呢', '工', 'ल', '沒', 'τ', 'ិ', 'ö', 'せ', '你', 'ん', 'ュ', 
                 '枚', '部', '大', '罗', 'হ', 'て', '表', '报', '攻', 'ĺ', 'ฉ', '∩', '宝', '对', '字', '文', '这', '∑', '髪', 'り', '่', '능',
                 '罢', '내', '阻', '为', '菲', 'ي', 'न', 'ί', 'ɦ', '開', '†', '茹', '做', '東', 'ত', 'に', 'ت', '晓', '키', '悲', 'સ', 
                 '好', '›', '上', '存', '없', '하', '知', 'ធ', '斯', ' ', '授', 'ł', '傳', '兰', '封', 'ோ', 'و', 'х', 'だ', '人', '太', 
                 '品', '毒', 'ᡳ', '血', '席', '剔', 'п', '蛋', '王', '那', '梦', 'ី', '彩', '甄', 'и', '柏', 'ਨ', '和', '坊', '⌚', '广', 
                 '依', '∫', 'į', '故', 'ś', 'ऊ', '几', '日', 'ک', '音', '×', '”', '▾', 'ʊ', 'ज', 'ด', 'ठ', 'उ', 'る', '清', 'ग', 'ط',
                 'δ', 'ʏ', '官', '∛', '়', '้', '男', '骂', '复', '∂', 'ー', '过', 'য', '以', '短', '翻', 'র', '教', '儀', 'ɛ', '‹', 'へ', 
                 '¾', '合', '学', 'ٌ', '학', '挑', 'ष', '比', '体', 'م', 'س', 'អ', 'ת', '訓', '∀', '迎', 'វ', 'ɔ', '٨', '▒', '化', 'చ', '‛', 
                 'প', 'º', 'น', '업', '说', 'ご', '¸', '₹', '儿', '︠', '게', '骨', 'ท', 'ऋ', 'ホ', '茶', '는', 'જ', 'ุ', '羡', '節', 'ਮ', 
                 'উ', '番', 'ড়', '讲', 'ㅜ', '등', '伟', 'จ', '我', 'ล', 'す', 'い', 'ញ', '看', 'ċ', '∧', 'भ', 'ઘ', 'ั', 'ម', '街', 'ય', 
                 '还', '鰹', 'ខ', 'ు', '訊', 'म', 'ю', '復', '杨', 'ق', 'त', '金', '味', 'ব', '风', '意', '몇', '佬', '爾', '精', '¶', 
                 'ం', '乱', 'χ', '교', 'ה', '始', 'ᠰ', '了', '个', '克', '্', 'ห', '已', 'ʃ', 'わ', '新', '译', '︡', '本', 'ง', 'б', 'け', 
                 'ి', '明', '¯', '過', 'ك', 'ῥ', 'ف', 'ß', '서', '进', 'ដ', '样', '乐', '寧', '€', 'ณ', 'ル', '乡', '子', 'ﬁ', 'ج', '慕',
                 '–', 'ᡵ', 'Ø', '͡', '제', 'Ω', 'ប', '絕', '눈', 'फ', 'ম', 'గ', '他', 'α', 'ξ', '§', 'ஜ', '黎', 'ね', '복', 'π', 'ú', '鸡',
                 '话', '会', 'ক', '八', '之', '북', 'ن', '¦', '가', 'ו', '恋', '地', 'ῆ', '許', '产', 'ॡ', 'ش', '़', '野', 'ή', 'ɒ', '啧',
                 'យ', '᠌', 'ᠨ', 'ب', '皎', '老', '公', '☆', 'व', 'ি', 'ល', 'ر', 'គ', '행', 'ង', 'ο', '让', 'ំ', 'λ', 'خ', 'ἰ', '家',
                 'ট', 'ब', '理', '是', 'め', 'र', '√', '기', 'ν', '玉', '한', '入', 'ד', '别', 'د', 'ะ', '电', 'ા', '♫', 'ع', 'ં', '堵',
                 '嫉', '伊', 'う', '千', '관', '篇', 'क', '非', '荣', '粵', '瑜', '英', '를', '美', '条', '`', '宋', '←', '수', '後', '•',
                 '³', 'ी', '고', '肉', '℃', 'し', '漢', '싱', 'ϵ', '送', 'ه', '落', 'న', 'ក', 'க', 'ℇ', 'た', 'ះ', '中', '射', '♪', '符',
                 'ឃ', '谷', '分', '酱', 'び', 'থ', 'ة', 'г', 'σ', 'と', '楚', '胡', '饭', 'み', '禮', '主', '直', '÷', '夢', 'ɾ', 'চ', '⃗',
                 '統', '高', '顺', '据', 'ら', '頭', 'よ', '最', 'ా', 'ੁ', '亲', 'ស', '花', '≡', '眼', '病', '…', 'の', '發', 'ா', '汝',
                 '★', '氏', 'ร', '景', 'ᡠ', '读', '件', '仲', 'শ', 'お', 'っ', 'پ', 'ᡤ', 'ч', '♭', '悠', 'ं', '六', '也', 'ռ', 'য়', '恐', 
                 'ह', '可', '啊', '莫', '书', '总', 'ষ', 'ք', '̂', '간', 'な', '此', '愛', 'ర', 'ใ', '陳', 'Ἀ', 'ण', '望', 'द', '请', '油',
                 '露', '니', 'ş', '宗', 'ʍ', '鳳', 'अ', '邋', '的', 'ព', '火', 'ा', 'ก', '約', 'ட', '章', '長', '商', '台', '勢', 'さ',
                 '국', 'Î', '簡', 'ई', '∈', 'ṭ', '經', '族', 'ु', '孫', '身', '坑', 'স', '么', 'ε', '失', '殺', 'ž', 'ર', 'が', '手',
                 'ា', '心', 'ਾ', '로', '朝', '们', '黒', '欢', '早', '️', 'া', 'आ', 'ɸ', '常', '快', '民', 'ﷺ', 'ូ', '遢', 'η', '国', 
                 '无', '江', 'ॠ', '「', 'ন', '™', 'ើ', 'ζ', '紫', 'ె', 'я', '“', '♨', '國', 'े', 'อ', '∞', 
                  '\n', "{\n', '}\n", "=&gt;", '}\n\n', '-&gt;', '\n\ni', '&lt;','/&gt;\n','{\n\n','\\','|','&','\\n\\n',"\\appendix"]
useless_punct.remove(' ')
def remove_useless_punct(text):
    """
    to remove punctuation symbols as in identifier useless_punct
    """
    return re.sub(f'{"|".join(useless_punct)}', '', text)

letter_mapping = {'\u200b':' ', 'ũ': "u", 'ẽ': 'e', 'é': "e", 'á': "a", 'ķ': 'k', 
                  'ï': 'i', 'Ź': 'Z', 'Ż': 'Z', 'Š': 'S', 'Π': ' pi ', 'Ö': 'O', 
                  'É': 'E', 'Ñ': 'N', 'Ž': 'Z', 'ệ': 'e', '²': '2', 'Å': 'A', 'Ā': 'A',
                  'ế': 'e', 'ễ': 'e', 'ộ': 'o', '⧼': '<', '⧽': '>', 'Ü': 'U', 'Δ': 'delta',
                  'ợ': 'o', 'İ': 'I', 'Я': 'R', 'О': 'O', 'Č': 'C', 'П': 'pi', 'В': 'B', 'Φ': 
                  'phi', 'ỵ': 'y', 'օ': 'o', 'Ľ': 'L', 'ả': 'a', 'Γ': 'theta', 'Ó': 'O', 'Í': 'I',
                  'ấ': 'a', 'ụ': 'u', 'Ō': 'O', 'Ο': 'O', 'Σ': 'sigma', 'Â': 'A', 'Ã': 'A', 'ᗯ': 'w', 
                  'ᕼ': "h", "ᗩ": "a", "ᖇ": "r", "ᗯ": "w", "O": "o", "ᗰ": "m", "ᑎ": "n", "ᐯ": "v", "н": 
                  "h", "м": "m", "o": "o", "т": "t", "в": "b", "υ": "u",  "ι": "i","н": "h", "č": "c", "š":
                  "s", "ḥ": "h", "ā": "a", "ī": "i", "à": "a", "ý": "y", "ò": "o", "è": "e", "ù": "u", "â": 
                  "a", "ğ": "g", "ó": "o", "ê": "e", "ạ": "a", "ü": "u", "ä": "a", "í": "i", "ō": "o", "ñ": "n",
                  "ç": "c", "ã": "a", "ć": "c", "ô": "o", "с": "c", "ě": "e", "æ": "ae", "î": "i", "ő": "o", "å": 
                  "a", "Ä": "A","&gt":" greater than","&lt" :"lesser than", "(not" : "not" , "});":"",">" :"greater","<":"lesser" ,"$":"dollar","\\\\":" ","\\": " "} 
def clean_special_chars(text):
    """
    clean weird / special characters as in identifier letter_mapping
    """
    new_text = ''
    for i in range(len(text)):
        if i in letter_mapping:
            c = letter_mapping[i]
        else:
            c = text[i]
        new_text += c
    return new_text

def clean_apostrophes(sentence):
    apostrophes = ["’", "‘", "´", "`"]
    for s in apostrophes:
        sentence = re.sub(s, "'", sentence)
    return sentence


In [None]:
from tqdm import tqdm
tqdm.pandas()
def build_vocab(sentences, verbose =  True):
    """
     this function constructs vocubulary from sentences of a corpus
    :param sentences: list of list of words
    :return: dictionary of words and their count
    """
    vocab = {}
    for sentence in tqdm(sentences, disable = (not verbose)):
        for word in sentence:
            try:
                vocab[word] += 1
            except KeyError:
                vocab[word] = 1
    return vocab

In [None]:
%%time
import operator 
def check_coverage(vocab,embeddings_index):
    """
    function to check coverage between vocabulary and the word embedding
    :param vocab: dictionary of wordtypes from the corpus
    :param embeddings_index: embeddings
    :return: dictionary, consisting of out of vocabulary words w.r.t embeddings passed as argument
    """
    word_embedding = {}
    oov = {}
    k = 0
    i = 0
    for word in tqdm(vocab):#vocab has unique words in a corpus:wordtypes/vocabulary of a corpus
        try:
            word_embedding[word] = embeddings_index[word]
            k += vocab[word]#frequency of word in the corpus
        except:
            oov[word] = vocab[word]
            i += vocab[word]
            pass
    print('Found embeddings for {:.2%} of vocabulary'.format(len(word_embedding) / len(vocab)))
    print('Found embeddings for {:.2%} of corpus'.format(k / (k + i)))
    sorted_x = sorted(oov.items(), key=operator.itemgetter(1),reverse=True)
    return sorted_x

## Building vocabulary

In [None]:
colname="question_body"
sentences = train[colname].progress_apply(lambda x: x.split()).values
vocab = build_vocab(sentences)
print("vocabulary size for",colname,":",len(vocab))

## The following 4-steps are used for any of the pre-processing techniques above, 
1.Apply the **pre-processing** method of choice on the **sentence**.  
2.Split the **sentence**.  
3.**build** the vocabulary.  
4.check **coverage**.

In [None]:
def preprocess(embeddings_index):
    #preprocess method: decontract(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: decontract(sentence))
    sentences = train[colname].apply(lambda x: x.split())
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)
    #preprocess method:clean_apostrophes(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: clean_apostrophes(sentence))
    sentences = train[colname].apply(lambda x: x.split())
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)
    #preprocess method:clean_special_chars(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: clean_special_chars(sentence))
    sentences = train[colname].apply(lambda x: x.split())
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)
    #preprocess method: remove_useless_punct(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: remove_useless_punct(sentence))
    sentences = train[colname].apply(lambda x: x.split())
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)
    #preprocess method:clean_text(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: clean_text(sentence))
    sentences = train[colname].apply(lambda x: x.split())
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)
    #preprocess method:clean_numbers(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: clean_numbers(sentence))
    sentences = train[colname].apply(lambda x: x.split())
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)
    #preprocess method:replace_typical_misspell(sentence)
    train[colname] = train[colname].progress_apply(lambda sentence: replace_typical_misspell(sentence))
    sentences = train[colname].progress_apply(lambda x: x.split())
    to_remove = ['a','to','of','and']
    sentences = [[word for word in sentence if not word in to_remove] for sentence in tqdm(sentences)]
    vocab = build_vocab(sentences)
    oov = check_coverage(vocab,embeddings_index)

### Let us now use GloVe embeddings

In [None]:
import pickle
def load_embeddings(path):
    with open(path,'rb') as f:
        emb_arr = pickle.load(f)
    return emb_arr
GLOVE_EMBEDDING_PATH = '/kaggle/input/pickled-glove840b300d-for-10sec-loading/glove.840B.300d.pkl' 
tic = time.time()
glove_embeddings = load_embeddings(GLOVE_EMBEDDING_PATH)
print(f'loaded {len(glove_embeddings)} word vectors in {time.time()-tic}s')
#check coverage with glove embedding
oov = check_coverage(vocab,glove_embeddings)

In [None]:
train = pd.read_csv('/kaggle/input/google-quest-challenge/train.csv')
test = pd.read_csv('/kaggle/input/google-quest-challenge/test.csv')
train['host'].apply(in set(train['host']).difference(set(test['host']))

In [None]:
train = pd.read_csv('/kaggle/input/google-quest-challenge/train.csv')
test = pd.read_csv('/kaggle/input/google-quest-challenge/test.csv')
to_del_indices = train[(train['host']=='meta.christianity.stackexchange.com')|(train['host']=='rpg.stackexchange.com')].index
set(Counter(train['host']).keys()).difference(set(Counter(test['host']).keys()))
# Delete these row indexes from dataFrame
print("before deletion, train.shape:", train.shape)
train.drop(to_del_indices , inplace=True)
print("after deletion, train.shape:", train.shape)

print("Test shape : ",test.shape)
preprocess(glove_embeddings)

In [None]:
train = pd.read_csv('/kaggle/input/google-quest-challenge/train.csv')
test = pd.read_csv('/kaggle/input/google-quest-challenge/test.csv')
to_del_indices = train[(train['host']=='meta.christianity.stackexchange.com')|(train['host']=='rpg.stackexchange.com')].index

# Delete these row indexes from dataFrame
print("before deletion, train.shape:", train.shape)
train.drop(to_del_indices , inplace=True)
print("after deletion, train.shape:", train.shape)

print("Test shape : ",test.shape)
colname="answer"
preprocess(glove_embeddings)

Found GloVe embeddings for 49.51% of "question_body" vocabulary  
Found GloVe embeddings for 92.11% of "question_body" corpus  
Found GloVe embeddings for 52.33% of "answer" vocabulary  
Found GloVe embeddings for 94.36% of "answer" corpus  

 ## Loading fasttext crawl

In [None]:
%%time 
from gensim.models import KeyedVectors
news_path = '/kaggle/input/fasttext-crawl-300d-2m/crawl-300d-2M.vec'
embeddings_index = KeyedVectors.load_word2vec_format(news_path, binary=False)
#checking coverage with fasttext embedding
oov = check_coverage(vocab,embeddings_index)#without preprocessing vocabulary

## Loading fasttext wikitext

In [None]:
import pickle
def load_embeddings(path):
    with open(path,'rb') as f:
        emb_arr = pickle.load(f)
    return emb_arr
FASTTEXT_EMBEDDING_PATH = '/kaggle/input/fasttext-wikinews/wiki-news-300d-1M.pickle' 
tic = time.time()
fasttext_embeddings = load_embeddings(FASTTEXT_EMBEDDING_PATH)
print(f'loaded {len(glove_embeddings)} word vectors in {time.time()-tic}s')
#check coverage with glove embedding
oov = check_coverage(vocab,fasttext_embeddings)

In [None]:
train = pd.read_csv('/kaggle/input/google-quest-challenge/train.csv')
test = pd.read_csv('/kaggle/input/google-quest-challenge/test.csv')
to_del_indices = train[(train['host']=='meta.christianity.stackexchange.com')|(train['host']=='rpg.stackexchange.com')].index

# Delete these row indexes from dataFrame
print("before deletion, train.shape:", train.shape)
train.drop(to_del_indices , inplace=True)
print("after deletion, train.shape:", train.shape)
6
print("Test shape : ",test.shape)
preprocess(fasttext_embeddings)

In [None]:
import torch
import transformers
import torch.nn as nn
import pandas as pd
import numpy as np
from sklearn import model_selection
from transformers import AdamW,get_linear_schedule_with_warmup
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multioprocessing as xmp
import torch_xla.distributed.parallel_loader as pl
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

In [None]:
class BERTBaseUncased(nn.Module):
    def __init__(self,bert_path):
        super(BERTBaseUncased,self).__init__()
        self.bert_path=bert_path
        self.bert=transformers.BertModel.from_pretrained(self_path)
        self.bert_drop=nn.Dropout(0.3)
        self.out=nn.Linear(768,30)
    def forward(self,ids,mask,token_type_ids):
        _,o2=self.bert(ids,attention_mask=mask,token_type_ids=token_type_ids)
        bo=self.bert_drop(o2)
        return self.out(bo)
    
class BERTDatasetTraining:
    def __init__(self,qtitle,qbody,answer,targets,tokenizer,max_len):
        self.qtitle=qtitle
        self.qbody=qbody
        self.answer=answer
        self.tokenizer=tokenizer
        self.max_len=max_len
        self.targets=targets
    def __len__(self):
        return len(self.answer)
    def __getitem__(self,item):
        question_title=str(self.qtitle[item])
        question_body=str(self.qbody[item])
        answer=str(self.answer[item])
        inputs=self.tokenizer.encode_plus(
            question_title+" "+question_body,
            answer,
            add_special_tokens=True,
            max_length=self.max_len
        )
        ids=inputs["input_ids"]
        token_type_ids=inputs["token_type_ids"]
        mask=inputs["attention_mask"]
        padding_len=self.max_len-len(ids)
        ids=ids+([0]*padding_len)
        token_type_ids=token_type_ids+([0]*padding_len)
        mask=mask+([0]*padding_len)
        return{
            "ids":torch.tensor(ids,dtype=torch.long),
                "mask":torch.tensor(mask,dtype=torch.long),
              "token_type_ids":torch.tensor(token_type_ids,dtype=torch.long),
               "targets":torch.tensor(self.targets[item,:],dtype=torch.float)
        }
    
    def loss_fn(outputs,targets):
        return nn.BCEWithLogitsLoss()(outputs,targets)
    
    def train_loop_fn(data_loader,model,optimizer,device,scheduler=None):
        model.train()
        for bi,d in enumerate(data_loader):
            ids=d["ids"]
            mask=d["mask"]
            token_type_ids=d["token_type_ids"]
            targets=d["targets"]
            ids=ids.to(device,dtype=torch.long)
            mask=mask.to(device,dtype=torch.long)
            token_type_ids=token_type_ids.to(device,dtype=torch.long)
            targets=targets.to(device,dtype=torch.float)
            
            optimizer.zero_grad()
            outputs=model(ids=ids,mask=mask,token_type_ids=token_type_ids)
            loss=loss_fn(outputs,targets)
            loss.backward()
            xm.optimizer_step(optimizer)
            if scheduler is not None:
                scheduler.step()
            if bi%10==0:
                xm.master_print(f"bi={bi},loss={loss}")
            
    def eval_loop_fn(data_loader,model,device):
        model.eval()
        fin_targets=[]
        fin_outputs=[]
        for bi,d in enumerate(data_loader):
            ids=d["ids"]
            mask=d["mask"]
            token_type_ids=d["token_type_ids"]
            targets=d["targets"]
            ids=ids.to(device,dtype=torch.long)
            mask=mask.to(device,dtype=torch.long)
            token_type_ids=token_type_ids.to(device,dtype=torch.long)
            targets=targets.to(device,dtype=torch.float)
        
            outputs=model(ids=ids,mask=mask,token_type_ids=token_type_ids)
            loss=loss_fn(outputs,targets)
            fin_targets.append(targets.cpu().detach().numpy())
            fin_outputs.append(outputs.cpu().detach().numpy())
        return np.vstack(fin_outputs),np.vtack(fin_targets)
    
    def run(index):
        MAX_LEN=512
        TRAIN_BATCH_SIZE=32
        EPOCHS=20
        dfx=pd.read_csv("train.csv").fillna("none")
        df_train,df_valid=model_selection.train_test_split(dfx,random_state=42,test_size=0.1)
        df_train = df_train.reset_index(drop=True)
        df_valid = df_valid.reset_index(drop=True)
        
        sample=pd.read_csv("submission.csv")
        target_cols=list(sample.drop("qa_id",axis=1).columns)
        train_targets = df_train[target_cols].values
        valid_targets = df_valid[target_cols].values
        
        tokenizer=transformers.BertTokenizer.from_pretrained("bert_base_uncased")
        
        train_dataset=BERTDatasetTraining(
        qtitle=df_train.question_title.values,
        qbody=df_train.question_body.values,
        answer=df_train.answer.values,
        targets=train_targets,
        tokenizer=tokenizer,
        max_len=MAX_LEN
        )
        train_sampler = torch.utils.data.DistributedSampler(
            train_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
            shuffle=True
        )
        train_data_loader=torch.utils.data.DataLoader(
            train_dataset,
            batch_size=TRAIN_BATCH_SIZE,
            sampler=train_sampler
        )
        
        valid_dataset=BERTDatasetTraining(
        qtitle=df_valid.question_title.values,
        qbody=df_valid.question_body.values,
        answer=df_valid.answer.values,
        targets=valid_targets,
        tokenizer=tokenizer,
        max_len=MAX_LEN
        )
        valid_sampler = torch.utils.data.DistributedSampler(
            valid_dataset,
            num_replicas=xm.xrt_world_size(),
            rank=xm.get_ordinal(),
        )
        valid_data_loader=torch.utils.data.DataLoader(
            valid_dataset,
            batch_size=4,###change
            sampler=valid_sampler
        )
        
        device=xm.xla_device()
        lr=3e-5*xm.xrt_world_size()###change
        num_train_steps=int(len(train_dataset)/TRAIN_BATCH_SIZE/xm.xrt_world_size()*EPOCHS)
        model=BERTBaseUncased("bert_base_uncased")
        
        optimizer=AdamW(model.parametes(),lr=lr)
        scheduler=get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=0,
            num_training_steps=num_train_steps
        )
        
        for epoch in range(EPOCHS):
            para_loader=pl.ParallelLoader(train_data_loader,[device])
            train_loop_fn(para_loader.per_device_loader(device),model,optimizer,device,scheduler)
            para_loader=pl.ParallelLoader(valid_data_loader,[device])
            o,t=eval_loop_fn(para_loader.per_device_loader(device),model,device)
            
            spear=[]
            for jj in range(t.shape[1]):
                p1=list(t[:,jj])
                p2=list(o[:,jj])
                coef,_=np.nan_to_num(stats.spearmanr(p1,p2))
                spear.append(coef)
            spear=np.mean(spear)
            xm.master_print(f"epoch={epoch},spearman={spear}")
            xm.save(model.state_dict(),"model.bin")
if __name__=="__main__":
    xmp.spawn(run,nprocs=8)