In [None]:
# New code.

from datetime import datetime
import os


class bcolors:
    # https://stackoverflow.com/questions/287871/how-to-print-colored-text-in-terminal-in-python
    DEFAULT = '\x1b[0m'
    RED = '\x1b[31m'
    GREEN = '\x1b[32m'
    YELLOW = '\x1b[33m'
    CYAN = '\x1b[36m'

    DEBUG = CYAN
    INFO = GREEN
    WARNING = YELLOW
    ERROR = RED
    CRITICAL = RED


def getLogger(logname):
    return PervasiveLogger(logname)


class PervasiveLogger(object):
    """
    The usual logging library is not able to print to the terminal, because
    it is run in a forked process. This logger fixes that.
    
    TODO: Derive it from a class in the Python logger library.
    """
    def __init__(self, logname):
        self.logname = logname
        self.logfile = None
        self.rank = os.getenv('RANK', None)

    def write(self, msg, level, color=None, *args):
        ts = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
        msg = msg % args
        prank = f'P{self.rank}' if self.rank else ''
        if color:
            full_msg = f'{ts} {prank} {color}{msg}{bcolors.DEFAULT}'
        else:
            full_msg = f'{ts} {prank} {msg}'
        print(full_msg)
        if self.logfile:
            with open(self.logfile, 'a') as f:
                f.write(full_msg)

    def set_logfile(self, logfile):
        self.logfile = logfile

    def critical(self, msg, *args):
        self.write(msg, 'CRIT', bcolors.CRITICAL, *args)

    def warning(self, msg, *args):
        self.write(msg, 'WARN', bcolors.WARNING, *args)

    def info(self, msg, *args):
        self.write(msg, 'INFO', bcolors.INFO, *args)

    def error(self, msg, *args):
        self.write(msg, 'ERROR', bcolors.ERROR, *args)

    def debug(self, msg, *args):
        self.write(msg, 'DEBUG', bcolors.DEBUG, *args)

In [None]:
# New code.

import fire
import yaml


class PervasiveApp(object):
    
    def train(config, gpu_ids=None, verbose=True):
        with open(config, 'r') as f:
            params = yaml.load(f)

        default_config = params.get('default_config', None)
        if default_config:
            with open(default_config, 'r') as f:
                params.update(yaml.load(f))

        if not gpu_ids:
            if 'gpu_ids' not in params:
                raise Exception('Expected parameter "gpu_ids" not supplied.')
            gpu_ids = params['gpu_ids']
            if not isinstance(gpu_ids, list):
                gpu_ids = [gpu_ids]
        
        project_dir = os.path.dirname(os.path.abspath(__file__))
        if 'model_name' not in params:
            raise Exception('Expected parameter "model_name" not supplied.')
        events_path = os.path.join(project_dir, 'events', params['model_name'])
        save_path = os.path.join(project_dir, 'save', params['model_name'])

        logger = PervasiveLogger(jobname)
        logger.info(f'Distributing to GPUs: {", ".join(gpu_ids)}')
        
        os.environ['MASTER_ADDR'] = '127.0.0.1'
        os.environ['MASTER_PORT'] = '3892'
        torch.multiprocessing.spawn(
            train_worker, args=(gpu_ids, params), nprocs=len(gpu_ids), join=True)
        logger.info(f'All {len(gpu_ids)} training processes joined. Shutting down.')


if __name__ == '__main__':
    fire.Fire(PervasiveApp)

In [None]:
#
# Function from https://github.com/elbayadm/attn2d/blob/master/nmt/models/aggregator.py
#

def truncated_max(tensor, src_lengths):
    """
    Max-pooling up to effective length
    
    input size: N, d, Tt, Ts
    src_lengths : N,
    """
    Pool = []
    Attention = []
    for n in range(tensor.size(0)):
        X = tensor[n]
        xpool, attn = X[:, :, :src_lengths[n]].max(dim=2)
        Pool.append(xpool.unsqueeze(0))
    result = torch.cat(Pool, dim=0)
    return result


def truncated_mean(tensor, src_lengths):
    """
    Average-pooling up to effective length

    input size: N, d, Tt, Ts
    src_lengths : N,
    """
    Pool = []
    Attention = []
    for n in range(tensor.size(0)):                                                            
        X = tensor[n]
        xpool = X[:, :, :src_lengths[n]].mean(dim=2)                                           
        xpool *=  math.sqrt(src_lengths[n])                                                    
        Pool.append(xpool.unsqueeze(0))                                                        
    result = torch.cat(Pool, dim=0)
    return result


def average_code(tensor, src_lengths=None):
    return tensor.mean(dim=3)


def max_code(tensor, src_lengths=None):
    return tensor.max(dim=3)[0]


class Aggregator(nn.Module):
    """
    Implements max pool layer, etc.
    
    Adapted from https://github.com/elbayadm/attn2d/blob/master/nmt/models/aggregator.py
    """
    def __init__(self, input_channls, mode, output_channels=None, params={}):
        nn.Module.__init__(self)
        mode = params.get("mode", "max")
        self.output_channels = input_channls
        if mode == 'mean':
            self.project = average_code
        elif mode == 'max':
            self.project = max_code
        elif mode == 'truncated-max':
            self.project = truncated_max
        elif mode == 'truncated-mean':
            self.project = truncated_mean
        else:
            raise ValueError('Unknown mode %s' % mode)
        # Map the final output to the requested dimension
        # for when tying the embeddings with the final projection layer
        print(self.output_channels, end='')
        lin = nn.Linear(self.output_channels, force_output_channels)
        print(">", force_output_channels)
        self.output_channels = force_output_channels

    def forward(self, tensor, src_lengths):
        proj = self.project(tensor, src_lengths)
        proj = proj.permute(0, 2, 1)
        return self.lin(proj)

