In [1]:
import argparse
import os

import numpy as np
import torch
import torch.nn as nn
from sklearn.metrics import accuracy_score
from sklearn.utils import shuffle

from analysis import SST_analyze
from analysis import news as news_analysis
from datasets import SST2
from loss import ClassificationLossCompute
from model_pytorch import LMHead, load_openai_pretrained_model, MLP
from opt import OpenAIAdam
from text_utils import TextEncoder
from utils import (encode_dataset, iter_data,
                   ResultLogger, make_path)

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"


def transform_news(X1):
    n_batch = len(X1)
    xmb = np.zeros((n_batch, n_ctx, 2), dtype=np.int32)
    mmb = np.zeros((n_batch, n_ctx), dtype=np.float32)

    #start = encoder['_start_']
    for i, x1 in enumerate(X1):
        print(max_len, clf_token)
        x12 = x1[:max_len] + [clf_token]
        l12 = len(x12)
        xmb[i, :l12, 0] = x12
        mmb[i, :l12] = 1
    # Position information that is added to the input embeddings in the TransformerModel
    xmb[:, :, 1] = np.arange(n_vocab + n_special, n_vocab + n_special + n_ctx)
    return xmb, mmb


def iter_apply(Xs, Ms, Ys):
    # fns = [lambda x: np.concatenate(x, 0), lambda x: float(np.sum(x))]
    logits = []
    cost = 0

    with torch.no_grad():
        dh_model.eval()
        for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
            n = len(xmb)
            XMB = torch.tensor(xmb, dtype=torch.long).to(device)
            YMB = torch.tensor(ymb, dtype=torch.long).to(device)
            MMB = torch.tensor(mmb).to(device)
            _, clf_logits = dh_model(XMB[..., 0])
            # clf_logits *= n
            clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
            # clf_losses *= n
            logits.append(clf_logits.to("cpu").numpy())
            cost += clf_losses.sum().item()
        logits = np.concatenate(logits, 0)
    return logits, cost


def iter_predict(Xs, Ms):
    logits = []
    with torch.no_grad():
        dh_model.eval()
        for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
            n = len(xmb)
            XMB = torch.tensor(xmb, dtype=torch.long).to(device)
            MMB = torch.tensor(mmb).to(device)
            _, clf_logits = dh_model(XMB[..., 0])
            logits.append(clf_logits.to("cpu").numpy())
    logits = np.concatenate(logits, 0)
    return logits


def log(save_dir, desc):
    global best_score
    print("Logging")
    tr_logits, tr_cost = iter_apply(trX[:n_valid], trM[:n_valid], trY[:n_valid])
    va_logits, va_cost = iter_apply(vaX, vaM, vaY)
    tr_cost = tr_cost / len(trY[:n_valid])
    va_cost = va_cost / n_valid
    tr_acc = accuracy_score(trY[:n_valid], np.argmax(tr_logits, 1)) * 100.
    va_acc = accuracy_score(vaY, np.argmax(va_logits, 1)) * 100.
    logger.log(n_epochs=n_epochs, n_updates=n_updates, tr_cost=tr_cost, va_cost=va_cost, tr_acc=tr_acc, va_acc=va_acc)
    print('%d %d %.3f %.3f %.2f %.2f' % (n_epochs, n_updates, tr_cost, va_cost, tr_acc, va_acc))
    if submit:
        score = va_acc
        if score > best_score:
            best_score = score
            path = os.path.join(save_dir, desc, 'best_params')
            torch.save(dh_model.state_dict(), make_path(path))


def predict(dataset, submission_dir):
    filename = filenames[dataset]
    pred_fn = pred_fns[dataset]
    label_decoder = label_decoders[dataset]
    predictions = pred_fn(iter_predict(teX, teM))
    if label_decoder is not None:
        predictions = [label_decoder[prediction] for prediction in predictions]
    path = os.path.join(submission_dir, filename)
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as f:
        f.write('{}\t{}\n'.format('index', 'prediction'))
        for i, prediction in enumerate(predictions):
            f.write('{}\t{}\n'.format(i, prediction))


