In [1]:
import os
from datetime import datetime
from time import time


class MyLog(object):
    def __init__(self, log_file='log/log_file', file_output=True, screen_output=True, reset=False):
        if not os.path.exists("log"):
            os.makedirs("log")
        self.st = time()
        self.log_file = log_file
        self.file_output = file_output
        self.screen_output = screen_output
        if reset:
            f = open(log_file, "w")
            f.close()

    def get_start(self):
        return self.st

    def set_start(self, st):
        self.st = st

    def reset(self):
        self.st = time()

    def get_time(self):
        return time() - self.st

    def set_file_output(self, file_output):
        self.file_output = file_output

    def set_screen_output(self, screen_output):
        self.screen_output = screen_output

    def set_log_file(self, log_file):
        self.log_file = log_file

    def log(self, msg, file_output=None, screen_output=None):
        if file_output is None:
            file_output = self.file_output
        if screen_output is None:
            screen_output = self.screen_output

        current_time = self.get_time()

        if file_output:
            f = open(self.log_file, "a+", encoding='utf-8')
            f.write(str(datetime.now()) + ' : ' + msg + '\n')
            f.close()

        if screen_output:
            print("%15.4f" % current_time, ':', msg)

In [2]:
import time

import numpy as np
import torch


def weights(c=0.8, k=64):
    r = c / k
    ret = []
    for i in range(k):
        w = 1 - i * r
        ret.append(w)
    return ret


parts = list(range(5, 100, 5))


def prepare_data_cuda(batch_data, config, func=None):
    ret = func(batch_data, config)
    cuda_ret = ()
    for it in list(ret):
        if type(it) != list:
            cuda_ret += (it.cuda(config.device),)
        else:
            cuda_ret += (it, )
    return cuda_ret


def offline(config, log):
    __, valid_set = get_data(config, log, mode="test")

    net = TokenLevelEncoder(config).cuda(config.device)
    if config.parallel:
        net = DataParallelModel(net)

    h = {}
    st = time.time()
    for batch_idx, batch_data in enumerate(valid_set):
        batch_chunks, batch_indices, batch_lengths, __ = \
            prepare_data_cuda(batch_data, config, prepare_offline_data)
        bsz, n, l = batch_chunks.shape
        k = config.offline_k // torch.cuda.device_count() // bsz
        m = n // k
        hiddens_np = np.empty((bsz, 0, config.doc_encoder.d_model), dtype=np.float64)
        for i in range(m):
            batch_chunks_ = batch_chunks[:, i * k: (i + 1) * k]
            with torch.no_grad():
                hiddens = net(batch_chunks_)
            if config.parallel:
                hiddens = torch.cat([it.to("cuda:0") for it in hiddens], dim=0)
            hiddens_np = np.concatenate([hiddens_np, hiddens.cpu().numpy()], axis=1)

        if m * k < n:
            batch_chunks_ = batch_chunks[:, m * k: n]
            with torch.no_grad():
                hiddens = net(batch_chunks_)
            if config.parallel:
                hiddens = torch.cat([it.to("cuda:0") for it in hiddens], dim=0)
            hiddens_np = np.concatenate([hiddens_np, hiddens.cpu().numpy()], axis=1)

        for i, (idx, l) in enumerate(zip(batch_indices, batch_lengths)):
            h[idx] = hiddens_np[i][:l, :]

        print(batch_idx + 1, hiddens_np.shape, batch_indices, (time.time() - st))
    name = "hidden_v_"+str(config.kernel_size)+"_"+str(config.stride)+"_"+config.window_type+".pkl"
    save_to_pkl(name, h)

In [3]:
import argparse
import os
from argparse import Namespace

import torch