In [None]:
class MaskedConv2d(nn.Conv2d):
    """
    Masked (autoregressive) conv2d.
    
    From https://github.com/elbayadm/attn2d/blob/master/nmt/models/conv2d.py
    """
    def __init__(self, in_channels, out_channels,
                 kernel_size=3, dilation=1,
                 groups=1, bias=False):
        pad = (dilation * (kernel_size - 1)) // 2
        super(MaskedConv2d, self).__init__(in_channels, out_channels,
                                           kernel_size,
                                           padding=pad,
                                           groups=groups,
                                           dilation=dilation,
                                           bias=bias)
        self.register_buffer('mask', self.weight.data.clone())
        _, _, kH, kW = self.weight.size()
        self.mask.fill_(1)
        if kH > 1:
            self.mask[:, :, kH // 2 + 1:, :] = 0
        self.incremental_state = torch.zeros(1, 1, 1, 1)

    def forward(self, x, *args):
        self.weight.data *= self.mask
        return super(MaskedConv2d, self).forward(x)

    def update(self, x):
        k = self.weight.size(2) // 2 + 1
        buffer = self.incremental_state
        if buffer.size(2) < k:
            output = self.forward(x)
            self.incremental_state = x.clone()
        else:
            # Shift the buffer and add the recent input.
            buffer[:, :, :-1, :] = buffer[:, :, 1:, :].clone()
            buffer[:, :, -1:, :] = x[:, :, -1:, :]
            output = self.forward(buffer)
            self.incremental_state = buffer.clone()
        return output

In [None]:
def _bn_function_factory(norm, relu, conv):
    """
    From https://github.com/elbayadm/attn2d/blob/master/nmt/models/efficient_densenet.py
    """
    def bn_function(*inputs):
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = conv(relu(norm(concated_features)))
        return bottleneck_output

    return bn_function


def init_weights(model: nn.Module):
    """
    Initialize weights of all submodules.
    Not object-oriented, but prevents repetitive code.

    Cf. https://github.com/pytorch/vision/blob/master/torchvision/models/vgg.py#L46-L59
    """
    for m in model.modules():
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


class DenseLayer(nn.Module):
    """
    Layer of a DenseNet.
    
    Adapted from https://github.com/elbayadm/attn2d/blob/master/nmt/models/efficient_densenet.py
    """
    def __init__(self,
                 input_size,
                 growth_rate,
                 kernel_size=3,
                 bn_size=4,
                 dropout=0,
                 efficient=False):
        super(DenseLayer, self).__init__()
        self.kernel_size = kernel_size
        self.dropout = dropout
        self.efficient = efficient
        
        self.add_module('norm1', nn.BatchNorm2d(input_size)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(
            input_size,
            bn_size * growth_rate,
            kernel_size=1,
            bias=False))
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', MaskedConv2d(
            bn_size * growth_rate,
            growth_rate,
            kernel_size=kernel_size,
            bias=bias))

    def forward(self, *prev_features):
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        if self.efficient and any(prev_feature.requires_grad
                                  for prev_feature in prev_features):
            # Wins decreased memory at cost of extra computation.
            # Does not compute intermediate values, but recomputes them in backward pass.
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)
        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
        if self.dropout > 0:
            new_features = F.dropout(new_features, p=self.dropout,
                                     training=self.training)
        return new_features

    def reset_buffers(self):
        self.conv2.incremental_state = torch.zeros(1, 1, 1, 1)

    def update(self, x):
        maxh = self.kernel_size // 2 + 1
        if x.size(2) > maxh:
            x = x[:, :, -maxh:, :].contiguous()
        res = x
        x = self.conv1(self.relu1(self.norm1(x)))
        x = self.conv2.update(self.relu2(self.norm2(x)))
        return torch.cat([res, x], 1)


