In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import random, os
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
import pandas as pd
import easydict
import math
from itertools import chain

from sklearn.model_selection import train_test_split

In [None]:
seed = 42
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [None]:
lr = 0.001
epochs = 1000
h_dim = 64
num_layers = 1
num_heads = 4
device = "cuda:0" if torch.cuda.is_available() else "cpu"
batch_size = 256
dropout = 0.2
save_dir_path = "./model_saved"

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, dim, max_len, device):
        super(PositionalEncoding, self).__init__() 
        self.pe = torch.zeros(max_len, dim).to(device)
        self.pe.requires_grad = False
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1).to(device)
        div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)).to(device)
        self.pe[:, 0::2] = torch.sin(pos * div_term)
        self.pe[:, 1::2] = torch.cos(pos * div_term)
    
    def forward(self, x):
        return x+self.pe[:x.size()[1], :]

class FFN(nn.Module):
    def __init__(self, dim=h_dim, dropout=dropout):
        super(FFN, self).__init__()
        self.linear1 = nn.Linear(dim, dim) 
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(dim, dim) 
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

class ChampRec(nn.Module):
    def __init__(self, n_ch=164, seq_len=9, dim=h_dim, num_heads=num_heads, num_layers=num_layers, dropout=dropout):
        super(ChampRec, self).__init__()
        self.pos_emb = nn.Embedding(seq_len, dim)
        self.champ_emb = nn.Embedding(n_ch, dim)
        self.pe = PositionalEncoding(dim, seq_len, device)
        self.emb_linear1 = nn.Linear(dim, dim)

        encoder_layers = nn.TransformerEncoderLayer(dim, num_heads, dim, dropout)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.layer_norm = nn.LayerNorm(dim)
        self.FFN = FFN(dim, dropout)
        self.final_layer = nn.Linear(dim, n_ch)
    
    def forward(self, champ_seq, pos):
        champid_emb = self.champ_emb(champ_seq)
        champid_emb = self.emb_linear1(champid_emb)
        
        champid_emb = self.pe(champid_emb)
        pad_mask = (champ_seq == 0)

        champid_emb = champid_emb.permute(1, 0, 2)
        output = self.transformer_encoder(champid_emb, src_key_padding_mask=pad_mask)
        output = output.permute(1, 0, 2)

        output = output[:, -1]    
        output = self.layer_norm(output) 
               
        ffn_out = self.FFN(output)
        
        output = self.layer_norm(ffn_out + output)
        output = self.final_layer(output).squeeze(1)
        
        output = torch.sigmoid(output)
        return output


In [None]:
class SeqDataset(Dataset):
    def __init__(self, seq_sample, target, seq_len=9):  
        self.seq_sample = seq_sample
        self.target = target
        self.seq_len = seq_len

    def __len__(self):
        return len(self.seq_sample)
    
    def __getitem__(self, index):
        x = self.seq_sample[index]
        t = self.target[index]
                 
        return torch.tensor(x), torch.tensor(t)


In [None]:
def preprocessing(ally, enemy, train = True):
    seq_sample = []
    target = []
    test_pos = []
    if train: 
        for x, o in zip(ally.ch, enemy.ch):
            for i in range(4):
                ax = [0 for _ in range(4)]
                ex = [0 for _ in range(5)]
                if i == 3:
                    ex[-(i+2):] = o[:i+2]
                else:
                    ex[-(i+1):] = o[:i+1]
                ax[-(i+1):] = x[:i+1]
                target.append(x[i+1])
                seq_sample.append(ex+ax)
    else:
        for x, o in zip(ally.ch, enemy.ch):
            i = np.random.choice(range(4),1)[0]
            ax = [0 for _ in range(4)]
            ex = [0 for _ in range(5)]
            if i == 3:
                ex[-(i+2):] = o[:i+2]
            else:
                ex[-(i+1):] = o[:i+1]
            ax[-(i+1):] = x[:i+1]
            target.append(x[i+1])
            seq_sample.append(ex+ax)
            test_pos.append(p[i+1])
            
    return seq_sample, target, test_pos

In [None]:
def recall_at_k(answer, pred):
    recall = 0.0
    for a, p in zip(answer, pred):
        ans_set = set([a.item()])
        pred_set = set(p)
        recall += len(ans_set & pred_set) / float(len(ans_set))
    return recall

In [None]:
def recall_test(answer, toplist):
    return len(set([answer]) & set(toplist)) / len([answer])

In [None]:
import json
def keystoint(x):
    return {int(k): v for k, v in x}

with open('../encdec_pos.json', 'r', encoding='utf-8') as f:
    json_string = f.read()
    
ch_pos = json.loads(json_string, object_pairs_hook=keystoint)

In [None]:
train = pd.read_parquet("transformer_train.parquet")
op_train= pd.read_parquet("transformer_op_train.parquet")
valid = pd.read_parquet("transformer_valid.parquet")
op_valid = pd.read_parquet("transformer_op_valid.parquet")
test = pd.read_parquet("transformer_test.parquet")
op_test = pd.read_parquet("transformer_op_test.parquet")
pos_test = pd.read_parquet("transformer_pos_test.parquet")
ban_test = pd.read_parquet("transformer_ban_test.parquet")

In [None]:
X_train = []
y_train = []
for x, o in zip(train.ch, op_train.ch):
    for i in range(4):
        ax = [0 for _ in range(4)]
        ex = [0 for _ in range(5)]
        if i == 3:
            ex[-(i+2):] = o[:i+2]
        else:
            ex[-(i+1):] = o[:i+1]
        ax[-(i+1):] = x[:i+1]
        y_train.append(x[i+1])
        X_train.append(ex+ax)
        