def arg_loader():
    parser = argparse.ArgumentParser()

    """
        Actions: train/test/debug/offline
        Mode: ext/abs/swh/ret
        Mini: for debugging load less data
    """
    parser.add_argument('--do_train', action='store_true', help="Whether to run training")
    parser.add_argument('--do_test', action='store_true', help="Whether to run test")
    parser.add_argument('--debug', action="store_true")
    parser.add_argument('--offline', action='store_true')
    parser.add_argument('--decode', action='store_true')
    parser.add_argument('--example', action='store_true')
    parser.add_argument('--statistic', action='store_true')
    parser.add_argument('--create_ref', action='store_true')

    parser.add_argument('--mode', type=str, default='all')
    parser.add_argument('--offline_k', type=int, default=256)
    parser.add_argument('--epsilon', type=float, default=1.0)
    parser.add_argument('--ret_loss_per_token', action="store_true")
    parser.add_argument('--ignore_index', type=int, default=-1)
    parser.add_argument('--regularizer', type=int, default=1)
    parser.add_argument('--coef', type=float, default=0.1)

    parser.add_argument('--mini', action="store_true")
    parser.add_argument('--T_p', type=int, default=25)

    parser.add_argument('--Th', type=float, default=0.5)
    parser.add_argument('--cls_dim', type=int, default=256)

    parser.add_argument('--topk', type=int, default=3)
    parser.add_argument('--overlap', action="store_true")

    """
        Training Tricks:
            amp: Automatic Mixed Precision
            parallel: data-parallel
            device: default device
            n_workers: how many thread for data loading (current unavailable)
    """

    parser.add_argument('--amp', action='store_true', help="Whether or not to use amp during trianing")
    parser.add_argument('--parallel', action='store_true')
    parser.add_argument('--device', type=int, default=0)
    parser.add_argument('--n_workers', type=int, default=2)

    """
        Data Source: dataset/raw
            load: whether build from files or load from cache
            train/valid/test: the part of the dataset
            raw_train/valid/test: the path to raw file
            length_limit_input/output: maximum length
    """
    parser.add_argument('--load', action='store_true', help="Whether or not build from raw")

    # From dataset
    parser.add_argument('--from_dataset', action='store_true', help="From dataset")
    parser.add_argument('--dataset', type=str, default='podcasts')
    parser.add_argument('--train', type=str, default='train')
    parser.add_argument('--valid', type=str, default='valid')
    parser.add_argument('--test', type=str, default='test')

    # From Raw Files
    parser.add_argument('--from_raw', action='store_true', help="From raw text")
    parser.add_argument('--raw_train', type=str, default="raw_train.txt")
    parser.add_argument('--raw_valid', type=str, default="raw_valid.txt")
    parser.add_argument('--raw_test', type=str, default="raw_test.txt")

    # Truncated
    parser.add_argument('--input_limit', type=int, default=65536)
    parser.add_argument('--output_limit', type=int, default=128)
    parser.add_argument('--output_sent_limit', type=int, default=8)

    # Constrain
    parser.add_argument('--n_overlap_constrain', type=int, default=3)

    # Pre-ext parameters
    parser.add_argument('--threshold', type=float, default=0.2)

    """
        Pretrained Model Selection:
            --model_ext: roberta-base roberta-large
            --model_abs: facebook/bart-base facebook/bart-large facebook/bart-large-cnn
            --local: whether or not behind a firewall
        Current Version ext_model will takes bart encoder as its sentence encoder
    """
    parser.add_argument('--model_ext', type=str, default='roberta-large')
    parser.add_argument('--model_abs', type=str, default='facebook/bart-large')
    parser.add_argument('--local', action='store_true', help="Whether or not using local models")

    """
        Model Settings:
            Save Path
            Sliding Window
            Sampling (between ext part and abs part when training in mode "ext-abs")
            Extractor
    """
    # Save Path
    parser.add_argument('--main_path', type=str, default='./model')
    parser.add_argument('--abs_path', type=str, default='/abs')
    parser.add_argument('--ext_path', type=str, default='/ext')
    parser.add_argument('--swh_path', type=str, default='/swh')
    parser.add_argument('--ret_path', type=str, default='/ret')

    # Sliding Window
    parser.add_argument('--window_type', type=str, default="sent")
    parser.add_argument('--kernel_size', type=int, default=20)
    parser.add_argument('--stride', type=int, default=10)

    # Model Parameters
    parser.add_argument('--ext_type', type=str, default="TokenLevelEncoder")
    parser.add_argument('--d_query', type=int, default=1024)
    parser.add_argument('--d_key', type=int, default=1024)
    parser.add_argument('--d_att', type=int, default=1024)
    parser.add_argument('--d_inner', type=int, default=512)
    parser.add_argument('--d_model', type=int, default=1024)
    parser.add_argument('--dropout', type=float, default=0.1)
    parser.add_argument('--type_att', type=int, default=3)

    # Transformer Encoder
    parser.add_argument('--doc_n_layers', type=int, default=2)
    parser.add_argument('--doc_d_model', type=int, default=1024)
    parser.add_argument('--doc_n_head', type=int, default=16)
    parser.add_argument('--doc_d_ff', type=int, default=3072)
    parser.add_argument('--doc_max_len', type=int, default=1024)
    parser.add_argument('--doc_dropout', type=float, default=0.1)

    # Loss Parameters
    parser.add_argument('--label_smoothing', type=float, default=0.1)

    """
        Optimization:
            learning rate: 1e-5 (32), 1.224e-5(48), 3.4641e-5(384)
            weight decay
            adam_epsilon
            warmup steps: 5% to 10% of first epoch
    """

    parser.add_argument('--learning_rate', type=float, default=1e-5)
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--adam_epsilon', type=float, default=1e-7)
    parser.add_argument('--warmup_steps', type=int, default=100)

    """
        Training:
            max epoch: 5, 10, 20
            batch_size: 1, 8, 16, 32
            checkPoint_Min/Freq: CheckPoint parameters
            save each epoch: for restore training (usually no need)
    """
    parser.add_argument('--max_epoch', type=int, default=20)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--checkPoint_Min', type=int, default=0)
    parser.add_argument('--checkPoint_Freq', type=int, default=100)
    parser.add_argument('--save_each_epoch', action="store_true")

    # Testting Parameters
    parser.add_argument('--model', type=str, default='model_best.pth.tar')

    # Sentence Decoding Parameters
    parser.add_argument('--gen_max_len', type=int, default=128)
    parser.add_argument('--gen_min_len', type=int, default=20)
    parser.add_argument('--beam_size', type=int, default=4)
    parser.add_argument('--chunk_beam_size', type=int, default=4)
    parser.add_argument('--answer_size', type=int, default=4)
    parser.add_argument('--length_penalty', type=float, default=1.0)
    parser.add_argument('--no_repeat_ngram_size', type=int, default=3)

    args = parser.parse_args()
    # Build or Load
    args.build = not args.load

    # Special Tokens
    args.pad = 1
    args.UNK = 3
    args.cls = 0
    args.sep = 2
    args.BOS = 50261
    args.EOS = 50260
    args.MASK = 50264
    args.n_vocab = 50265

    # Doc Encoder
    args.doc_encoder = Namespace()
    args.doc_encoder.pre_training_model = args.model_ext
    args.doc_encoder.n_layers = args.doc_n_layers
    args.doc_encoder.d_model = args.doc_d_model
    args.doc_encoder.n_head = args.doc_n_head
    args.doc_encoder.d_ff = args.doc_d_ff
    args.doc_encoder.max_len = args.doc_max_len
    args.doc_encoder.dropout = args.doc_dropout
    args.doc_encoder.pad = args.pad
    args.doc_encoder.autoregressive = False

    # Local Pre-trained Models
    if args.local:
        args.model_ext = "../../pretrained_models/" + args.model_ext
        args.model_abs = "../../pretrained_models/" + args.model_abs

    # Data Loading options
    args.dataOptions = load_from_json("settings/dataset/" + str(args.dataset) + ".json")
    args.strategy = None

    # Model Save Path
    if args.mode == "ext":
        args.save_path = args.main_path + args.ext_path
    elif args.mode == "abs":
        args.save_path = args.main_path + args.abs_path
    elif args.mode == "swh":
        args.save_path = args.main_path + args.swh_path
    elif args.mode == "ret":
        args.save_path = args.main_path + args.ret_path
    else:
        args.save_path = args.main_path + "/ret-" + str(args.stride) + "-01"
    # Make Dirs
    if not os.path.exists(args.main_path):
        os.makedirs(args.main_path)

    if not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # tricks
    # args.triGramTrick = not args.no_triGramTrick

    # Threshold
    if args.mode == "ext" and (args.do_train or args.do_test):
        name = "parts_" + str(args.kernel_size) + "_" + str(args.stride) + "_" + args.window_type + ".pt"
        parts = torch.load(name)
        args.T = list(parts.values())[args.T_p // 5 - 1]
        print(args.T)

    # print(args)
    return args

In [5]:
import torch



def get_data(config, log, mode='test'):
    print("get_data test")
    tokenizer = MyTokenizer(config)
    valid_set = Dataset(
        name=config.test,
        len_func=lambda x: sum(len(it) for it in x[0]),
        config=config,
        tokenizer=tokenizer,
        log=log,
        mode=mode
    )
    n_valid = len(valid_set)
    log.log("There are %d batches in valid data" % n_valid)
    return n_valid, valid_set


def get_network(config, log):
    # Model Setup
    log.log("Building Model")
    net = Retrieval(config)

    # Loading Parameter
    log.log("Loading Parameters")
    best_model = torch.load(config.save_path + "/" + config.model)
    new_stat_dict = {}
    for key, value in best_model["state_dict"].items():
        if key.startswith("module."):
            new_key = key[7:]
        else:
            new_key = key
        new_stat_dict[new_key] = value
    net.load_state_dict(new_stat_dict)
    log.log("Parameters Loaded")

    net = net.cuda(config.device)

    net.eval()
    log.log("Finished Build Model")
    return net


def test(config, log):
    tokenizer = MyTokenizer(config)
    config.batch_size = 1
    config.mode = "ret"
    net = get_network(config, log)
    _, valid_set = get_data(config, log)

    name = "hidden_v_" + str(config.kernel_size) + "_" + str(config.stride) + "_" + config.window_type + ".pkl"
    h_v = load_from_pkl(name)

    suffix = "_" + str(config.stride) + "_" + str(config.beam_size) + "_" + str(config.length_penalty)
    f = open("summary" + suffix + ".txt", "w")

    torch.cuda.empty_cache()

    for batch_idx, batch_data in enumerate(valid_set):
        print(batch_idx)
        answer, _ = beam_search(net, batch_data, config, h=h_v, output_others=True)

        ans = []
        first = True
        for sent in answer[2]:
            if first:
                ans.append(sent)
                first = False
            else:
                ans[-1].append(sent[0])
                ans.append(sent[1:])
        summary_text = " [SSPLIT] ".join(tokenizer.decode(sent) for sent in ans)
        if summary_text.endswith(" [SSPLIT] "):
            summary_text = summary_text[:-10]
        print(summary_text)
        print(summary_text, file=f)

    f.close()

In [10]:
import json
import random

import numpy as np
import torch
import torch.utils.data as Data



class MyDataset(Data.Dataset):
    def __init__(self, name, len_func, config, log, mode='train', prepare_func=None):
        self.name = name
        self.log = log
        self.len_func = len_func
        self.mode = mode
        self.prepare_func = prepare_func
        self.config = config

        self.build = config.build
        self.batch_size = config.batch_size

        self.path = config.dataOptions['Parts'][name]['path']
        self.sorted = config.dataOptions['Parts'][name]['sorted']
        self.shuffled = config.dataOptions['Parts'][name]['shuffled']

        self.n_data = 0
        self.data = []
        self.n_batch = 0
        self.batch = []
        self.batch_idx = []

    def sort_by_length(self):
        self.log.log('Start sorting by length')
        data = self.data
        number = self.n_data

        lengths = [(self.len_func(data[Index]), Index) for Index in range(number)]
        sorted_lengths = sorted(lengths)
        sorted_index = [d[1] for d in sorted_lengths]

        data_new = [data[sorted_index[Index]] for Index in range(number)]

        self.data = data_new
        self.log.log('Finish sorting by length')

    def shuffle(self):
        self.log.log('Start Shuffling')

        data = self.data
        number = self.n_data

        shuffle_index = list(range(number))
        random.shuffle(shuffle_index)

        data_new = [data[shuffle_index[Index]] for Index in range(number)]

        self.data = data_new
        self.log.log('Finish Shuffling')

    def gen_batches(self):
        batch_size = self.batch_size
        data = self.data
        number = self.n_data
        n_dim = len(data[0])

        number_batch = number // batch_size
        batches = []

        for bid in range(number_batch):
            batch_i = []
            for j in range(n_dim):
                data_j = [item[j] for item in data[bid * batch_size: (bid + 1) * batch_size]]
                batch_i.append(data_j)
            batches.append(batch_i)

        if number_batch * batch_size < number:
            if number - number_batch * batch_size >= torch.cuda.device_count():
                batch_i = []
                for j in range(n_dim):
                    data_j = [item[j] for item in data[number_batch * batch_size:]]
                    batch_i.append(data_j)
                batches.append(batch_i)
                number_batch += 1

        self.n_batch = number_batch
        self.batch = batches
        self.batch_idx = list(range(self.n_batch))

    def load(self):
        pass

    def after_load(self):
        if (self.mode != "test") and self.sorted:
            self.sort_by_length()
        if (self.mode != "test") and self.shuffled:
            self.shuffle()

        # Generate Batches
        self.log.log('Generating Batches')
        self.gen_batches()

        for_save = {
            "n_data": self.n_data,
            "Data": self.data,
            "n_batch": self.n_batch,
            "Batch": self.batch,
            "Batch_idx": self.batch_idx
        }

        save_to_pkl(self.name + ".cache", for_save)

    def batch_shuffle(self):
        random.shuffle(self.batch_idx)

    def __len__(self):
        if self.mode == 'train' or self.mode == 'valid':
            return self.n_batch
        return self.n_data

    def __getitem__(self, index):
        if self.prepare_func is None:
            return self.batch[self.batch_idx[index]]
        return self.prepare_func(self.batch[self.batch_idx[index]])


class Podcasts(MyDataset):
    def __init__(self, name, len_func, tokenizer, config, log, mode='train', prepare_func=None):
        super(Podcasts, self).__init__(name, len_func, config, log, mode, prepare_func)
        self.mini = config.mini
        self.input_limit = config.input_limit
        self.output_limit = config.output_limit
        self.kernel_size = config.kernel_size
        self.stride = config.stride

        self.tokenizer = tokenizer
        self.train_mode = config.mode
        self.threshold = config.threshold
        self.mode = mode

        # Loading Dataset
        if self.build:
            self.log.log('Building dataset %s from orignial text documents' % self.name)
            self.n_data, self.data = self.load()
            self.after_load()
            self.log.log('Finish Loading dataset %s' % self.name)
        else:
            self.log.log("Loading dataset %s from cached files" % self.name)
            for_load = load_from_pkl(self.name + ".cache")
            self.n_data = for_load["n_data"]
            self.data = for_load["Data"]
            self.n_batch = for_load["n_batch"]
            self.batch = for_load["Batch"]
            self.batch_idx = for_load["Batch_idx"]
            self.log.log('Finish Loading dataset %s' % self.name)

    @staticmethod
    def truncate(inputs, front, rear):
        if len(inputs) > front + rear:
            new_inputs = inputs[:front]
            if rear > 0:
                new_inputs += inputs[-rear:]
        else:
            new_inputs = inputs
        return new_inputs

    def load(self):
        input_file = open(self.path + ".json", "r", encoding='utf-8')
        data = []

        f = open(self.name + "_example.txt", "w")

        for index, line in enumerate(input_file):
            if self.mini and (index >= 5000):
                break
            data_i = json.loads(line)
            idx = data_i["episode_uri"]
            inputs_ = data_i["input"]
            outputs_ = data_i["output"]

            inputs = [self.tokenizer.encode(seg) for seg in inputs_]
            outputs = [self.tokenizer.encode(seq) for seq in outputs_]

            if len(inputs) == 0:
                print("Error: Input Empty at", index)
                continue

            if len(outputs) == 0 and self.mode != "test":
                print("Error: Output Empty at", index)
                continue

            lo = sum([len(it) for it in outputs])
            if lo < 10 and self.mode != "test":
                print("Error: Output is too short (%d < 10)." % lo, index)
                continue

            output_lengths = np.asarray([len(seq) for seq in outputs]).cumsum()
            t = 0
            while t + 1 < len(outputs) and output_lengths[t + 1] <= self.output_limit:
                t += 1

            outputs = outputs[:t + 1]

            if self.mode != "test" and apply_filter(inputs, outputs, self.config):
                print("Filter: Instance has too little overlap.", index)
                continue

            data.append([inputs, outputs, idx])

        f.close()
        return len(data), data

In [8]:
!pip install packaging==21.3



In [12]:
import torch
import torch.nn as nn


class BinaryClsHead(nn.Module):
    def __init__(self, n_in, n_out, dropout):
        super().__init__()
        self.dense = nn.Linear(n_in, n_out)
        self.dropout = nn.Dropout(p=dropout)
        self.out_proj = nn.Linear(n_out, 1)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x.squeeze(-1)


class MultiClsHead(nn.Module):
    def __init__(self, n_in, n_inner, n_out, dropout):
        super().__init__()
        self.dense = nn.Linear(n_in, n_inner)
        self.dropout = nn.Dropout(p=dropout)
        self.out_proj = nn.Linear(n_inner, n_out)

    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [14]:
import torch


def batch_index_select(x, idx):
    """
    :param x: *, n, Dim
    :param idx: *, k
    :return: *, k , Dim
    """
    idx_ = idx.unsqueeze(-1).expand(idx.shape + (x.shape[-1],))
    return torch.gather(x, -2, idx_)

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable


class LabelSmoothing(nn.Module):
    def __init__(self, config):
        super(LabelSmoothing, self).__init__()
        self.crit = nn.KLDivLoss(size_average=False)
        self.pad_idx = config.pad
        self.confidence = 1.0 - config.label_smoothing
        self.smoothing = config.label_smoothing
        self.size = config.n_vocab

    def forward(self, predicts, target):
        assert self.size == predicts.size(1)
        dist = torch.full_like(predicts, self.smoothing / (self.size - 2))
        dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
        dist[:, self.pad_idx] = 0
        mask_idx = torch.nonzero(target.data == self.pad_idx)
        if mask_idx.dim() > 0:
            dist.index_fill_(0, mask_idx.squeeze(), 0.0)
        return self.crit(predicts, Variable(dist, requires_grad=False))


class KLDivLoss(nn.Module):
    def __init__(self, config):
        super(KLDivLoss, self).__init__()
        self.crit = LabelSmoothing(config)

    def forward(self, predicts, target, norm=1.0):
        loss = self.crit(predicts.contiguous().view(-1, predicts.size(-1)), target.contiguous().view(-1))
        return loss / norm


class BCELoss(nn.Module):
    def __init__(self):
        super(BCELoss, self).__init__()
        self.crit = nn.BCELoss(reduction='sum')

    def forward(self, input, target, norm= 1.0):
        return self.crit(input, target) / norm


class BCEWithLogitsLoss(nn.Module):
    def __init__(self):
        super(BCEWithLogitsLoss, self).__init__()

    def forward(self, inputs, target, weights=None, norm=1.0):
        inputs = inputs.reshape(-1)
        target = target.reshape(-1)
        weights = weights.reshape(-1)
        return F.binary_cross_entropy_with_logits(inputs, target, weights, reduction="sum") / norm


class CrossEntropy(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ignore_index = config.ignore_index
        self.crit = nn.CrossEntropyLoss(ignore_index=config.ignore_index, reduction="sum")

    def forward(self, pred, target):
        target = target.reshape(-1)
        num = target.shape[0]
        pred = pred.reshape(num, -1)
        return self.crit(pred, target)


class AdjacentRegularizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ignore_index = config.ignore_index

    def forward(self, pred, tgt_indices, tgt_mask, src_mask):
        """
            pred: bsz, tgt_len, src_len
            target_indices: bsz, n
            tgt_mask: bsze, n
            src_mask: bsz, src_len
        """
        if tgt_indices.shape[1] < 2:
            return 0
        pred_ = batch_index_select(pred, tgt_indices)
        pred_ += (src_mask.unsqueeze(1) - 1) * 1e6
        prob = torch.softmax(pred_, dim=-1)
        prob_sum = torch.cumsum(prob, dim=-1)
        diff = prob_sum[:, 1:, :] - prob_sum[:, :-1, :]
        #print(diff.shape, tgt_mask.shape, src_mask.shape)
        #diff = torch.where(diff > 0, diff, torch.zeros_like(diff))
        loss = torch.relu(diff) * tgt_mask.unsqueeze(2) * src_mask.unsqueeze(1)
        return loss.sum()


class CrossEntropyWithRegularizer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.coef = config.coef
        self.crit = CrossEntropy(config)
        if config.regularizer == 0:
            self.regularizer = None
        elif config.regularizer == 1:
            self.regularizer = AdjacentRegularizer(config)

    def regu(self, *args):
        if self.regularizer is None:
            return 0
        return self.regularizer(*args)

    def loss(self, *args):
        return self.crit(*args)

    def forward(self, pred, target, tgt_indices, tgt_mask, src_mask, norm_loss=1.0, norm_regu=1.0):
        loss = self.crit(pred, target) / norm_loss
        regu = self.regu(pred, tgt_indices, tgt_mask, src_mask) / norm_regu

        if loss < 0:
            print("Negative Loss Error", norm_loss)

        if regu < 0:
            print("Negative Regu Error", norm_regu)
        return loss + self.coef * regu

In [17]:
import math
import torch
import torch.nn as nn



class CrossAttention(nn.Module):
    """
        Attention Score = q^T W k + v^T tanh(W'[q; k] + b)
                       ~= (q^T U) (V^T k) + v^T tanh(W'[q; k] + b)

        Type of Attention:
            1: linear
            2: bilinear
            3: linear + bilinear
    """

    def __init__(self, config):
        super().__init__()
        self.type_att = config.type_att
        n_query = config.d_query
        n_key = config.d_key
        n_out = config.d_att

        self.bi_linear_query = nn.Linear(n_query, n_out)
        self.bi_linear_key = nn.Linear(n_key, n_out)

        self.linear_query = nn.Linear(n_query, n_out)
        self.linear_key = nn.Linear(n_key, n_out)
        self.att = nn.Linear(n_out, 1)

    def low_rank_bi_linear(self, query, key):
        """
            Relevancy Score
            Low Rank Bi-Linear Attention Implementation

            query: bsz, num_q, dim_q
            key: bsz, num_k, dim_k
            att_logits: bsz, num_q, num_k

            att_logits = q^T W k ~= q^T (U V^T) k = (q^T U) (V^T k)
        """
        query = self.bi_linear_query(query)
        # bsz, num_q, dim
        key = self.bi_linear_key(key)
        # bsz, num_k, dim

        d_k = query.shape[-1]
        logits = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
        return logits

    def linear(self, query, key):
        """
            Saliency Score
            Linear Attention Implementation

            query: bsz, num_q, dim_q
            key: bsz, num_k, dim_k
            att_logits: bsz, num_q, num_k

            att_logits = v^T tanh(W'[q; k] + b)
        """
        query = self.linear_query(query).permute(0, 2, 1).unsqueeze(-1)
        key = self.linear_key(key).permute(0, 2, 1).unsqueeze(-2)
        activation = (query + key).permute(0, 2, 3, 1)
        hidden = torch.tanh(activation)
        logits = self.att(hidden)
        return logits.squeeze(-1)

    def forward(self, query, key):
        ret = 0
        if self.type_att & 1 > 0:
            a = self.linear(query, key)
            ret += a
        if self.type_att & 2 > 0:
            b = self.low_rank_bi_linear(query, key)
            ret += b
        return ret


class Retrieval(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.mode = config.mode
        self.extractor = Extractor(config)
        self.abstracter = Abstracter.from_pretrained(config.model_abs)
        self.switch = BinaryClsHead(config.d_model * 2, config.d_inner, config.dropout)
        self.retrieval = CrossAttention(config)
        self.epsilon = config.epsilon

    def set_mode(self, mode):
        """
            mode: str
                ext: only extractor parameters
                abs: only abstractor parameters
                ret: retrieval parameters
                swh: switch parameters
        """
        self.mode = mode
        for para in self.parameters():
            para.requires_grad = False

        if self.mode == "ext":
            for para in self.extractor.parameters():
                para.requires_grad = True
        elif self.mode == "abs":
            for para in self.abstracter.parameters():
                para.requires_grad = True
        elif self.mode == "ret":
            for para in self.retrieval.parameters():
                para.requires_grad = True
        elif self.mode == "swh":
            for para in self.switch.parameters():
                para.requires_grad = True

    def get_mode(self):
        return self.mode

    def retrieve(
            self,
            # Extractor
            chunk_input_ids=None,
            chunk_hidden=None,
            chunk_attention_mask=None,
            salience=None,
            # Abstracter
            hidden=None,
            src_mask=None,
    ):
        if salience is None:
            with torch.no_grad():
                # bsz, n_key
                salience = self.extractor(
                    chunk_input_ids=chunk_input_ids,
                    chunk_hidden=chunk_hidden,
                    chunk_attention_mask=chunk_attention_mask
                )
        if src_mask is None:
            src_mask = torch.ones_like(salience)

        src_mask_ = (src_mask.unsqueeze(1)-1) * 1e6
        relevance = self.retrieval(hidden, chunk_hidden)
        attention = torch.log_softmax(relevance + self.epsilon * salience.unsqueeze(1) + src_mask_, dim=-1)
        return attention

    def forward(
            self,
            # Extractor
            chunk_input_ids=None,
            chunk_hidden=None,
            chunk_attention_mask=None,
            # Abstracter
            input_ids=None,
            decoder_input_ids=None,
            encoder_attention_mask=None,
            # Decoding
            encoder_outputs=None,
            past_key_values=None,
            inputs_embeds=None,
            decoder_inputs_embeds=None,
            use_cache=None,
            output_attentions=None,
            output_hidden_states=None,
            return_dict=None,
            # Switch
            decoder_labels=None,
            # Retrieval
            src_mask=None,
            salience=None,
    ):
        if self.mode == "ext":
            return self.extractor(
                chunk_input_ids=chunk_input_ids,
                chunk_hidden=chunk_hidden,
                chunk_attention_mask=chunk_attention_mask
            )
        elif self.mode == "abs":
            return self.abstracter(
                input_ids=input_ids,
                decoder_input_ids=decoder_input_ids,
                encoder_attention_mask=encoder_attention_mask,
                encoder_outputs=encoder_outputs,
                past_key_values=past_key_values,
                inputs_embeds=inputs_embeds,
                decoder_inputs_embeds=decoder_inputs_embeds,
                use_cache=use_cache,
                output_attentions=output_attentions,
                output_hidden_states=output_hidden_states,
                return_dict=return_dict,
                hidden_only=False
            )
        else:
            with torch.no_grad():
                hidden = self.abstracter(
                    input_ids=input_ids,
                    decoder_input_ids=decoder_input_ids,
                    encoder_attention_mask=encoder_attention_mask,
                    encoder_outputs=encoder_outputs,
                    past_key_values=past_key_values,
                    inputs_embeds=inputs_embeds,
                    decoder_inputs_embeds=decoder_inputs_embeds,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                    output_hidden_states=output_hidden_states,
                    return_dict=return_dict,
                    hidden_only=True
                )
            if self.mode == 'ret':
                return self.retrieve(
                    chunk_input_ids=chunk_input_ids,
                    chunk_hidden=chunk_hidden,
                    chunk_attention_mask=chunk_attention_mask,
                    salience=salience,
                    hidden=hidden,
                    src_mask=src_mask,
                )
            elif self.mode == "swh":
                next_embedding = self.abstracter.model.shared(decoder_labels) * self.abstracter.model.embed_scale
                return self.switch(torch.cat([hidden, next_embedding], dim=-1))

    def adjust_logits_during_generation(self, logits, cur_len, max_length):
        if cur_len == 1:
            self._force_token_ids_generation(logits, self.config.cls)
        if cur_len == max_length - 1 and self.config.sep is not None:
            self._force_token_ids_generation(logits, self.config.sep)
        return logits

    def _force_token_ids_generation(self, scores, token_ids) -> None:
        """force one of token_ids to be generated by setting prob of all other tokens to 0"""
        if isinstance(token_ids, int):
            token_ids = [token_ids]
        all_but_token_ids_mask = torch.tensor(
            [x for x in range(self.config.n_vocab) if x not in token_ids],
            dtype=torch.long,
            device=next(self.parameters()).device,
        )
        assert len(scores.shape) == 2, "scores should be of rank 2 with shape: [batch_size, vocab_size]"
        scores[:, all_but_token_ids_mask] = -float("inf")

In [18]:
import math
from copy import deepcopy as cp

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F


def clones(module, n):
    return nn.ModuleList([cp(module) for _ in range(n)])


def attention(query, key, value, mask=None, dropout=None):
    d_k = query.size(-1)
    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -6e4)
    attn = F.softmax(scores, dim=-1)
    if dropout is not None:
        attn = dropout(attn)
    return torch.matmul(attn, value), attn


class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()

        self.d_model = config.d_model
        self.n_head = config.n_head
        self.d_k = config.d_model // config.n_head

        self.linears = clones(nn.Linear(self.d_model, self.d_model), 4)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, query, key, value, mask=None):
        if mask is not None:
            mask = mask.unsqueeze(1)
        batch_size = query.size(0)

        query, key, value = [l(x).reshape(batch_size, -1, self.n_head, self.d_k).transpose(1, 2)
                             for l, x in zip(self.linears, (query, key, value))]
        x, attn = attention(query, key, value, mask=mask, dropout=self.dropout)
        x = x.transpose(1, 2).contiguous().reshape(batch_size, -1, self.n_head * self.d_k)
        return self.linears[3](x), attn


class PositionwiseFeedForward(nn.Module):
    def __init__(self, config):
        super(PositionwiseFeedForward, self).__init__()

        self.w_1 = nn.Linear(config.d_model, config.d_ff)
        self.w_2 = nn.Linear(config.d_ff, config.d_model)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))