class DenseBlock(nn.Module):
    """
    Block of layers in a DenseNet.
    
    Adapted from https://github.com/elbayadm/attn2d/blob/master/nmt/models/efficient_densenet.py
    """
    def __init__(self, num_layers, input_size, bn_size, growth_rate,
                 dropout, efficient=False):
        super(DenseBlock, self).__init__()
        print('Dense channels:', input_size, end='')
        for i in range(num_layers):
            print(">", input_size + (i+1) * growth_rate, end='')
            layer = DenseLayer(
                input_size + i * growth_rate,
                growth_rate,
                kernels[i],
                bn_size,
                dropout,
                efficient=efficient
            )
            self.add_module(f'denselayer{i + 1}', layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.named_children():
            new_features = layer(*features)
            features.append(new_features)
        return torch.cat(features, 1)

    def update(self, x):
        for layer in list(self.children()):
            x = layer.update(x)
        return x

    def reset_buffers(self):
        for layer in list(self.children()):
            layer.reset_buffers()


class Transition(nn.Sequential):
    """
    Transiton layer between dense blocks to reduce number of channels.
    
    BN > ReLU > Conv(k=1)
    
    From https://github.com/elbayadm/attn2d/blob/master/nmt/models/efficient_densenet.py
    """
    def __init__(self, input_size, output_size):
        super(Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(input_size))
        self.add_module('relu', nn.ReLU(inplace=True))
        conv = nn.Conv2d(
                input_size,
                output_size,
                kernel_size=1,
                bias=False)
        self.add_module('conv', conv)

    def forward(self, x, *args):
        return super(Transition, self).forward(x)


class DenseNet(nn.Module):
    """ 
    Much more memory efficient, but slower.
    
    Adapted from https://github.com/elbayadm/attn2d/blob/master/nmt/models/efficient_densenet.py
    """
    def __init__(self, input_size, block_layers, bn_size=4, kernel_size=3,
                 dropout=0.2, growth_rate=32, divide_channels=2,
                 efficient=False):
        super(EfficientDenseNet, self).__init__()

        self.efficient = efficient

        self.model = nn.Sequential()
        num_features = input_size
        if divide_channels > 1:
            trans = nn.Conv2d(num_features, num_features // divide_channels, 1)
            torch.nn.init.xavier_normal_(trans.weight)
            self.model.add_module('initial_transition', trans)
            num_features = num_features // divide_channels

        for i, num_layers in enumerate(block_layers):
            block = DenseBlock(
                num_layers=num_layers,
                input_size=num_features,
                kernels=kernels,
                bn_size=bn_size,
                growth_rate=growth_rate,
                dropout=dropout,
                efficient=efficient
            )
            self.model.add_module(f'denseblock{i + 1}', block)
            num_features = num_features + num_layers * growth_rate
            trans = Transition(
                input_size=num_features,
                output_size=num_features // 2)
            self.model.add_module(f'transition{i + 1}', trans)
            num_features = num_features // 2
            print("> (trans) ", num_features)

        self.output_channels = num_features
        self.model.add_module('norm_final', nn.BatchNorm2d(num_features))
        self.model.add_module('relu_last', nn.ReLU(inplace=True))

    def forward(self, x):
        return  self.model(x.contiguous())

    def update(self, x):
        x = x.contiguous()
        for layer in list(self.model.children()):
            if isinstance(layer, DenseBlock):
                x = layer.update(x)
            else:
                x = layer(x)
        return x

    def reset_buffers(self):
        for layer in list(self.model.children()):
            if isinstance(layer, DenseBlock):
                layer.reset_buffers()

In [None]:
import math
import torch
import torch.nn.functional as F


class Pervasive(nn.Module):
    """
    Pervasive Attention Network.
    
    Based on https://github.com/elbayadm/attn2d/blob/master/nmt/models/pervasive.py
    """
    
    PAD = 0
    
    def __init__(self, name, net_type, Ts, Tt, special_tokens,
                 enc_input_dim=128, enc_dropout=0.2,
                 dec_input_dim=128, dec_dropout=0.2,
                 prediction_dropout=0.2):
        nn.Module.__init__(self)
        self.logger = logging.getLogger(name)
        self.src_vocab_size = src_vocab_size
        self.tgt_vocab_size = trg_vocab_size
        self.enc_dropout = enc_dropout
        self.dec_dropout = dec_dropout
        self.padding_idx = special_tokens['PAD']
        self.bos_token = special_tokens['BOS']
        self.eos_token = special_tokens['EOS']
        self.kernel_size = kernel_size
        self.src_embedding = nn.Embedding(
            self.src_vocab_size,
            enc_input_dim
            padding_idx,
            scale_grad_by_freq=False
        )
        self.tgt_embedding = nn.Embedding(
            self.tgt_vocab_size,
            dec_input_dim,
            dec_input_dropout,
            padding_idx=self.padding_idx
        )

        self.input_channels = \
            self.src_embedding.dimension + self.tgt_embedding.dimension

        self.logger.info('Model input channels: %d', self.input_channels)

        if divison_factor > 1:
            self.logger.info('Reducing the input channels by a factor of {division_factor}.')

        if net_type == "densenet":
            self.net = DenseNet(self.input_channels)
        elif net_type == "efficient-densenet":
            self.net = DenseNet(self.input_channels, efficient=True)
        elif net_type == "log-densenet":
            raise NotImplementedError('Log DenseNet not implemented.')
        else:
            raise ValueError(f'Unknown network type {net_type}.')

        self.logger.warning('Tying the decoder weights.')
        self.prediction.weight = self.tgt_embedding.label_embedding.weight
        
        self.aggregator = Aggregator(self.net.output_channels,
                                     dec_input_dim,
                                     params['aggregator'])
        self.final_output_channels = self.aggregator.output_channels

        self.prediction_dropout = nn.Dropout(prediction_dropout)
        self.logger.info('Output channels: %d', self.final_output_channels)
        self.prediction = nn.Linear(self.final_output_channels,
                                    self.trg_vocab_size)     

    def init_weights(self):
        # Tensorflow default embedding initialization (except they resample instead of clipping).
        src_std = 1 / math.sqrt(self.src_vocab_size)
        torch.nn.init.normal_(self.src_embedding, 0, src_std)
        torch.clamp_(self.src_embedding, min=-2*src_std, max=2*src_std)
        tgt_std = 1 / math.sqrt(self.tgt_vocab_size)
        torch.nn.init.normal_(self.tgt_embedding, 0, tgt_std)
        torch.clamp_(self.tgt_embedding, min=-2*tgt_std, max=2*tgt_std)
        init_weights(self.net)
        init_weights(self.aggregator)
        init_weights(self.prediction)

    def forward(self, src_data, tgt_data):
        src_emb = F.dropout(self.src_embedding(src_data), p=self.enc_dropout)
        tgt_emb = F.dropout(self.tgt_embedding(tgt_data), p=self.dec_dropout)
        Ts = src_emb.size(1)  # Source sequence length.
        Tt = tgt_emb.size(1)  # Target sequence length.
                    emb = F.dropout(emb,
                            p=self.dropout,
                            training=self.training)
        src_emb = src_emb.unsqueeze(1).repeat(1, Tt, 1, 1)
        tgt_emb = tgt_emb.unsqueeze(2).repeat(1, 1, Ts, 1)
        
        X = torch.cat((src_emb, trg_emb), dim=3)
        X = X.permute(0, 3, 1, 2)
        X = self.net(X)
        X = self.aggregator(X, src_data['lengths'])
        
        logits = F.log_softmax(
            self.prediction(self.prediction_dropout(X)), dim=2)
        return logits

    def update(self, X, src_lengths=None):
        X = X.permute(0, 3, 1, 2)
        X = self.net.update(X)
        X = self.aggregator(X, src_lengths)
        return X

In [None]:
#
# From https://github.com/elbayadm/attn2d/blob/master/nmt/utils/utils.py
#

def pload(path):
    """
    Pickle load
    """
    return pickle.load(open(path, 'rb'),
                       encoding='iso-8859-1')


def pdump(obj, path):
    """
    Picke dump
    """
    pickle.dump(obj, open(path, 'wb'),
                protocol=pickle.HIGHEST_PROTOCOL)


def set_seed(seed):
    """
    Set seed for reproducibility
    """
    torch.backends.cudnn.deterministic = True
    torch.cuda.manual_seed_all(seed)
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)

In [None]:
def evaluate_val_loss(job_name, trainer, src_loader, trg_loader, eval_kwargs):
    """
    Evaluate model.

    From https://github.com/elbayadm/attn2d/blob/master/nmt/models/evaluate.py
    """
    ground_truths = []
    batch_size = eval_kwargs.get('batch_size', 1)
    max_samples = eval_kwargs.get('max_samples', -1)
    split = eval_kwargs.get('split', 'val')
    verbose = eval_kwargs.get('verbose', 0)
    eval_kwargs['BOS'] = trg_loader.bos
    eval_kwargs['EOS'] = trg_loader.eos
    eval_kwargs['PAD'] = trg_loader.pad
    eval_kwargs['UNK'] = trg_loader.unk
    logger = logging.getLogger(job_name)

    # Switch to evaluation mode
    model = trainer.model
    crit = trainer.criterion
    model.eval()
    src_loader.reset_iterator(split)
    trg_loader.reset_iterator(split)
    n = 0
    loss_sum = 0
    ml_loss_sum = 0
    loss_evals = 0
    start = time.time()
    while True:
        # get batch
        data_src, order = src_loader.get_src_batch(split, batch_size)
        data_trg = trg_loader.get_trg_batch(split, order, batch_size)
        n += batch_size
        if model.version == 'seq2seq':
            source = model.encoder(data_src)
            source = model.map(source)
            if trainer.criterion.version == "seq":
                losses, stats = crit(model, source, data_trg)
            else:  # ML & Token-level
                # init and forward decoder combined
                decoder_logit = model.decoder(source, data_trg)
                losses, stats = crit(decoder_logit, data_trg['out_labels'])
        else:
            losses, stats = crit(model(data_src, data_trg), data_trg['out_labels'])

        loss_sum += losses['final'].data.item()
        ml_loss_sum += losses['ml'].data.item()
        loss_evals = loss_evals + 1
        if max_samples == -1:
            ix1 = data_src['bounds']['it_max']
        else:
            ix1 = max_samples
        if data_src['bounds']['wrapped']:
            break
        if n >= ix1:
            break
    logger.warn('Evaluated %d samples in %.2f s', n, time.time()-start)
    return ml_loss_sum / loss_evals, loss_sum / loss_evals


def evaluate_model(job_name, trainer, src_loader, trg_loader, eval_kwargs):
    """
    Evaluate model.

    From https://github.com/elbayadm/attn2d/blob/master/nmt/models/evaluate.py
    """
    preds = []
    ground_truths = []
    batch_size = eval_kwargs.get('batch_size', 1)
    max_samples = eval_kwargs.get('max_samples', -1)
    split = eval_kwargs.get('split', 'val')
    verbose = eval_kwargs.get('verbose', 0)
    eval_kwargs['BOS'] = trg_loader.bos
    eval_kwargs['EOS'] = trg_loader.eos
    eval_kwargs['PAD'] = trg_loader.pad
    eval_kwargs['UNK'] = trg_loader.unk
    logger = logging.getLogger(job_name)

    # Make sure to be in evaluation mode
    model = trainer.model
    crit = trainer.criterion
    model.eval()
    src_loader.reset_iterator(split)
    trg_loader.reset_iterator(split)
    n = 0
    loss_sum = 0
    ml_loss_sum = 0
    loss_evals = 0
    start = time.time()
    i = 0
    while True:
        i += 1
        batch_start = time.time()
        # get batch
        data_src, order = src_loader.get_src_batch(split, batch_size)
        data_trg = trg_loader.get_trg_batch(split, order, batch_size)
        n += batch_size
        if model.version == 'seq2seq':
            source = model.encoder(data_src)
            source = model.map(source)
            if trainer.criterion.version == "seq":
                losses, stats = crit(model, source, data_trg)
            else:  # ML & Token-level
                # init and forward decoder combined
                decoder_logit = model.decoder(source, data_trg)
                losses, stats = crit(decoder_logit, data_trg['out_labels'])
            batch_preds, _ = model.sample(source, eval_kwargs)
        else:
            losses, stats = crit(model(data_src, data_trg), data_trg['out_labels'])
            batch_preds, _ = model.sample(data_src, eval_kwargs)

        loss_sum += losses['final'].data.item()
        ml_loss_sum += losses['ml'].data.item()
        loss_evals = loss_evals + 1
        # Initialize target with <BOS> for every sentence Index = 2
        # print('batch preds', batch_preds)
        if isinstance(batch_preds, list):
            # wiht beam size unpadded preds
            sent_preds = [decode_sequence(trg_loader.get_vocab(),
                                          np.array(pred).reshape(1, -1),
                                          eos=trg_loader.eos,
                                          bos=trg_loader.bos)[0]
                          for pred in batch_preds]
        else:
            # decode
            sent_preds = decode_sequence(trg_loader.get_vocab(), batch_preds,
                                         eos=trg_loader.eos,
                                         bos=trg_loader.bos)
        # Do the same for gold sentences
        sent_source = decode_sequence(src_loader.get_vocab(),
                                      data_src['labels'],
                                      eos=src_loader.eos,
                                      bos=src_loader.bos)
        sent_gold = decode_sequence(trg_loader.get_vocab(),
                                    data_trg['out_labels'],
                                    eos=trg_loader.eos,
                                    bos=trg_loader.bos)
        if not verbose:
            verb = not (n % 1000)
        else:
            verb = verbose
        for (sl, l, gl) in zip(sent_source, sent_preds, sent_gold):
            preds.append(l)
            ground_truths.append(gl)
            if verb:
                lg.print_sampled(sl, gl, l)
        lg.print_sampled(sent_source[0], sent_preds[0], sent_gold[0])
        logger.info('Batch %d done in %.2f s', i, time.time() - batch_start)
        if max_samples == -1:
            ix1 = data_src['bounds']['it_max']
        else:
            ix1 = max_samples
        if data_src['bounds']['wrapped']:
            break
        if n >= ix1:
            break
    logger.warn('Evaluated %d samples in %.2f s', len(preds), time.time()-start)
    bleu_moses, _ = corpus_bleu(preds, ground_truths)
    return preds, ml_loss_sum / loss_evals, loss_sum / loss_evals, bleu_moses

In [None]:
import sys
import os.path as osp
import logging
import json
import time
import numpy as np

from math import sqrt, pi, cos, ceil
from tensorboardX import SummaryWriter
import torch
from torch import optim
from torch.optim import lr_scheduler


# From https://github.com/elbayadm/attn2d/blob/master/nmt/optimizer.py
TRACKERS = {
    'train/loss': [],
    'train/ml_loss': [],
    'val/loss': [],
    'val/ml_loss': [], 
    'val/perf/bleu': [],
    'optim/lr': [], 
    'optim/grad_norm': [],
    'optim/scheduled_sampling': [],
    'optim/ntokens': [],
    'optim/batch_size': [],
    'iteration': 0,
    'epoch': 1,
    'batch_offset': 0,
    'update': set(),
    'devices': [],
    'time': []
}


class InverseSquareRoot(lr_scheduler._LRScheduler):
    """
    Follow the schedule of Vaswani et al. 2017

    From https://github.com/elbayadm/attn2d/blob/master/nmt/optimizer.py
    """

    def __init__(self, optimizer,
                 warmup=4000, last_epoch=-1):
        self.warmup = warmup
        super(InverseSquareRoot, self).__init__(optimizer, last_epoch)

    def get_lr(self):
        it = self.last_epoch + 1
        scale_factor = min(1 / sqrt(it), it / self.warmup ** 1.5)
        return [base_lr * scale_factor
                for base_lr in self.base_lrs]


class ShiftedCosine(lr_scheduler._LRScheduler):
    """
    Similar to cosine

    From https://github.com/elbayadm/attn2d/blob/master/nmt/optimizer.py
    """

    def __init__(self, optimizer, cycles, T_max, last_epoch=-1):
        self.T_max = T_max
        self.cycle_duration = ceil(T_max / cycles)
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        scale_factor = 1/2 * (cos(pi * (self.last_epoch % self.cycle_duration) / self.cycle_duration) + 1)
        return [base_lr * scale_factor
                for base_lr in self.base_lrs]


class PlateauCosine(lr_scheduler._LRScheduler):
    """
    Steep decrease then ~ plateau
    [self.eta_min + (base_lr - self.eta_min) *
                    (1 + math.cos(math.pi * self.last_epoch / self.T_max)) / 2
                                    for base_lr in self.base_lrs]

    From https://github.com/elbayadm/attn2d/blob/master/nmt/optimizer.py
    """

    def __init__(self, optimizer, T1, T2, eta1, eta2, last_epoch=-1):
        self.T1 = int(T1)
        self.T2 = int(T2)
        self.eta1 = float(eta1)
        self.eta2 = float(eta2)
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch > self.T1:
            # use a wider period
            return [self.eta2 + (self.eta1 - self.eta2) *
                    (1 + cos(pi * self.last_epoch / self.T2)) / 2
                    for base_lr in self.base_lrs]

        # very steep
        return [self.eta1 + (base_lr - self.eta1) *
                (1 + cos(pi * self.last_epoch / self.T1)) / 2
                for base_lr in self.base_lrs]


def LRScheduler(opt, optimizer, last_epoch=-1):
    """
    Learning rate scheduler.
    
    From https://github.com/elbayadm/attn2d/blob/master/nmt/optimizer.py
    """
    ref = opt['schedule']
    if ref == "early-stopping":
        if opt['criterion'] == "loss":
            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       patience=opt['patience'],
                                                       factor=opt['decay_rate'],
                                                       verbose=True,
                                                       threshold=0.01,
                                                       min_lr=1e-5)
        elif opt['criterion'] == "perf":
            scheduler = lr_scheduler.ReduceLROnPlateau(optimizer,
                                                       mode="max",
                                                       patience=opt['patience'],
                                                       factor=opt['decay_rate'],
                                                       verbose=True,
                                                       threshold=0.05)

    elif ref == "cosine-ep":
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=opt['max_epochs'],
                                                   eta_min=opt.get('min_lr', 0))

    elif ref == "cosine":
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer,
                                                   T_max=opt['max_updates'],
                                                   eta_min=opt.get('min_lr', 0))
    elif ref == "shifted-cosine":
        scheduler = ShiftedCosine(optimizer,
                                  T_max=opt['max_updates'],
                                  cycles=opt['cycles'])

    elif ref == "plateau-cosine":
        scheduler = PlateauCosine(optimizer,
                                  T1=opt['T1'],
                                  T2=opt['T2'],
                                  eta1=opt['eta1'],
                                  eta2=opt['eta2'])

    elif ref == "step":
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=opt['decay_every'],
                                        gamma=opt['decay_rate'],
                                        last_epoch=last_epoch)
        # self.lr_scheduler = lr_scheduler.LambdaLR(self.optimizer.optimizer, self.anneal)
    elif ref == "step-iter":
        scheduler = lr_scheduler.StepLR(optimizer,
                                        step_size=opt['decay_every'],
                                        gamma=opt['decay_rate'],
                                        last_epoch=last_epoch)

    elif ref == "inverse-square":
        scheduler = InverseSquareRoot(optimizer,
                                      warmup=opt['warmup'],
                                      last_epoch=last_epoch)

    elif ref == 'multi-step':
        milestones = list(opt['milestones'].split(','))
        scheduler = lr_scheduler.MultiStepLR(optimizer,
                                             milestones,
                                             gamma=opt['decay_rate'],
                                             last_epoch=last_epoch)
    else:
        raise ValueError('Unknown scheduler % s' % ref)
    scheduler.mode = ref
    return scheduler


