In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install transformers
!pip install sentencepiece
!pip install pytorch_lightning

In [None]:
import json
import torch
import torch.nn as nn
import os, glob, re, gc
import pandas as pd
import numpy as np
from tqdm import tqdm
from pathlib import Path
import matplotlib.pyplot as plt
from itertools import permutations 
from collections import defaultdict
from transformers import (T5ForConditionalGeneration,
                          AdamW,
                          T5TokenizerFast as token)

from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

import pytorch_lightning as pl
pl.seed_everything(13)
print(torch.__version__)

Global seed set to 13


1.8.1+cu101


In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()

def clean_text_dig(txt):
    return re.sub('[^A-Za-z]+', ' ', str(txt).lower()).strip()

def find_start(title, text):
    start = re.search(r'\b%s\b' % title, text)
    if start == None:
        start = 0
    else:
        start = start.start()
    return start

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

def totally_clean_text(txt):
    txt = clean_text(txt)
    txt = re.sub(' +', ' ', txt)
    return txt

def make_interval(start:int, txt:str, interval: int)->str:
    """
    start: int return re.search() count latter
    txt: str text
    interval: int len text 

    return: str text
    """
    words = txt.split()
    len_interval = interval * 2
    start = len(txt[:start].split())    
    if (start - interval) > 0:
        start = start - interval
        if start + len_interval < len(words):
            words = words[start: start + len_interval]
        else:              
            words = words[start: len(words)]
    else:      
        words = words[0: len_interval]
    return ' '.join(words)

def count_answer(df: pd.DataFrame)-> dict:
    df['len_title'] = df.section_title.apply(lambda x: len(x.split()))
    #need more clear title skip error
    qwest = df[df.len_title >= 4].section_title.values
    txt = df.text.str.cat()
    txt = clean_text(txt)
    tmp = {}
    for qw in qwest:
        qw = clean_text_dig(qw)  
        if qw in txt:
            start = find_start(qw, txt)
            tmp['section_title'] = qw  
            tmp['text'] = make_interval(start, txt, 396)
    return tmp

In [None]:
PATH_ORI_TRAIN = '/content/drive/MyDrive/Coleridge_Initiative/input/train.csv'
PATH_ORI_JSON = '/content/drive/MyDrive/Coleridge_Initiative/input/train'
PATH_ORI_JSON_TEST = '/content/drive/MyDrive/Coleridge_Initiative/input/test'
PATH_SUB = '/content/drive/MyDrive/Coleridge_Initiative/input/sample_submission.csv'
train = pd.read_csv(PATH_ORI_TRAIN)
sub = pd.read_csv(PATH_SUB)

In [None]:
papers = {}
for json_id in tqdm(train['Id'].unique()):
    with Path(PATH_ORI_JSON, json_id + '.json').open('r') as jsn:
        cur_jsn = json.load(jsn)
        tmp_df = pd.DataFrame(cur_jsn)
        dct = count_answer(tmp_df)
        if len(dct) > 0:
            papers[paper_id] = [dct]
            
len(papers)

100%|██████████| 14316/14316 [1:04:28<00:00,  3.70it/s]


1

In [None]:
for paper_id in sub['Id']:
    with Path(PATH_ORI_JSON_TEST, json_id + '.json').open('r') as jsn:
        paper = json.load(f)
        tmp_df = pd.DataFrame(paper)
        dct = count_answer(tmp_df)
        papers[paper_id] = [dct]

len(papers)

5391

In [None]:
# def count_answer(df: pd.DataFrame):
#     df['len_title'] = df.section_title.apply(lambda x: len(x.split()))
#     #need more clear title skip error
#     qwest = df[df.len_title >= 4].section_title.values
#     txt = df.text.str.cat()
#     txt = clean_text(txt)
#     count = 0
#     for qw in qwest:
#         qw = clean_text_dig(qw)    
#         if qw in txt:
#             count += 1
#     return count

# tmp = {}
# for k in tqdm(papers.keys()):
#     df = pd.DataFrame(papers[str(k)])
#     count = count_answer(df)
#     if count >= 0:
#         tmp[str(k)] = count  

# count_zero = 0
# count_not = 0
# for k,v in tmp.items():
#     if v == 0:
#         count_zero += 1
#     else:
#         count_not += 1
# print(count_zero, count_not)

In [None]:
all_labels = set()

for label_1, label_2, label_3 in train[['dataset_title', 'dataset_label', 'cleaned_label']].itertuples(index=False):
    all_labels.add(str(label_1).lower())
    all_labels.add(str(label_2).lower())
    all_labels.add(str(label_3).lower())

print(f'No. different labels: {len(all_labels)}')

No. different labels: 180


In [None]:
papers2 = {}
for paper_id in train['Id'].unique():
    with Path(PATH_ORI_JSON, json_id + '.json').open('r') as jsn:
        paper = json.load(jsn)
        papers2[paper_id] = paper


for paper_id in sample_submission['Id']:
    with Path(PATH_ORI_JSON_TEST, json_id + '.json').open('r') as jsn:
        paper = json.load(jsn)
        papers2[paper_id] = paper
        
len(papers2)

14316

In [None]:
literal_preds = []

for paper_id in sample_submission['Id']:
    paper = papers2[paper_id]
    text_1 = '. '.join(section['text'] for section in paper).lower()
    text_2 = totally_clean_text(text_1)
    
    labels = set()
    for label in all_labels:
        if label in text_1 or label in text_2:
            labels.add(clean_text(label))
    
    literal_preds.append('|'.join(labels))

literal_preds