class SelfAttentionBlock(nn.Module):
    def __init__(self, config):
        super(SelfAttentionBlock, self).__init__()

        self.norm = nn.LayerNorm(config.d_model)
        self.attn = MultiHeadAttention(config)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x, mask):
        x_ = self.norm(x)
        x_, attn = self.attn(x_, x_, x_, mask)
        return self.dropout(x_) + x, attn


class SourceAttentionBlock(nn.Module):
    def __init__(self, config):
        super(SourceAttentionBlock, self).__init__()

        self.norm = nn.LayerNorm(config.d_model)
        self.attn = MultiHeadAttention(config)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x, m, mask):
        x_ = self.norm(x)
        x_, attn = self.attn(x_, m, m, mask)
        return self.dropout(x_) + x, attn


class FeedForwardBlock(nn.Module):
    def __init__(self, config):
        super(FeedForwardBlock, self).__init__()

        self.norm = nn.LayerNorm(config.d_model)
        self.feed_forward = PositionwiseFeedForward(config)
        self.dropout = nn.Dropout(p=config.dropout)

    def forward(self, x):
        x_ = self.norm(x)
        x_ = self.feed_forward(x_)
        return self.dropout(x_) + x


class EncoderBlock(nn.Module):
    def __init__(self, config):
        super(EncoderBlock, self).__init__()
        self.self_attn = SelfAttentionBlock(config)
        self.feed_forward = FeedForwardBlock(config)

    def forward(self, x, mask):
        x, attn = self.self_attn(x, mask)
        x = self.feed_forward(x)
        return x, attn