class Optimizer(object):
    """
    Wrapper for the optimizer (fairseq style)
    
    From https://github.com/elbayadm/attn2d/blob/master/nmt/optimizer.py
    """
    def __init__(self, opt, model):
        super().__init__()
        #  rmsprop | sgd | sgdmom | adagrad | adam
        ref = opt['solver'].lower()
        lr = opt['LR']['base']
        if isinstance(model, list):
            params = [{'params': m.parameters(),
                       'lr': lr}
                      for m in model]
        else:
            params = [{'params': model.parameters(), 'lr': lr}]

        if ref == 'adam':
            optimizer = optim.Adam(params,
                                   lr=lr,
                                   betas=(opt['alpha'], opt['beta']),
                                   weight_decay=opt['weight_decay'],
                                   eps=float(opt['epsilon']),
                                   amsgrad=bool(opt.get('amsgrad', 0)))
        elif ref == 'sgd':
            optimizer = optim.SGD(params,
                                  lr=lr,
                                  momentum=opt.get('momentum', 0),
                                  dampening=opt.get('dampening', 0),
                                  weight_decay=opt['weight_decay'],
                                  nesterov=bool(opt.get('nesterov', 0)))

        elif ref.lower() == 'rmsprop':
            optimizer = optim.RMSprop(params,
                                      lr=lr,
                                      alpha=opt['alpha'],
                                      eps=opt['epsilon'],
                                      weight_decay=opt['weight_decay'],
                                      momentum=opt.get('momentum', 0),
                                      centered=False)
        elif ref.lower() == 'adagrad':
            optimizer = optim.Adagrad(params,
                                      lr=lr,
                                      lr_decay=opt.get('lr_decay', 0),
                                      weight_decay=opt['weight_decay'],
                                      initial_accumulator_value=0)
        elif ref.lower() == 'nag':
            optimizer = NAG(params,
                            lr=lr,
                            momentum=opt['momentum'],
                            weight_decay=opt['weight_decay']
                            )

        else:
            raise ValueError('Unknown optimizer % s' % ref)

        self.optimizer = optimizer

    def get_lr(self):
        """Return the current learning rate."""
        return self.optimizer.param_groups[0]['lr']

    def set_lr(self, lr):
        """Set the learning rate."""
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def state_dict(self):
        """Return the optimizer's state dict."""
        return self.optimizer.state_dict()

    def load(self, state_dict):
        """Load an optimizer state dict. """
        self.optimizer.load_state_dict(state_dict)

    def step(self, closure=None):
        """Performs a single optimization step."""
        return self.optimizer.step(closure)

    def zero_grad(self):
        """Clears the gradients of all optimized parameters."""
        return self.optimizer.zero_grad()

    def require_grad(self):
        """Set requires_grad true for all params"""
        for p in self.optimizer.param_groups:
            if isinstance(p, dict):
                for pp in p['params']:
                    pp.requires_grad = True

                    