In [None]:
BATCH = 3
EPOCHS =1

config={
    "learning_rate": 0.0001,
    "architecture": "T5",
    'model': 't5-small',# 't5-base',
    "dataset": "Coleridge Initiative ",
    'tex_max_len': 396,
    'asw_max_len': 12,
    'batch_size' : BATCH,
    'epoch':EPOCHS,
    'device': 'cuda'
}

In [None]:
class CI(nn.Module):

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.model = T5ForConditionalGeneration.from_pretrained('../input/ci-model-small/model', return_dict = True)

    def forward(self, input_ids, attention_mask, labels):
        out = self.model(input_ids = input_ids,
                    attention_mask = attention_mask,
                    labels = labels
                    )
        return out.loss, out.logits


def make_pred(question:str, pre_model, tokenizer)->str:
    encode_test = tokenizer(question['question'],
                        question['text'],    
                        max_length = 396, 
                        padding= 'max_length',
                        truncation = False,#'only_second',
                        return_attention_mask = True,
                        add_special_tokens =True,
                        return_tensors = 'pt'
                        )
    gen_ids = pre_model.model.generate(
        input_ids = encode_test['input_ids'],
        attention_mask = encode_test['attention_mask'],
        num_beams = 5,
        no_repeat_ngram_size = 1,            
        num_return_sequences = 1,     
        do_sample=True,      
        top_k=0,
        top_p=0.92,       
        max_length = 8,
        repetition_penalty = 2.5,
        length_penalty =0.5,               
        early_stopping = True,
        use_cache = True
        )


    decode = [
              tokenizer.decode(ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
              for ids in gen_ids
              ]
    return ' '.join(decode)
    

def make_interval(start:int, txt:str, interval: int)->str:
    """
    start: int return re.search() count latter
    txt: str text
    interval: int len text 

    return: str text
    """
    words = txt.split()
    len_interval = interval * 2
    start = len(txt[:start].split())    
    if (start - interval) > 0:
        start = start - interval
        if start + len_interval < len(words):
            words = words[start: start + len_interval]
        else:              
            words = words[start: len(words)]
    else:      
        words = words[0: len_interval]
    return ' '.join(words)

In [None]:
"""
{'d0fa7568-7d8e-4db9-870f-f9c6f668c17b': [
        {'section_title': 'What is this study about?',
        'text': 'This study used data from the National Education Longitudinal Study (NELS:88)...
        }],
},
"""
def makedata(title, text, model, tokenizer, skip_title = True):
    data = {'question':title,
            'text': text,
           }  
    y_ = make_pred(data, model, tokenizer)
    y_ = clean_text(y_)
    return y_


model = CI(config)
MODEL = config['model']
model.load_state_dict(torch.load('../input/ci-model-small/model_check_predict_small.pth'))
model.eval()
tokenizer = token.from_pretrained('../input/ci-model-small/token')
labels = []
for paper_id in sample_submission['Id']:
    paper = papers[paper_id]  
    label = []
    for dct in paper:
        if len(dct) > 0:
            title = dct['section_title']
            txt = dct['text']
            if (title  != '') and (txt  != ''):
                y_ = makedata(title, txt, model, tokenizer)        
            else:
                y_ = ''
        else: y_ = ''
        if y_ != '': label.append(y_)
        gc.collect()
    labels.append(np.unique(label))  

In [None]:
labels

[array([], dtype=float64),
 array(['isced97 level'], dtype='<U13'),
 array(['and ocraco'], dtype='<U10'),
 array([], dtype=float64)]

In [None]:
def jaccard_similarity(s1, s2):
    l1 = s1.split(" ")
    l2 = s2.split(" ")    
    intersection = len(list(set(l1).intersection(l2)))
    union = (len(l1) + len(l2)) - intersection
    return float(intersection) / union

filtered_bert_labels = []

for lab in labels:
    filtered = []
    
    for label in sorted(lab, key=len):
  
        label = clean_text(label)
        if len(filtered) == 0 or all(jaccard_similarity(label, got_label) < 0.75 for got_label in filtered):
            filtered.append(label)
    
    filtered_bert_labels.append('|'.join(filtered))

In [None]:
filtered_bert_labels

['', 'isced97 level', 'and ocraco', '']

In [None]:
final_predictions = []
for literal_match, bert_pred in zip(literal_preds, filtered_bert_labels):
    if literal_match:
        final_predictions.append(literal_match)
    else:
        final_predictions.append(bert_pred)

In [None]:
final_predictions

['alzheimer s disease neuroimaging initiative adni|adni',
 'common core of data|nces common core of data|trends in international mathematics and science study',
 'sea lake and overland surges from hurricanes|slosh model|noaa storm surge inundation',
 'rural urban continuum codes']

In [None]:
sample_submission_path = '../input/coleridgeinitiative-show-us-the-data/sample_submission.csv'
sample_submission = pd.read_csv(sample_submission_path)

In [None]:
sample_submission['PredictionString'] = final_predictions

In [None]:
sample_submission.to_csv(f'submission.csv', index=False)
sample_submission

Unnamed: 0,Id,PredictionString
0,2100032a-7c33-4bff-97ef-690822c43466,alzheimer s disease neuroimaging initiative ad...
1,2f392438-e215-4169-bebf-21ac4ff253e1,common core of data|nces common core of data|t...
2,3f316b38-1a24-45a9-8d8c-4e05a42257c6,sea lake and overland surges from hurricanes|s...
3,8e6996b4-ca08-4c0b-bed2-aaf07a4c6a60,rural urban continuum codes