class DecoderBlock(nn.Module):
    def __init__(self, config):
        super(DecoderBlock, self).__init__()

        self.self_attn = SelfAttentionBlock(config)
        self.src_attn = SourceAttentionBlock(config)
        self.feed_forward = FeedForwardBlock(config)

    def forward(self, x, m, src_mask, tgt_mask):
        x, attn_tgt = self.self_attn(x, tgt_mask)
        x, attn_src = self.src_attn(x, m, src_mask)
        x = self.feed_forward(x)
        return x, attn_src, attn_tgt


class Encoder(nn.Module):
    def __init__(self, config):
        super(Encoder, self).__init__()

        self.layers = clones(EncoderBlock(config), config.n_layers)
        self.norms = clones(nn.LayerNorm(config.d_model), config.n_layers)

    def forward(self, x, mask):
        outputs = []
        attns = []
        for layer, norm in zip(self.layers, self.norms):
            x, attn = layer(x, mask)
            outputs.append(norm(x))
            attns.append(attn)
        return outputs, attns


class PositionalEmbedding(nn.Module):
    def __init__(self, config):
        super(PositionalEmbedding, self).__init__()

        p2e = torch.zeros(config.max_len, config.d_model)
        position = torch.arange(0.0, config.max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0.0, config.d_model, 2) * (- math.log(10000.0) / config.d_model))
        p2e[:, 0::2] = torch.sin(position * div_term)
        p2e[:, 1::2] = torch.cos(position * div_term)

        self.register_buffer('p2e', p2e)

    def forward(self, x):
        shp = x.size()
        emb = Variable(torch.index_select(self.p2e, 0, x.reshape(-1)), requires_grad=False).reshape(shp + (-1,))
        return emb


