In [14]:
# from src.model.data_loader import ParagraphDataset
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
from src.model.loss import ParagraphLoss
from transformers import AdamW
from tqdm import tqdm
import math
from src.model.logger import Logger
import os

In [15]:
class ParagraphDataset(Dataset):
    def __init__(self, data_file, encoder, max_size=None, n_ctx=102, n_gen=401, include_neigh=False,
                 include_discourse_type=True, include_kw=True, dim=0 ,debug_mode=False):
        with open(data_file, "rb") as f:
            self.data = f.readlines()

        if include_neigh:
            self.prev = []
            fn = ".".join(data_file.split(".")[:-1]) + "_gpt2.pkl"
            if debug_mode:
                fn = ".".join(data_file.split(".")[:-1]) + "_gpt.pkl"
            with open(fn, 'rb') as fp:
                for k in range(len(self.data)):
                    temp = pickle.load(fp)
                    assert temp[0] == k and temp[1] == self.data[k].decode('utf-8', 'ignore').split("\t")[-1].replace(
                        "<o>", "").strip()
                    self.prev.append(temp[2])
        else:
            self.prev = None

        self.dids = []
        for d in range(1, len(self.data)):
            t = self.data[d].decode("utf-8", "ignore").strip().split('\t')
            if len(t) == 7 and t[5].replace("<o>", "").strip() != "":
                try:
                    x, y = int(t[0].split("_")[-1]), int(t[4])
                    self.dids.append(d)
                except:
                    pass

        if max_size is not None:
            self.dids = self.dids[:max_size]
        self.encoder = encoder
        self.ctx = n_ctx - 2
        self.gen = n_gen - 1
        self.dim = dim
        self.len = len(self.data)
        self.include_neigh = include_neigh
        self.include_discourse_type = include_discourse_type
        self.include_kw = include_kw


    def __getitem__(self, index):
        idx = self.dids[index]
        csv_data = self.data[idx].decode("utf-8", "ignore").strip().split('\t')
        kws = csv_data[2].split("[SEP]")
        # print(self.encoder.encode(csv_data[5]))
        tgt_phrase = self.encoder.encode(csv_data[5].replace("<o>", ""),  add_special_tokens=False)[:self.gen] # add_prefix_space=True,
        start = torch.LongTensor([self.encoder.bos_token_id])
        clstok = torch.LongTensor([self.encoder.cls_token_id])
        end = torch.LongTensor([self.encoder.eos_token_id])
        tstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_t_')])
        istart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_i_')])
        bstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_b_')])
        cstart = torch.LongTensor([self.encoder.convert_tokens_to_ids('_c_')])
        keytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_kw_')])
        endkeytok = torch.LongTensor([self.encoder.convert_tokens_to_ids('_endkw_')])
        
        if self.include_discourse_type:
            starttyptok = bstart
            if int(csv_data[0].split("_")[-1]) == 0:
                starttyptok = istart
            elif int(csv_data[0].split("_")[-1]) == int(csv_data[4]) - 1:
                starttyptok = cstart
        else:
            starttyptok = clstok

        pad_output = torch.zeros(self.ctx + self.gen + 3).long()
        mask_output = torch.zeros(self.ctx + self.gen + 3).long()

        pad_output[0] = start
        if self.include_kw:
            i = 1
            for k in kws:
                if i - 1 >= self.ctx:
                    break
                enck = self.encoder.encode(k.strip(),  add_special_tokens=False)[:self.ctx - i] # add_prefix_space=True,
                # print(enck, i)
                pad_output[i:i + len(enck)] = torch.LongTensor(enck)
                pad_output[i + len(enck)] = keytok
                i += len(enck) + 1
            pad_output[i - 1] = endkeytok
            mask_output[0:i] = torch.ones(i).long()

        pad_output[self.ctx + 1] = starttyptok if self.include_discourse_type else clstok  # [101] -> discourse tag
        pad_output[self.ctx + 1 + 1:self.ctx + 1 + 1 + len(tgt_phrase)] = torch.LongTensor(tgt_phrase)
        pad_output[self.ctx + 1 + 1 + len(tgt_phrase)] = end

        # Mask
        mask_output[self.ctx + 1:self.ctx + 1 + len(tgt_phrase) + 2] = torch.ones(len(tgt_phrase) + 2).long()

        if self.include_neigh:
            n = torch.FloatTensor(self.prev[idx].flatten())
        else:
            n = torch.zeros(self.dim, dtype=torch.float64)
        return pad_output, mask_output, n

    def __len__(self):
        return len(self.dids)