class Trainer(object):
    """
    Training a model with a given criterion.
    
    From https://github.com/elbayadm/attn2d/blob/master/nmt/trainer.py
    """

    def __init__(self, jobname, params, model, criterion):

        if not torch.cuda.is_available():
            raise NotImplementedError('Training on CPU is not supported')

        self.params = params
        self.jobname = jobname

        self.logger = logging.getLogger(jobname)
        # reproducibility:
        set_seed(params['optim']['seed'])

        self.clip_norm = params['optim']['grad_clip']
        self.num_batches = params['optim']['num_batches']

        # Move to GPU
        self.model = model.cuda()
        self.criterion = criterion.cuda()

        # Initialize optimizer and LR scheduler
        self.optimizer = Optimizer(params['optim'], model)
        self.lr_patient = params['optim']['LR']['schedule'] == "early-stopping"
        if self.lr_patient:
            self.lr_patient = params['optim']['LR']['criterion']
            self.logger.info('updating the lr wrt %s', self.lr_patient)
        self.lr_scheduler = LRScheduler(params['optim']['LR'],
                                        self.optimizer.optimizer,
                                        )

        self.tb_writer = SummaryWriter(params['eventname'])
        self.log_every = params['track']['log_every']
        self.checkpoint = params['track']['checkpoint']
        self.evaluate = False
        self.done = False
        self.trackers = TRACKERS
        self.iteration = 0
        self.epoch = 0
        self.batch_offset = 0
        # Dump  the model params:
        json.dump(params, open('%s/params.json' % params['modelname'], 'w'))

    def update_params(self, val_loss=None):
        """
        Update dynamic params:
        lr, scheduled_sampling probability and tok/seq's alpha
        """
        epoch = self.epoch
        iteration = self.iteration
        if not self.lr_patient:
            if self.lr_scheduler.mode in ["step-iter", "inverse-square",
                                          "cosine", 'shifted-cosine',
                                          'plateau-cosine']:
                self.lr_scheduler.step(iteration)
            else:
                self.lr_scheduler.step(epoch - 1)
        self.track('optim/lr', self.optimizer.get_lr())

    def step(self, data_src, data_trg, ntokens=0):
        """
        A signle forward step
        """
        # Clear the grads
        self.optimizer.zero_grad()
        batch_size = data_src['labels'].size(0)
        # evaluate the loss
        decoder_logit = self.model(data_src, data_trg)
        losses, stats = self.criterion(decoder_logit, data_trg['out_labels'])
        if not ntokens:
            ntokens = torch.sum(data_src['lengths'] *
                                data_trg['lengths']).data.item()

        return losses, batch_size, ntokens

    def backward_step(self, loss, ml_loss, ntokens, nseqs, start, wrapped):
        """
        A single backward step
        """
        loss.backward()
        if self.clip_norm > 0:
            grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       self.clip_norm)
        self.track('optim/grad_norm', grad_norm)
        self.track('optim/ntokens', ntokens)
        self.track('optim/batch_size', nseqs)

        self.optimizer.step()
        # torch.cuda.empty_cache()  # FIXME
        if np.isnan(loss.data.item()):
            sys.exit('Loss is nan')
        torch.cuda.synchronize()
        self.iteration += 1
        if wrapped:
            self.epoch += 1
        # Log
        if (self.iteration % self.log_every == 0):
            self.track('train/loss', loss.data.item())
            self.track('train/ml_loss', ml_loss.data.item())
            self.to_stderr(nseqs, ntokens, time.time()-start)
            self.tensorboard()

        self.evaluate = (self.iteration % self.checkpoint == 0)
        self.done = (self.epoch > self.params['optim']['max_epochs'])

    def validate(self, src_loader=None, trg_loader=None):
        """
        Evaluate on the dev set
        """
        params = self.params
        self.log('Evaluating the model on the validation set..')
        self.model.eval()
        if params.get('eval_bleu', 1):
            _, val_ml_loss, val_loss, bleu = evaluate_model(params['modelname'],
                                                            self,
                                                            src_loader,
                                                            trg_loader,
                                                            params['track'])
            self.log('BLEU: %.5f ' % bleu)
            self.track('val/perf/bleu', bleu)
            self.log('logged bleu log')
            save_best = (self.trackers['val/perf/bleu'][-1] ==
                         max(self.trackers['val/perf/bleu']))
            save_every = 0

        else:
            val_ml_loss, val_loss = evaluate_val_loss(params['modelname'],
                                                      self,
                                                      src_loader,
                                                      trg_loader,
                                                      params['track'])
            save_every = 1
            save_best = 0

        self.track('val/loss', val_loss)
        self.track('val/ml_loss', val_ml_loss)
        self.tensorboard()
        # Save model if still improving on the dev set
        # self.save_model(src_loader, trg_loader, save_best, save_every)
        # self.model.train()
        if self.lr_patient == "loss":
            self.log('Updating the learning rate - LOSS')
            self.lr_scheduler.step(val_loss)
            self.track('optim/lr', self.optimizer.get_lr())
        elif self.lr_patient == "perf":
            assert not save_every
            self.log('Updating the learning rate - PERF')
            self.lr_scheduler.step(bleu)
            self.track('optim/lr', self.optimizer.get_lr())

    def save_model(self, src_loader, trg_loader, save_best, save_every):
        """
        checkoint model, optimizer and history
        """
        params = self.params
        modelname = params['modelname']
        checkpoint_path = osp.join(modelname, 'model.pth')
        torch.save(self.model.state_dict(), checkpoint_path)
        self.log("model saved to {}".format(checkpoint_path))
        optimizer_path = osp.join(modelname, 'optimizer.pth')
        torch.save(self.optimizer.state_dict(), optimizer_path)
        self.log("optimizer saved to {}".format(optimizer_path))
        self.trackers['src_iterators'] = src_loader.iterators
        self.trackers['trg_iterators'] = trg_loader.iterators
        self.trackers['iteration'] = self.iteration
        self.trackers['epoch'] = self.epoch
        pdump(self.trackers, osp.join(modelname, 'trackers.pkl'))

        if save_best:
            checkpoint_path = osp.join(modelname, 'model-best.pth')
            torch.save(self.model.state_dict(), checkpoint_path)
            self.log("model saved to {}".format(checkpoint_path))
            optimizer_path = osp.join(modelname, 'optimizer-best.pth')
            torch.save(self.optimizer.state_dict(), optimizer_path)
            self.log("optimizer saved to {}".format(optimizer_path))
            pdump(self.trackers, osp.join(modelname, 'trackers-best.pkl'))

        if save_every:
            checkpoint_path = osp.join(modelname, 'model-%d.pth' % self.iteration)
            torch.save(self.model.state_dict(), checkpoint_path)
           self.log("model saved to {}".format(checkpoint_path))

    def load_checkpoint(self):
        """
        Load last saved params:
        for use with oar's idempotent jobs
        """
        params = self.params
        modelname = params['modelname']
        iterators_state = {}
        history = {}
        if osp.exists(osp.join(modelname, 'model.pth')):
            self.warn('Picking up where we left')
            # load model's weights
            saved_state = torch.load(osp.join(modelname, 'model.pth'))
            saved = list(saved_state)
            required_state = self.model.state_dict()
            required = list(required_state)
            del required_state
            if "module" in required[0] and "module" not in saved[0]:
                for k in saved:
                    kbis = "module.%s" % k
                    saved_state[kbis] = saved_state[k]
                    del saved_state[k]

            for k in saved:
                if "increment" in k:
                    del saved_state[k]
                if "transiton" in k:
                    kk = k.replace("transiton", "transition")
                    saved_state[kk] = saved_state[k]
                    del saved_state[k]
            self.model.load_state_dict(saved_state)
            # load the optimizer's last state:
            self.optimizer.load(
                torch.load(osp.join(modelname, 'optimizer.pth')
                           ))
            history = pload(osp.join(modelname, 'trackers.pkl'))
            iterators_state = {'src_iterators': history['src_iterators'],
                               'trg_iterators': history['trg_iterators']}

        elif params['start_from']:
            start_from = params['start_from']
            # Start from a pre-trained model:
            self.warn('Starting from %s' % start_from)
            if params['start_from_best']:
                flag = '-best'
                self.warn('Starting from the best saved model')
            else:
                flag = ''
            # load model's weights
            self.model.load_state_dict(
                    torch.load(osp.join(start_from, 'model%s.pth' % flag))
                    )
            # load the optimizer's last state:
            if not params['optim']['reset']:
                self.optimizer.load(
                    torch.load(osp.join(start_from, 'optimizer%s.pth' % flag)
                               ))
            history = pload(osp.join(start_from, 'trackers%s.pkl' % flag))
        self.trackers.update(history)
        self.epoch = self.trackers['epoch']
        self.iteration = self.trackers['iteration']
        return iterators_state

    def log(self, message):
        self.logger.info(message)

    def warn(self, message):
        self.logger.warning(message)

    def debug(self, message):
        self.logger.debug(message)

    def set_devices(self, devices):
        self.trackers['devices'].append(devices)
        self.trackers['time'].append(0)

    def increment_time(self, t):
        self.trackers['time'][-1] += t

    def track(self, k, v):
        """
        Track key metrics
        """
        if k not in self.trackers:
            raise ValueError('Tracking unknown entity %s' % k)
        if isinstance(self.trackers[k], list):
            self.trackers[k].append(v)
        else:
            self.trackers[k] = v
        self.trackers['update'].add(k)

    def tensorboard(self):
        """
        Write tensorboard events
        """
        for k in self.trackers['update']:
            self.tb_writer.add_scalar(k, self.trackers[k][-1], self.iteration)
        self.tb_writer.file_writer.flush()
        self.trackers['update'] = set()

    def to_stderr(self, batch_size, ntokens, timing):
        """
        Log to stderr
        """
        self.log('| epoch {:2d} '
                 '| iteration {:5d} '
                 '| lr {:02.2e} '
                 '| seq {:3d} '
                 '| sXt {:5d} '
                 '| ms/batch {:6.3f} '
                 '| total time {:6.2f} s'
                 '| loss {:6.3f} '
                 '| ml {:6.3f}'
                 .format(self.epoch,
                         self.iteration,
                         self.optimizer.get_lr(),
                         batch_size,
                         ntokens,
                         timing * 1000,
                         sum(self.trackers['time']),
                         self.trackers['train/loss'][-1],
                         self.trackers['train/ml_loss'][-1]))