class Transformer(nn.Module):
    def __init__(self, config):
        super(Transformer, self).__init__()
        self.p2e = PositionalEmbedding(config)
        self.encoder = Encoder(config)

    def forward(self, input_emb, position_ids, attention_mask):
        # position embedding projection
        projection = self.p2e(position_ids) + input_emb
        return self.encoder(projection, attention_mask)

In [20]:
from copy import deepcopy as cp


def mapping_tokenize(s, t):
    st = 0
    ed = 0
    mapping = []
    mapping_idx = []
    for idx, token in enumerate(s):
        token_ = token.lower()
        prefix = "".join([piece.replace('##', '') for piece in t[st:ed + 1]])
        while token_.startswith(prefix):
            ed += 1
            if ed >= len(t):
                break
            prefix = "".join([piece.replace('##', '') for piece in t[st:ed + 1]])
        if (ed - st > 1) or (sum(1 for c in token if c.isupper()) > 1) or (idx > 0):
            mapping_idx.append([(st, ed), idx])
            mapping.append([cp(t[st:ed]), token])
        st = ed
    return mapping


def detokenize(text, mapping):
    if mapping is None:
        return text
    text = " " + text
    for one_mapping in mapping:
        keys = "".join([key.replace('##', '') if key.startswith('##') else ' ' + key for key in one_mapping[0]])
        value = ' ' + one_mapping[1]
        text = text.replace(keys, value)
    text = list(text[1:])
    if len(text) > 0:
        text[0] = text[0].upper()
        text = "".join(text)
    return text

