In [96]:
import os
import torch
from torch import nn
import pandas as pd
from tqdm.auto import tqdm
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader,WeightedRandomSampler
from torch.optim import AdamW,Adam
import torch.nn.functional as F
import re

In [8]:
from transformers import AutoModel,AutoTokenizer,BertModel
from kobert_tokenizer import KoBERTTokenizer

In [10]:
model_path = 'skt/kobert-base-v1' #multilingual

In [83]:
class BertClassifier(nn.Module):

    def __init__(self, dropout=0.5):

        super(BertClassifier, self).__init__()

        self.bert = AutoModel.from_pretrained(model_path)
        self.dropout1 = nn.Dropout(dropout)
        self.linear1 = nn.Linear(768, 1024)
        self.dropout2 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(1024,1)
        
    def forward(self, input_id, mask):
        last_hidden_state,_ = self.bert(input_ids= input_id, attention_mask=mask,return_dict=False)
        dropout_output = self.dropout1(last_hidden_state)
        linear_output = self.linear1(dropout_output)
        dropout_output1 = self.dropout2(linear_output)
        linear_output1 = self.linear2(dropout_output1)
        sigmoid_output = torch.sigmoid(linear_output1)
        return sigmoid_output[:,0,:]

In [84]:
tokenizer = KoBERTTokenizer.from_pretrained('skt/kobert-base-v1')
classifier = BertClassifier()
classifier.load_state_dict(torch.load('/home/se/paper/classifier/kobert_base_model/model.pt'))

<All keys matched successfully>

In [85]:
def preprocess(context,null_char='',space_char=' '):
    bracket_removed_context = re.sub(r'\([^)]*\)', null_char, context)
    square_bracket_removed_context = re.sub(r'\[[^]]*\]', null_char, bracket_removed_context)
    punc_removed_context = re.sub(r'[^\w\s]',space_char,square_bracket_removed_context)
    repeated_space_removed_context = re.sub(r'[\n]{2,}','\n',punc_removed_context)
    repeated_space_removed_context.upper()
            
    return repeated_space_removed_context

In [86]:
def predict(contents):
    model_input = tokenizer(preprocess(contents),padding='max_length', max_length = 512, truncation=True,return_tensors="pt")
    mask = model_input['attention_mask']
    input_ids = model_input['input_ids']
    label = classifier(input_ids,mask)
    return round(label.item())

In [87]:
predict(target.contents)

1

In [3]:
data_path = '/home/se/paper/legacy/data'

In [7]:
match.head()

Unnamed: 0,date,state_game,home,home_score,home_result,home_time,away,away_score,away_result,away_time,stadium,home_pitcher,away_pitcher
0,20200505,종료,KIA,2,패,18:30,키움,11,승,18:30,광주,양현종,브리검
1,20200505,종료,LG,8,승,18:30,두산,2,패,18:30,잠실,차우찬,알칸타라
2,20200505,종료,삼성,0,패,18:30,NC,4,승,18:30,대구,백정현,루친스키
3,20200505,종료,KT,2,패,18:30,롯데,7,승,18:30,수원,데스파이네,스트레일리
4,20200505,종료,SSG,0,패,18:30,한화,3,승,18:30,인천,킹엄,서폴드


In [6]:
article.head()

Unnamed: 0,team,date,time,image,contents
0,키움,20160101,오전 09:16,https://imgnews.pstatic.net/image/421/2016/01/...,"'단일 경기사용구·비디오판독·돔구장' 새해, 달라지는 것들 2016시즌 처음으로 ..."
1,키움,20160101,오전 05:51,https://imgnews.pstatic.net/image/109/2016/01/...,넥센 히어로즈의 2016시즌은 유망주들에게 달려 있다. 넥센은 2016년 새해를 ...
2,키움,20160103,오후 03:09,https://imgnews.pstatic.net/image/396/2016/01/...,"염경엽·이택근 ‘원숭이띠 리더’, 넥센 이끄는 힘 〔정정욱 기자〕 2016년 ‘원숭..."
3,키움,20160103,오후 12:26,https://imgnews.pstatic.net/image/477/2016/01/...,창단 후 8년을 살던 목동에서 고척동으로 이사 가며 뚜껑을 '득템'했다. 그런데 ...
4,키움,20160103,오전 06:04,https://imgnews.pstatic.net/image/109/2016/01/...,넥센 히어로즈 외야수 허정협이 팀의 장타력 약화를 막을 수 있을까. 넥센은 최근 ...


In [88]:
from sklearn.preprocessing import LabelEncoder
class StatDataset(Dataset):
    def __init__(self, match,stat,article,le):
        
        tmp = []
        label = []
        
        for home_pitcher,home_team,away_pitcher,away_team,home_result,date in \
        list(zip(match['home_pitcher'].values,match['home'].values,match['away_pitcher'],match['away'],match['home_result'],match['date'])):
            home_player_stat = stat.loc[(stat['name'] == home_pitcher) & (stat['team'] == home_team)]
            away_player_stat = stat.loc[(stat['name'] == away_pitcher) & (stat['team'] == away_team)]

            if (len(home_player_stat) != 0) and (len(away_player_stat) != 0):
                vector = home_player_stat.values.tolist()[0][2:] + away_player_stat.values.tolist()[0][2:]
                home_article = article.loc[(article['team'] == home_team) & (article['date'] == date)]
                away_article = article.loc[(article['team'] == away_team) & (article['date'] == date)]
                if (len(home_article) != 0) and (len(away_article) != 0):
                    home_sum = 0
                    for index,document in home_article.iterrows():
                        home_sum += predict(document.contents)

                    away_sum = 0
                    for index,document in away_article.iterrows():
                        away_sum += predict(document.contents)

                    vector.append(home_sum / len(home_article))
                    vector.append(away_sum / len(away_article))

                    tmp.append(vector)
                    label.append(home_result)
        
        transform = le.transform(label)
        self.label = torch.tensor(list(transform),dtype=torch.long)
        self.stat = torch.tensor(tmp)
        
        
    def __len__(self):
        return len(self.stat)

    def __getitem__(self, index):
        return {
                "stat": self.stat[index],
                "label": self.label[index]
                }
        
    def get_num_batches(self, batch_size):
        return len(self) // batch_size

In [90]:
stat = pd.read_csv(os.path.join(data_path,'average.csv'))
match = pd.read_csv(os.path.join(data_path,'match_day.csv'))
article = pd.read_csv(os.path.join(data_path,'article.csv'))

In [91]:
le = LabelEncoder()
_=le.fit(['승','패','무승부'])

In [92]:
train_match,valid_match = train_test_split(match,shuffle=True,random_state=56)

In [94]:
train_dataset = StatDataset(train_match,stat,article,le)

In [93]:
valid_dataset = StatDataset(valid_match,stat,article,le)

In [95]:
print(len(train_dataset),len(valid_dataset))

327 100