def run_epoch():
    for xmb, mmb, ymb in iter_data(*(trX, trM, trYt),
                                   n_batch=n_batch_train, truncate=True, verbose=True):
        #shuffle,random_state=np.random
        global n_updates
        dh_model.train()
        XMB = torch.tensor(xmb, dtype=torch.long).to(device)  # torch.Size([4,257,2])
        YMB = torch.tensor(ymb, dtype=torch.long).to(device)  # torch.Size([4])
        MMB = torch.tensor(mmb).to(device)  # ndarray: size(1028)
        lm_logits, clf_logits = dh_model(XMB[..., 0])
        # lm_logits: torch.Size([1024, 40737])
        # clf_logits: torch.Size([4, 2])
        compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
        n_updates += 1
        if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
            log(save_dir, desc)


argmax = lambda x: np.argmax(x, 1)

pred_fns = {
    'sst': argmax
}

filenames = {
    'sst': 'sst'
}

label_decoders = {
    'sst': None
}

In [2]:
class ClfHead(nn.Module):
    """Classification Head for the transformer

    TODO: test this class."""

    def __init__(self, clf_token, cfg, n_class):
        super(ClfHead, self).__init__()
        self.n_embd = cfg.n_embd
        self.clf_token = clf_token
        self.dropout = nn.Dropout(cfg.clf_pdrop)
        self.linear = nn.Linear(cfg.n_embd, n_class)
        self.mlp = MLP(4 * cfg.n_embd, cfg)
        nn.init.normal_(self.linear.weight, std=0.02)
        nn.init.normal_(self.linear.bias, 0)

    def forward(self, h, x):
        # h: Tensor [4,257,768]
        #print(h.shape)
        clf_h = h.view(-1, self.n_embd)
        #print(clf_h.shape)
        # clf_h: tensor[1028,768]
        #flat = x[..., 0].contiguous().view(-1)
        flat= x.contiguous().view(-1)
        #print(flat.shape)
        # falt: tensor[1028]
        #print(self.clf_token)
        clf_h = clf_h[flat == self.clf_token, :]
        # clf_h: tensor: [4,768]
        clf_h = self.dropout(clf_h)
        clf_logits = self.linear(clf_h)  # Tensor [4,2]

        return clf_logits

In [3]:
from transformers import GPT2Model, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.add_special_tokens({'cls_token': '[CLS]'})
n_vocab = len(tokenizer)

In [4]:
import collections

class DoubleHeadModel(nn.Module):
    """ Transformer with language model and task specific heads """

    def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):
        super(DoubleHeadModel, self).__init__()
        #self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
        self.transformer = GPT2Model.from_pretrained('gpt2')
        self.transformer.resize_token_embeddings(len(tokenizer))
        self.lm_head = LMHead(self.transformer, cfg)
        if isinstance(task_head_type, str):
            if task_head_type == 'multiple_choice':
                self.task_head = MultipleChoiceHead(clf_token, cfg)
            elif task_head_type == 'similarity':
                self.task_head = SimilarityHead(clf_token, cfg)
            elif task_head_type == 'inference':
                # the three classes correspond to entailment, contradiction and neutral.
                self.task_head = ClfHead(clf_token, cfg, 3)
            else:
                raise ValueError("task_head_type is expected to be 'multiple_choice' "
                                 "'similarity', 'inference' or ('classification', n_class) "
                                 "got {task_head_type}.")
        elif isinstance(task_head_type, collections.abc.Sequence) and len(task_head_type) == 2 and \
                task_head_type[0] == 'classification':
            n_class = task_head_type[1]
            self.task_head = ClfHead(clf_token, cfg, n_class)
        else:
            raise ValueError("task_head_type is expected to be 'multiple_choice' "
                             "'similarity', 'inference' or ('classification', n_class) "
                             "got {task_head_type}.")

    def forward(self, x):
        h = self.transformer(x)  # [4,257,768]
        h = h[0]
        #print(h.shape)
        # h: torch.float32, torch.Size([4, 257, 768])
        lm_logits = self.lm_head(h)  # [1024,40737]
        # lm_logits: torch.float32, torch.Size([1024, 40737])
        sum = torch.sum((x!=0), dim=1)
        task_logits = self.task_head(h, x)  # [4,2]
        # task_logits: torch.Size([4, 2])
        return lm_logits, task_logits


In [5]:
submit = True
dataset = 'sst'
#n_ctx = 52
n_ctx = 51
save_dir = '../models/save_gpt2/'
desc = 'sst'
data_dir = '../data/bot_detection'
log_dir = 'log/'
submission_dir = 'submission_gpt2/'