In [21]:
def get_n_gram(seq, n):
    return list(zip(*[seq[i:] for i in range(n)]))


def do_tricks(preds, source, target, config):
    ban_ids = []

    # n_gram_blocking
    if config.no_repeat_ngram_size == 1:
        ban_ids = list(set(target))
    if (config.no_repeat_ngram_size > 1) and (len(target) >= config.no_repeat_ngram_size):
        current_n_grams = get_n_gram(target, config.no_repeat_ngram_size)
        for n_gram in current_n_grams:
            if all(t_token == ng_token for t_token, ng_token in zip(target[1 - config.no_repeat_ngram_size:], n_gram)):
                ban_ids.append(n_gram[-1])

    # min_length
    if len(target) < config.gen_min_len:
        ban_ids.append(config.sep)

    for idx in ban_ids:
        preds[idx] = -float("inf")

    # blocking NAN and INF
    preds[preds != preds] = -1e9
    preds[preds == float("inf")] = -1e9

    return preds


def trigram_blocking(preds, target, config):
    ban_ids = []
    if config.triGramTrick and len(target) > 2:
        current_tri_grams = get_n_gram(target, 3)
        for tri_gram in current_tri_grams:
            if (target[-2] == tri_gram[0]) and (target[-1] == tri_gram[1]):
                ban_ids.append(tri_gram[2])

    for idx in ban_ids:
        preds[idx] = -1e9

    return preds