In [None]:
class MLCriterion(nn.Module):
    """
    The default cross entropy loss.
    
    From https://github.com/elbayadm/attn2d/blob/master/nmt/loss/cross_entropy.py
    """
    def __init__(self, job_name, params):
        super().__init__()
        self.logger = logging.getLogger(job_name)
        self.th_mask = params.get('mask_threshold', 1)  # both pad and unk
        self.normalize = params.get('normalize', 'ntokens')
        self.version = 'ml'

    def log(self):
        self.logger.info('Default ML loss')

    def forward(self, logp, target):
        """
        logp : the decoder logits (N, seq_length, V)
        target : the ground truth labels (N, seq_length)
        """
        output = self.get_ml_loss(logp, target)
        return {"final": output, "ml": output}, {}

    def get_ml_loss(self, logp, target):
        """
        Compute the usual ML loss.
        """
        # print('logp:', logp.size(), "target:", target.size())
        batch_size = logp.size(0)
        seq_length = logp.size(1)
        vocab = logp.size(2)
        target = target[:, :seq_length]
        logp = to_contiguous(logp).view(-1, logp.size(2))
        target = to_contiguous(target).view(-1, 1)
        mask = target.gt(self.th_mask)
        ml_output = - logp.gather(1, target)[mask]
        ml_output = torch.sum(ml_output)

        if self.normalize == 'ntokens':
            # print('initial ml:', ml_output.data.item())
            norm = torch.sum(mask)
            ml_output /= norm.float()
            # print('norm ml:', ml_output.data.item(), '// %d' % norm.data.item())
        elif self.normalize == 'seqlen':
            # print('initial ml:', ml_output.data.item())
            norm = seq_length
            ml_output /= norm
            # print('norm ml:', ml_output.data.item(), '// %d' % norm)
        elif self.normalize == 'batch':
            # print('initial ml:', ml_output.data.item())
            norm = batch_size
            ml_output /= norm
            # print('norm ml:', ml_output.data.item(), '// %d' % norm)

        else:
            raise ValueError('Unknown normalizing scheme')
        return ml_output