np.random.seed(2345)
# set random seed for all CPU
torch.manual_seed(2345)
# set random seed for all GPU
torch.cuda.manual_seed_all(2345)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device(device)
n_gpu = torch.cuda.device_count()
print("device:", device, "n_gpu:", n_gpu)
# device: cuda n_gpu: 1

#logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)
#text_encoder = TextEncoder('model/encoder_bpe_40000.json', 'model/vocab_40000.bpe')
#encoder = text_encoder.encoder
#n_vocab = len(text_encoder.encoder)

print("Encoding dataset...")

#((trX1, trY),
# (vaX1, vaY),
# (teX1,)) = encode_dataset(*SST2(data_dir),
#                           encoder=text_encoder)

trX1 = [tokenizer.encode(sentence) for sentence in SST2(data_dir)[0][0]]
trY = [label for label in SST2(data_dir)[0][1]]
vaX1 = [tokenizer.encode(sentence) for sentence in SST2(data_dir)[1][0]]
vaY = [label for label in SST2(data_dir)[1][1]]
teX1 =[tokenizer.encode(sentence) for sentence in SST2(data_dir)[2][0]]

device: cuda n_gpu: 1
Encoding dataset...


In [None]:
#encoder['_start_'] = len(encoder)
#encoder['_classify_'] = len(encoder)
#clf_token = encoder['_classify_']
clf_token = tokenizer.cls_token_id
#n_special = 2
n_special = 1
#max_len = n_ctx - 2
max_len = n_ctx - 1
# n_ctx is the maximum number of token in an input sequence
n_ctx = min(max(
#    [len(x1[:max_len]) for x1 in trX1]
#    + [len(x1[:max_len]) for x1 in vaX1]
     [len(x1[:max_len]) for x1 in teX1]
) + 2, n_ctx)

vocab = n_vocab + n_special + n_ctx
trX, trM = transform_news(trX1)
vaX, vaM = transform_news(vaX1)
if submit:
    teX, teM = transform_news(teX1)

n_train = len(trY)
n_valid = len(vaY)

In [7]:
n_batch_train = 8 * max(n_gpu, 1)
n_updates_total = (500000 // n_batch_train) * 3

In [8]:
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

In [9]:
args={"n_embd":768,"embd_pdrop":0.1,"n_layer":12,"n_ctx":51,"clf_pdrop":0.1,"n_head":12,"attn_pdrop":0.1,"resid_pdrop":0.1,"afn":"gelu","submission_dir":"submission_gpt2"}
args=AttrDict(args)

In [10]:
logger = ResultLogger(path=os.path.join(log_dir, '{}.jsonl'.format(desc)), **args.__dict__)

In [11]:
dh_model = DoubleHeadModel(args, clf_token, ('classification', 2), vocab, n_ctx)

criterion = nn.CrossEntropyLoss(reduce=False)
model_opt = OpenAIAdam(dh_model.parameters(),
                       lr=6.25e-5,
                       schedule='warmup_linear',
                       warmup=0.002,
                       t_total=n_updates_total,
                       b1=0.9,
                       b2=0.999,
                       e=1e-8,  # epsilon
                       l2=0.01,
                       vector_l2='store_true',
                       max_grad_norm=1)
compute_loss_fct = ClassificationLossCompute(criterion,
                                             criterion,
                                             0.5,
                                             model_opt)
#load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)

dh_model.to(device)
dh_model = nn.DataParallel(dh_model)



In [None]:
n_updates = 0
n_epochs = 0
if dataset != 'stsb':
    trYt = trY
if submit:
    path = os.path.join(save_dir, desc, 'best_params')
    torch.save(dh_model.state_dict(), make_path(path))
best_score = 0
for i in range(3):
    print("running epoch", i)
    run_epoch()
    n_epochs += 1
    log(save_dir, desc)
if submit:
    path = os.path.join(save_dir, desc, 'best_params')
    dh_model.load_state_dict(torch.load(path))
    predict(dataset, args.submission_dir)
    if args.analysis:
        SST_analyze(data_dir, os.path.join(args.submission_dir, filenames[dataset]),
                    os.path.join(log_dir, '{}.jsonl'.format(dataset)))