In [22]:
import json
import pickle
import shutil

import torch


# IO
def load_from_json(filename):
    f = open(filename, 'r', encoding='utf-8')
    data = json.load(f, strict=False)
    f.close()
    return data


def save_to_json(filename, data):
    f = open(filename, 'w', encoding='utf-8')
    json.dump(data, f, indent=4)
    f.close()
    return True


def save_to_pkl(filename, data):
    with open(filename, 'wb')as f:
        pickle.dump(data, f)
    return


def load_from_pkl(filename):
    with open(filename, 'rb') as f:
        data = pickle.load(f)
    return data


def write_file(filename, massage):
    with open(filename, 'w', encoding='utf-8') as f:
        f.write(massage)
    return True


def save_check_point(state, is_best, path='.model', file_name='latest.pth.tar', mode=None):
    if mode is None:
        name = path + '/' + file_name
        torch.save(state, name)
        if is_best:
            shutil.copyfile(name, path + '/model_best.pth.tar')
            shutil.copyfile(name, path + '/model_best_epoch_' + str(state['epoch']) + '.pth.tar')
    else:
        name = path + '/' + mode + '_' + file_name
        torch.save(state, name)
        if is_best:
            shutil.copyfile(name, path + '/' + mode + '_model_best.pth.tar')
            shutil.copyfile(name, path + '/' + mode + '_model_best_epoch_' + str(state['epoch']) + '.pth.tar')

In [23]:
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu
# Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co
# Copyright (c) 2017-2018
#
# This source code is licensed under the MIT-style license found in the
# LICENSE file in the root directory of this source tree
#+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

"""Encoding Data Parallel"""
import functools
import threading

import torch
import torch.cuda.comm as comm
from torch.autograd import Variable, Function
from torch.nn.parallel._functions import Broadcast
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.nn.parallel.parallel_apply import get_a_var

torch_ver = torch.__version__[:3]

__all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion',
           'patch_replication_callback']


def allreduce(*inputs):
    """Cross GPU all reduce autograd operation for calculate mean and
    variance in SyncBN.
    """
    return AllReduce.apply(*inputs)


class AllReduce(Function):
    @staticmethod
    def forward(ctx, num_inputs, *inputs):
        ctx.num_inputs = num_inputs
        ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)]
        inputs = [inputs[i:i + num_inputs] for i in range(0, len(inputs), num_inputs)]
        # sort before reduce sum
        inputs = sorted(inputs, key=lambda i: i[0].get_device())
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return tuple([t for tensors in outputs for t in tensors])

    @staticmethod
    def backward(ctx, *inputs):
        inputs = [i.data for i in inputs]
        inputs = [inputs[i:i + ctx.num_inputs]
                 for i in range(0, len(inputs), ctx.num_inputs)]
        results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0])
        outputs = comm.broadcast_coalesced(results, ctx.target_gpus)
        return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors])


class Reduce(Function):
    @staticmethod
    def forward(ctx, *inputs):
        ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))]
        inputs = sorted(inputs, key=lambda i: i.get_device())
        return comm.reduce_add(inputs)

    @staticmethod
    def backward(ctx, grad_output):
        return Broadcast.apply(ctx.target_gpus, grad_output)


class DistributedDataParallelModel(DistributedDataParallel):
    """Implements data parallelism at the module level for the DistributedDataParallel module.
    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the
    batch dimension.
    In the forward pass, the module is replicated on each device,
    and each replica handles a portion of the input. During the backwards pass,
    gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
    :class:`encoding.parallel.DataParallelCriterion`.
    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).
    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
    Example::
        >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2])
        >>> y = net(x)
    """
    def gather(self, outputs, output_device):
        return outputs


class DataParallelModel(DataParallel):
    """Implements data parallelism at the module level.
    This container parallelizes the application of the given module by
    splitting the input across the specified devices by chunking in the
    batch dimension.
    In the forward pass, the module is replicated on each device,
    and each replica handles a portion of the input. During the backwards pass,
    gradients from each replica are summed into the original module.
    Note that the outputs are not gathered, please use compatible
    :class:`encoding.parallel.DataParallelCriterion`.
    The batch size should be larger than the number of GPUs used. It should
    also be an integer multiple of the number of GPUs so that each chunk is
    the same size (so that each GPU processes the same number of samples).
    Args:
        module: module to be parallelized
        device_ids: CUDA devices (default: all devices)
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
    Example::
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> y = net(x)
    """
    def gather(self, outputs, output_device):
        return outputs

    def replicate(self, module, device_ids):
        modules = super(DataParallelModel, self).replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules


class DataParallelCriterion(DataParallel):
    """
    Calculate loss in multiple-GPUs, which balance the memory usage.
    The targets are splitted across the specified devices by chunking in
    the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`.
    Reference:
        Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi,
        Amit Agrawal. “Context Encoding for Semantic Segmentation.
        *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018*
    Example::
        >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2])
        >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2])
        >>> y = net(x)
        >>> loss = criterion(y, target)
    """
    def forward(self, inputs, *targets, **kwargs):
        # input should be already scatterd
        # scattering the targets instead
        if not self.device_ids:
            return self.module(inputs, *targets, **kwargs)
        targets, kwargs = self.scatter(targets, kwargs, self.device_ids)
        if len(self.device_ids) == 1:
            return self.module(inputs, *targets[0], **kwargs[0])
        replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
        outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs)
        #return Reduce.apply(*outputs) / len(outputs)
        #return self.gather(outputs, self.output_device).mean()
        return self.gather(outputs, self.output_device)


def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None):
    assert len(modules) == len(inputs)
    assert len(targets) == len(inputs)
    if kwargs_tup:
        assert len(modules) == len(kwargs_tup)
    else:
        kwargs_tup = ({},) * len(modules)
    if devices is not None:
        assert len(modules) == len(devices)
    else:
        devices = [None] * len(modules)

    lock = threading.Lock()
    results = {}
    if torch_ver != "0.3":
        grad_enabled = torch.is_grad_enabled()

    def _worker(i, module, input, target, kwargs, device=None):
        if torch_ver != "0.3":
            torch.set_grad_enabled(grad_enabled)
        if device is None:
            device = get_a_var(input).get_device()
        try:
            with torch.cuda.device(device):
                # this also avoids accidental slicing of `input` if it is a Tensor
                if not isinstance(input, (list, tuple)):
                    input = (input,)
                if not isinstance(target, (list, tuple)):
                    target = (target,)
                output = module(*(input + target), **kwargs)
            with lock:
                results[i] = output
        except Exception as e:
            with lock:
                results[i] = e
            raise ValueError('Exception Detected')

    if len(modules) > 1:
        threads = [threading.Thread(target=_worker,
                                    args=(i, module, input, target,
                                          kwargs, device),)
                   for i, (module, input, target, kwargs, device) in
                   enumerate(zip(modules, inputs, targets, kwargs_tup, devices))]

        for thread in threads:
            thread.start()
        for thread in threads:
            thread.join()
    else:
        _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

    outputs = []
    for i in range(len(inputs)):
        output = results[i]
        if isinstance(output, Exception):
            raise output
        outputs.append(output)
    return outputs


###########################################################################
# Adapted from Synchronized-BatchNorm-PyTorch.
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
#
class CallbackContext(object):
    pass


def execute_replication_callbacks(modules):
    """
    Execute an replication callback `__data_parallel_replicate__` on each module created
    by original replication.
    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
    Note that, as all modules are isomorphism, we assign each sub-module with a context
    (shared among multiple copies of this module on different devices).
    Through this context, different copies can share some information.
    We guarantee that the callback on the master copy (the first copy) will be called ahead
    of calling the callback of any slave copies.
    """
    master_copy = modules[0]
    nr_modules = len(list(master_copy.modules()))
    ctxs = [CallbackContext() for _ in range(nr_modules)]

    for i, module in enumerate(modules):
        for j, m in enumerate(module.modules()):
            if hasattr(m, '__data_parallel_replicate__'):
                m.__data_parallel_replicate__(ctxs[j], i)


def patch_replication_callback(data_parallel):
    """
    Monkey-patch an existing `DataParallel` object. Add the replication callback.
    Useful when you have customized `DataParallel` implementation.
    Examples:
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
        > patch_replication_callback(sync_bn)
        # this is equivalent to
        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
    """

    assert isinstance(data_parallel, DataParallel)

    old_replicate = data_parallel.replicate

    @functools.wraps(old_replicate)
    def new_replicate(module, device_ids):
        modules = old_replicate(module, device_ids)
        execute_replication_callbacks(modules)
        return modules

    data_parallel.replicate = new_replicate

In [28]:
!pip install packaging==21.3

Collecting packaging==21.3
  Using cached packaging-21.3-py3-none-any.whl.metadata (15 kB)
Using cached packaging-21.3-py3-none-any.whl (40 kB)
Installing collected packages: packaging
  Attempting uninstall: packaging
    Found existing installation: packaging 22.0
    Uninstalling packaging-22.0:
      Successfully uninstalled packaging-22.0
Successfully installed packaging-21.3


In [30]:
!python run.py --offline --batch_size 1 --ext_type TokenLevelEncoder --window_type token --kernel_size 256 --stride 256

Number of available CUDA devices: 1
Using CUDA device 0
get_data test
         3.4396 : Building dataset test from orignial text documents
        58.0773 : Generating Batches
        58.3021 : Finish Loading dataset test
        58.3103 : There are 1027 batches in valid data
1 (1, 54, 1024) ['spotify:episode:74t5WREXUbhEKNI89CNSkL'] 2.1798017024993896
2 (1, 75, 1024) ['spotify:episode:5fG4VlWnWwzAt6mSs0H7lY'] 2.7420132160186768
3 (1, 54, 1024) ['spotify:episode:5hvOWPoB0j6HMrSVAMtJLV'] 3.1403274536132812
4 (1, 43, 1024) ['spotify:episode:7JG3lLnRoDdOxuqjf14ZkM'] 3.4496114253997803
5 (1, 44, 1024) ['spotify:episode:2WQ1GcC6J0k7qsO8Vvf2be'] 3.7607624530792236
6 (1, 53, 1024) ['spotify:episode:3kkhUQJ9DXYs6aSdDmPp2V'] 4.1623430252075195
7 (1, 46, 1024) ['spotify:episode:4fJ6Y6IpljKy8FT8DZHx1L'] 4.534748554229736
8 (1, 59, 1024) ['spotify:episode:5xBPWxqVCocdBgybmHjr5V'] 5.022202491760254
9 (1, 50, 1024) ['spotify:episode:0Sidld7sRx7bCpxEI9bdSs'] 5.452049016952515
10 (1, 25, 1024) ['spoti

In [31]:
!python run.py --do_test --window_type token --kernel_size 256 --stride 256

^C