In [None]:
# New code.

import logging
import json
import h5py
import numpy as np
import torch


class VocabData(object):
    def __init__(self, infos_fn):
        # Load index to word mapping from .infos file.
        self.infos_filename = infos_fn
        infos = pload(self.infos_filename)
        self.ix_to_word = infos['itow']
        self.vocab_size = len(self.ix_to_word)

        # Word to index mapping and special tokens.
        word_to_ix = {w: ix for ix, w in self.ix_to_word.items()}
        self.pad = word_to_ix['<PAD>']
        self.unk = word_to_ix['<UNK>']
        self.eos = word_to_ix['<EOS>']
        self.bos = word_to_ix['<BOS>']


class TextDataLoader(object):
    """
    Adapted from https://github.com/elbayadm/attn2d/blob/master/nmt/loader/dataloader.py
    """

    def __init__(self, src_infos, src_h5, tgt_infos, tgt_h5, batch_size,
                 max_length, model_name):
        self.model_name = model_name
        self.logger = logging.getLogger(model_name)

        self.src_vocab = VocabData(src_infos)
        self.tgt_vocab = VocabData(tgt_infos)
        self.batch_size = batch_size
        self.seq_length = max_length
        self.logger.info(f'Loading h5 files: {src_h5}, {tgt_h5}')

        # Load HDF5 data file.
        self.datasets = {}
        self.loaders = {}
        self.max_indices = { 'train': 0, 'val': 0, 'test': 0 }
        for split in ['train', 'val', 'test']:
            self.datasets[split] = ZippedDataset(
                'src': TextDataset(src_h5, split),
                'tgt': TextDataset(tgt_h5, split))
            sampler = \
                torch.utils.data.distributed.DistributedSampler(self.dataset[split])
            self.loaders[split] = torch.utils.data.DataLoader(
                self.datasets[split], batch_size=self.batch_size,
                shuffle=False, num_workers=1, pin_memory=True, sampler=sampler)
        self.logger.info(
            'Train:  {} | Dev: {} | Test: {}'.format(
                len(self.datasets['train']),
                len(self.datasets['val']),
                len(self.datasets['test'])))
        self.logger.warning(f'Reading sequences up to {self.seq_length}')

    def get_vocab_size(self):
        return self.vocab_size

    def get_vocab(self):
        return self.ix_to_word

    def get_seq_length(self):
        return self.seq_length