In [16]:
encoder = GPT2Tokenizer.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2", add_prefix_space=True)

encoder.add_special_tokens({'bos_token':'_start_',
                                     'cls_token':'_classify_',
                                     'eos_token':'_end_',
                                     'additional_special_tokens': ['_kw_','_endkw_', '_t_', '_i_', '_b_', '_c_']
                                    })

9

In [17]:
config = {
    "include_kw": True,
    "include_discourse_type": True,
    "max_size": 100,
    "n_ctx": 102,
    "gen_len": 1946,
    "include_neigh": False,
    "dim": 768,
    "lr": 6.25e-5,
    "b1": 0.9,
    "b2": 0.999,
    "e": 1e-8,
    "num_epochs": 1
}

In [18]:
train_dataset = ParagraphDataset('dataset/plot/train_encoded.csv', encoder, max_size=config["max_size"], n_ctx=config["n_ctx"], n_gen=config["gen_len"],
                               include_neigh=config["include_neigh"], include_discourse_type=config["include_discourse_type"], 
                               include_kw=config["include_kw"], dim=config["dim"])

test_dataset = ParagraphDataset('dataset/plot/val_encoded.csv', encoder, max_size=config["max_size"], n_ctx=config["n_ctx"], n_gen=config["gen_len"],
                               include_neigh=config["include_neigh"], include_discourse_type=config["include_discourse_type"], 
                               include_kw=config["include_kw"], dim=config["dim"])

In [19]:
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=1, drop_last=True)
val_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=1, drop_last=True)

In [20]:
model = GPT2LMHeadModel.from_pretrained("sberbank-ai/rugpt3small_based_on_gpt2", output_hidden_states=True)

In [21]:
def run_batch(model, args, device, compute_loss_fct):
    for arg in args:
        if arg is not None:
            arg = arg.to(device)

    output = model(*args)
    
    args[0] = args[0].to(device)
    args[1] = args[1].to(device)
    
    allloss = compute_loss_fct(output, args[0], args[1])
    
    return allloss.mean()

