<a href="https://colab.research.google.com/github/cronus6w6/AI-CUP-2020/blob/master/predict.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers

In [2]:
from transformers import *
import torch
from torch import nn
import shutil
import pandas as pd
import numpy as np
from tqdm import tqdm
import random

In [16]:
test_data_path = "testset.csv" #test dataset path
model_state_path = "model_state" #model state path, output from predict
LABELS = ["THEORETICAL", "ENGINEERING", "EMPIRICAL", "OTHERS"]
thresholds = [0.35, 0.3, 0.25, 0.35]
batch = 100

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

In [7]:
class _DataSet(torch.utils.data.Dataset):
  def __init__(self, inp_data: pd.DataFrame):
    title_indices = []
    title_segments = []
    abstract_indices = []
    abstract_segments = []
    self.return_labels = "Classifications" in inp_data.columns
    if self.return_labels:
      labels = []
    for _, row in tqdm(inp_data.iterrows(), total=len(inp_data)):
      # index, segment = tokenizer.encode(row.Title + "[SEP]" + row.Abstract.replace("$$$", "[SEP]"), max_len=512)
      title_index = tokenizer.encode(row.Title, max_length=512, padding="max_length")
      abstract_index = tokenizer.encode(row.Abstract, max_length=512, padding="max_length", truncation=True)
      title_indices.append(title_index)
      abstract_indices.append(abstract_index)
      if self.return_labels:
        labels.append(list(map(lambda l: 1 if l in row.Classifications.split(" ") else 0, LABELS)))
        self.labels = torch.tensor(labels, dtype=torch.float32, device=device)
    self.title_indices = torch.tensor(title_indices, dtype=torch.long, device=device)
    self.title_segments = torch.zeros(self.title_indices.size(), dtype=torch.long, device=device)
    self.abstract_indices = torch.tensor(abstract_indices, dtype=torch.long, device=device)
    self.abstract_segments = torch.zeros(self.abstract_indices.size(), dtype=torch.long, device=device)
  def __getitem__(self, index):
    if self.return_labels:
      return ({
        "title_indices": self.title_indices[index],
        "title_segments": self.title_segments[index],
        "abstract_indices": self.abstract_indices[index],
        "abstract_segments": self.abstract_segments[index]
      }, self.labels[index])
    return {
      "title_indices": self.title_indices[index],
      "title_segments": self.title_segments[index],
      "abstract_indices": self.abstract_indices[index],
      "abstract_segments": self.abstract_segments[index]
    }
  def __len__(self):
    return len(self.title_indices)

In [None]:
testset = pd.read_csv(test_data_path)
testset.Abstract = testset.Abstract.str.replace("\$\$\$", " ")
test_data = _DataSet(testset)

In [17]:
test_dataloader = torch.utils.data.DataLoader(test_data, batch)

In [10]:
class MultiClassificationModel(nn.Module):
  def __init__(self, encoder, embs_num=768, class_num=4, hidden_unit=64, encoder_dropout=0.2, hidden_dropout=0.2):
    super(MultiClassificationModel, self).__init__()
    
    self.encoder = encoder
    self.encoder_dropout = nn.Dropout(encoder_dropout)
    self.classifier = nn.Sequential(
        nn.Linear(embs_num * 2, hidden_unit),
        nn.GELU(),
        nn.Dropout(hidden_dropout),
        nn.Linear(hidden_unit, class_num)
    )
  def forward(self, title_indices, title_segments, abstract_indices, abstract_segments):
    title_embs = self.encoder(title_indices, token_type_ids=title_segments)[0][:, 0, :]
    abstract_embs = self.encoder(abstract_indices, token_type_ids=abstract_segments)[0][:, 0, :]
    
    embs = torch.cat([title_embs, abstract_embs], 1)
    result = self.classifier(embs)
    return result

In [None]:
try:
  model.cpu() 
except:
  pass
scibert = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')
model = MultiClassificationModel(scibert)
model.load_state_dict(torch.load(model_state_path, map_location=device))
model.to(device)

In [18]:
thrld = torch.tensor(thresholds, dtype=torch.float, device=device).expand((batch, 4))

In [None]:
ans = []
with torch.no_grad():
  for data in tqdm(test_dataloader, total=len(test_dataloader)):
    result = model(**data)
    result = torch.sigmoid(result)
    result = (result > thrld).int()
    ans.extend(result.tolist())

In [20]:
out = pd.DataFrame(ans)

In [21]:
out.columns = LABELS

In [22]:
out.insert(0, "Id", range(1, len(out) + 1))

In [23]:
out.to_csv("predict.csv", index=None)