class H5Dataset(torch.utils.data.Dataset):
    """
    Text data iterator class
    """
    def __init__(self, h5_filename, dsname):
        super().__init__()
        self.data_info = {}

        self.h5_filename = h5_filename
        self.dsname = dsname
        with h5py.File(self.h5_filename, 'r', libver='latest', swmr=True) as h5_file:
            self.data_cache = np.array(h5_file[dsname])

    def __getitem__(self, index):
        return self.data_cache[index]
    
    def __len__(self):
        return len(self.data_cache)


class TextDataset(torch.utils.data.Dataset):
    def __init__(self, h5_filename, split):
        super().__init__()
        self.h5_filename = h5_filename

        self.label_ds = H5Dataset(self.h5_filename, f'labels_{split}')
        self.length_ds = H5Dataset(self.ds_filename, f'lengths_{split}')
    
    def __len__(self):
        return len(self.length_ds)


class ZippedDataset(torch.utils.data.Dataset):
    def __init__(self, dataset1, dataset2):
        super().__init__()
        assert(len(dataset1) == len(dataset2))
        self.dataset1 = dataset1
        self.dataset2 = dataset2

    def __getitem__(self, index):
        return (self.dataset1[index], self.dataset2[index])

    def __len__(self):
        return len(self.dataset1)

In [None]:
import time
import os
from nmt.params import parse_params, set_env
import torch


def train_worker(pindex, gpu_ids, params):
    """
    Adapted from https://github.com/elbayadm/attn2d/blob/master/train.py
    """
    model_name = params['model_name']
    logger = PervasiveLogger(model_name)

    os.environ['MASTER_ADDR'] = '127.0.0.1'
    os.environ['MASTER_PORT'] = '3892'
    if not os.getenv('RANK', None):
        os.environ['RANK'] = str(pindex)

    device_id = gpu_ids[pindex]
    device_name = torch.cuda.get_device_name(device_id)
    torch.cuda.set_device(device_id)
    logger.info(f'Device {pindex}: cuda({device_id}), {device_name}')

    project_dir = os.path.dirname(os.path.abspath(__file__))
    comm_file = f'{project_dir}/{model_name}/pgroup_shared'
    if os.path.isfile(comm_file):
        os.remove(comm_file)
    logger.info(f'Process communication file: {comm_file}')
    torch.distributed.init_process_group(
            backend='nccl',
            world_size=len(gpu_ids),
            rank=pindex,
            init_method=f'file://{comm_file}')

    if 'network' in params and 'type' in params['network']:
        net_type = params['network']['type']
    else:
        net_type = 'efficient-densenet'
    if 'encoder' in params:
        enc_in_dim = (
            params['encoder']['input'] if 'input' in params['encoder'] else 128)
        enc_dropout = (
            params['encoder']['dropout'] if 'dropout' in params['encoder'] else 0.2)
    if 'decoder' in params:
        dec_in_dim = (
            params['decoder']['input'] if 'input' in params['decoder'] else 128)
        dec_dropout = (
            params['decoder']['dropout'] if 'dropout' in params['decoder'] else 0.2)
        pred_dropout = (
            params['decoder']['prediction_dropout'] if 'prediction_dropout' in params['decoder'] else 0.2)
    model = Pervasive(model_name, net_type, Ts, Tt, special_tokens,
                      enc_in_dim, enc_dropout, dec_in_dim, dec_dropout,
                      pred_dropout)
    model.init_weights()
    model.cuda(device_id)
    model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device_id])
    batch_size = params['data']['batch_size'] / len(gpu_ids)

    criterion = loss.MLCriterion(model_name, params)
    criterion.log()
    trainer = Trainer(model_name, params, model, criterion)

    src_infos = os.path.join(params['data']['dir'], f'{params["src"]}.infos')
    tgt_infos = os.path.join(params['data']['dir'], f'{params["tgt"]}.infos')
    src_h5 = os.path.join(params['data']['dir'], f'{params["src"]}.h5')
    tgt_h5 = os.path.join(params['data']['dir'], f'{params["tgt"]}.h5')
    train_loader = TextDataLoader(src_infos, src_h5, tgt_infos, tgt_h5, batch_size,
                                  params['data']['max_length'], model_name):

    iters = trainer.load_checkpoint()
    if trainer.lr_patient:
        trainer.update_params()

    for epoch in range(params['optim']['max_epochs']):
        # Update learning rate.
        if not trainer.lr_patient:
            trainer.update_params()
        torch.cuda.synchronize()
        avg_loss = torch.zeros(1).cuda()
        avg_ml_loss = torch.zeros(1).cuda()
        total_ntokens = 0
        total_nseqs = 0
        start = time.time()
        for src_data, tgt_data in loader.loaders['train']:
            losses, batch_size, ntokens = trainer.step(src_data, tgt_data)
            avg_loss += ntokens * losses['final']
            avg_ml_loss += ntokens * losses['ml']
            total_nseqs += batch_size
            total_ntokens += ntokens
            avg_loss /= total_ntokens
            avg_ml_loss /= total_ntokens

        trainer.backward_step(avg_loss, avg_ml_loss,
                              total_ntokens, total_nseqs,
                              start, False)
        trainer.increment_time(time.time()-start)