X_valid = []
y_valid = []
for x, o in zip(valid.ch, op_valid.ch):
    i = np.random.choice(range(4), 1)[0]
    ax = [0 for _ in range(4)]
    ex = [0 for _ in range(5)]
    if i == 3:
        ex[-(i+2):] = o[:i+2]
    else:
        ex[-(i+1):] = o[:i+1]
    ax[-(i+1):] = x[:i+1]
    y_valid.append(x[i+1])
    X_valid.append(ex+ax)

X_test = []
y_test = []
test_pos = []
for x, o, p in zip(test.ch, op_test.ch, pos_test.pos):
    i = np.random.choice(range(4), 1)[0]
    ax = [0 for _ in range(4)]
    ex = [0 for _ in range(5)]
    if i == 3:
        ex[-(i+2):] = o[:i+2]
    else:
        ex[-(i+1):] = o[:i+1]
    ax[-(i+1):] = x[:i+1]
    y_test.append(x[i+1])
    X_test.append(ex+ax)
    test_pos.append(p[i+1])

In [None]:
train_loader = DataLoader(SeqDataset(X_train, y_train), batch_size=batch_size)
valid_loader = DataLoader(SeqDataset(X_valid, y_valid), batch_size=batch_size)
test_loader = DataLoader(SeqDataset(X_test, y_test), batch_size=len(X_test))

In [None]:
loss_func = nn.BCELoss()
model = ChampRec().to(device)
optimizer = Adam(model.parameters(), lr=lr)

In [None]:
train_losses = []
best_recall = 0
cnt = 0
for epoch in range(epochs):
    model.train()
    train_loss = []
    for x, target in train_loader:
        x = x.to(device).long()
        target = target.to(device).long()
        pos = torch.arange(9).unsqueeze(0).to(device)

        label = torch.zeros((target.shape[0], 164)).to(device)        
        label.scatter_(1, target[:, None], 1)

        optimizer.zero_grad()

        preds = model(x, pos)
        loss = loss_func(preds, label)

        loss.backward()
        optimizer.step()

        train_loss.append(loss.item())
    train_loss = np.mean(train_loss)
    
    model.eval()
    recall = 0.0
    with torch.no_grad():
        for x, target in valid_loader:
            x = x.to(device).long()
            pos = torch.arange(9).unsqueeze(0).to(device)
            target = target.to(device).long()
            
            preds = model(x, pos)
            _, items = torch.topk(preds, 3)
            items = items.to("cpu").detach().numpy()
            recall += recall_at_k(target, items)
            
        recall /= len(valid_loader.dataset)    
        
        if recall > best_recall:
            best_recall = recall
            cnt = 0
            if not os.path.exists(save_dir_path):
                os.makedirs(save_dir_path)
            model_scripted = torch.jit.script(model) #Export to TorchScript
            model_scripted.save('model_scripted_fixed_length.pt') # Save
        else:
            cnt+=1
            
    train_losses.append(train_loss)
    print(f"epoch : {epoch} Train Loss : {train_loss} recall : {recall}")
    if cnt >= 50:
        print('Early Stopping!')
        print(best_recall)
        break

In [None]:
model.eval()
top3recall = 0.0
test_loader = DataLoader(SeqDataset(X_test, y_test), batch_size=1)

with torch.no_grad():
    for (x, target), ban in zip(test_loader, ban_test.ch.values):
        x = x.to(device).long()
        target = target.to(device).long()
        ban = torch.from_numpy(ban).to(device).unsqueeze(0).long()
        pos = torch.arange(9).unsqueeze(0).to(device)
        
        preds = model(x, pos)
        
        mask = torch.zeros(1, 164).to(device).long()
        mask.scatter_(1, x, 1)
        mask.scatter_(1, ban, 1)
        preds[mask==1] = -np.inf

        items = torch.argsort(preds, descending=True, dim = 1)
        items = items.to("cpu").detach().numpy()[0]
        
        top3recall+=recall_test(target.item(), items[:3])            
     
    print(top3recall / len(test_loader.dataset))        

In [None]:
model.eval()
top3recall = 0.0
posrecall = 0.0
top15recall = 0.0

with torch.no_grad():
    for (x, target), ban, rp in zip(test_loader, ban_test.ch.values, test_pos):
        x = x.to(device).long()
        target = target.to(device).long()
        ban = torch.from_numpy(ban).to(device).unsqueeze(0).long()
        pos = torch.arange(9).unsqueeze(0).to(device)
        
        preds = model(x, pos)
        
        mask = torch.zeros(1, 164).to(device).long()
        mask.scatter_(1, x, 1)
        mask.scatter_(1, ban, 1)
        preds[mask == 1] = -np.inf

        items = torch.argsort(preds, descending=True, dim = 1)
        items = items.to("cpu").detach().numpy()[0]

        posrec = [[], [], [], [], []]
        
        total_len = 0
        for x in items:
            if total_len == 15:
                break
            for pos in ch_pos[x]:
                if len(posrec[pos]) < 3:
                    posrec[pos].append(x)
                    total_len+=1
                    
        top3recall += recall_test(target.item(), items[:3])      
        posrecall += recall_test(target.item(), posrec[rp])
        top15recall += recall_test(target.item(), chain.from_iterable(posrec))        
     
    print(top3recall / len(test_loader.dataset))   
    print(posrecall / len(test_loader.dataset))   
    print(top15recall / len(test_loader.dataset))       