In [22]:
def run_epoch(bestloss, start_iter, running_loss, model, compute_loss_fct, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, gen_len, k,p, decoding_strategy, accum_iter, desc_str, save_dir, logger, text_encoder, show_progress=False, summary_loss=None, my_local_dir='checkpoints_local'):
    '''
    Run a single epoch, log results, and save best checkpoint
    '''
    if show_progress:
        train_bar = tqdm(iterable=train_loader, desc=desc_str)
    else:
        train_bar = train_loader

    for i, batchargs in enumerate(train_bar, start_iter):
        
        num_updates = i // accum_iter
        model.train()
        loss = run_batch(model, batchargs, device, compute_loss_fct)

        loss.backward()

        running_loss += float(loss.detach().item())
        if show_progress:
            train_bar.set_postfix(loss=running_loss / ((train_log_interval * accum_iter) if num_updates % train_log_interval == 0 and num_updates != 0 else i % (train_log_interval * accum_iter)))

        if i % accum_iter == 0:
            model_opt.step()
            model_opt.zero_grad()
            torch.cuda.empty_cache()
        if num_updates % train_log_interval == 0 and i % accum_iter == 0:
            logger.scalar_summary("Training", num=running_loss, denom=(train_log_interval * accum_iter), step=num_updates)
            print("training loss %.2f" % (running_loss/float(train_log_interval * accum_iter)))
            running_loss = 0

        # if num_updates % 1000 == 0 and i % accum_iter == 0:
        #     val_loss, scores = evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k, p, decoding_strategy, compute_loss_fct, min_len=args.min_len)

        #     logger.scalar_summary("Validation", num=val_loss, denom=len(val_loader), step=num_updates)
        #     # if sum(val_loss) < bestloss or bestloss == -1:
        #     lv = get_loss_value(val_loss, len(val_loader))
        #     if (not math.isnan(lv)) and (bestloss == -1 or lv < bestloss):
        #         bestloss = lv
        #         save_checkpoint(i + 1, running_loss, model.state_dict(), model_opt.state_dict(), save_dir, my_local_dir)


    # val_loss, scores = evaluate(val_loader, train_log_interval, model, text_encoder, device, beam, gen_len, k, p, decoding_strategy, compute_loss_fct, min_len=args.min_len)
    # for key, value in scores.items():
    #     for key2, value2 in value.items():
    #         logger.rouge_summary("{}/{}".format(key, key2), value2, num_updates)
    # print("Validation rouge: " + str(scores.items()))
    # logger.scalar_summary("Validation", num=val_loss, denom=len(val_loader), step=num_updates)
    # lv = get_loss_value(val_loss, len(val_loader))
    # if (not math.isnan(lv)) and (bestloss == -1 or lv < bestloss):
    #     bestloss = lv
    #     save_checkpoint(i + 1, running_loss, model.state_dict(), model_opt.state_dict(), save_dir, my_local_dir)


    torch.cuda.empty_cache()
    return i + 1, running_loss, bestloss, num_updates # , lv

In [23]:
output_dir = 'savedir'
experiment_name = 'gpt3'
print("Creating directories")
os.makedirs(output_dir, exist_ok=True)
os.makedirs(os.path.join(output_dir, experiment_name), exist_ok=True)
os.makedirs(os.path.join(output_dir, experiment_name), exist_ok=True)


save_dir = os.path.join(output_dir, experiment_name, "checkpoints")
save_dir_local = "checkpoints_local"
desc = "Desc"
data_dir = 'dataset/plot'
log_dir = os.path.join(output_dir, experiment_name, "logs")
os.makedirs(log_dir, exist_ok=True)
os.makedirs(save_dir, exist_ok=True)
os.makedirs(save_dir_local, exist_ok=True)

Creating directories


In [24]:
train_log_interval = 4
val_log_interval = 4
beam = 0
p = 90
k = 0
decoding_strategy = 0
accum_iter = 4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger = Logger(log_dir)

In [25]:
criterion = torch.nn.CrossEntropyLoss(reduction="none")

model_opt = AdamW(filter(lambda p : p.requires_grad, model.parameters()),
                        lr=config['lr'],
                        betas=(config['b1'], config['b2']),
                        eps=config['e'])



lm_loss = ParagraphLoss(criterion, n_ctx=config["n_ctx"], gen_len=config["gen_len"])

In [26]:
bestloss = -1
start_iter, running_loss = 1,0
prevloss = 1000

for i in range(config['num_epochs']):
    start_iter, running_loss, bestloss, updates, val_loss1 = run_epoch(bestloss, start_iter, running_loss, model, lm_loss, model_opt, train_loader, val_loader, train_log_interval, val_log_interval, device, beam, config['gen_len'], k, p, decoding_strategy, accum_iter, "FT Training Epoch [{}/{}]".format(i + 1, config['num_epochs']), save_dir, logger, encoder, show_progress=True, my_local_dir='save_dir_local')
    print("VAL LOSS: ", str(val_loss1))
    if val_loss1 > prevloss or math.isnan(val_loss1):
        break
    prevloss = val_loss1

FT Training Epoch [1/1]:   0%|          | 0/50 [00:00<?, ?it/s]