<a href="https://colab.research.google.com/github/dotsnangles/Retrieval-Based-Chatbot/blob/main/Poly_encoder_Code_Splits.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers

[K     |████████████████████████████████| 4.4 MB 15.6 MB/s 
[K     |████████████████████████████████| 6.6 MB 49.8 MB/s 
[K     |████████████████████████████████| 596 kB 45.8 MB/s 
[K     |████████████████████████████████| 101 kB 12.2 MB/s 
[?25h

In [None]:
import os
import time
import json
import shutil
import argparse
import numpy as np
from tqdm import tqdm
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from torch.optim import Adam
from torch.utils.data import DataLoader

from transformers import BertModel, BertConfig, BertTokenizer, BertTokenizerFast
from transformers.optimization import AdamW, get_linear_schedule_with_warmup

# from dataset import SelectionDataset
# from transform import SelectionSequentialTransform, SelectionJoinTransform, SelectionConcatTransform
# from encoder import PolyEncoder, BiEncoder, CrossEncoder

from sklearn.metrics import label_ranking_average_precision_score

import logging

In [None]:
lst = [1,2,3,4,5]
lst[-3:]

[3, 4, 5]

In [None]:
class SelectionSequentialTransform(object):
    def __init__(self, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __call__(self, texts):
        input_ids_list, segment_ids_list, input_masks_list, contexts_masks_list = [], [], [], []
        for text in texts:
            tokenized_dict = self.tokenizer.encode_plus(text, max_length=self.max_len, padding='max_length', truncation=True)
            input_ids, input_masks = tokenized_dict['input_ids'], tokenized_dict['attention_mask']
            assert len(input_ids) == self.max_len
            assert len(input_masks) == self.max_len
            input_ids_list.append(input_ids)
            input_masks_list.append(input_masks)

        return input_ids_list, input_masks_list


class SelectionJoinTransform(object):
    def __init__(self, tokenizer, max_len):
        self.tokenizer = tokenizer
        self.max_len = max_len

        self.cls_id = self.tokenizer.convert_tokens_to_ids('[CLS]')
        self.sep_id = self.tokenizer.convert_tokens_to_ids('[SEP]')
        self.tokenizer.add_tokens(['\n'], special_tokens=True)
        self.pad_id = 0

    def __call__(self, texts):
        # another option is to use [SEP], but here we follow the discussion at:
        # https://github.com/facebookresearch/ParlAI/issues/2306#issuecomment-599180186
        context = '\n'.join(texts)
        tokenized_dict = self.tokenizer.encode_plus(context)
        input_ids, input_masks = tokenized_dict['input_ids'], tokenized_dict['attention_mask']
        input_ids = input_ids[-self.max_len:]
        input_ids[0] = self.cls_id
        input_masks = input_masks[-self.max_len:]
        input_ids += [self.pad_id] * (self.max_len - len(input_ids))
        input_masks += [0] * (self.max_len - len(input_masks))
        assert len(input_ids) == self.max_len
        assert len(input_masks) == self.max_len

        return input_ids, input_masks
    

class SelectionConcatTransform(object):
    def __init__(self, tokenizer, max_len):
        self.tokenizer = tokenizer
        # in cross encoder mode, we simply add max_contexts_length and max_response_length together to form max_len
        # this (in almost all cases) ensures all the response tokens are used and as many context tokens are used as possible
        # the intuition is that responses and the last few contexts are the most important
        self.max_len = max_len
        self.cls_id = self.tokenizer.convert_tokens_to_ids('[CLS]')
        self.sep_id = self.tokenizer.convert_tokens_to_ids('[SEP]')
        self.tokenizer.add_tokens(['\n'], special_tokens=True)
        self.pad_id = 0

    def __call__(self, context, responses):
        # another option is to use [SEP], but here we follow the discussion at:
        # https://github.com/facebookresearch/ParlAI/issues/2306#issuecomment-599180186
        context = '\n'.join(context)
        tokenized_dict = self.tokenizer.encode_plus(context)
        context_ids, context_masks, context_segment_ids = tokenized_dict['input_ids'], tokenized_dict['attention_mask'], tokenized_dict['token_type_ids']
        ret_input_ids = []
        ret_input_masks = []
        ret_segment_ids = []
        for response in responses:
            tokenized_dict = self.tokenizer.encode_plus(response)
            response_ids, response_masks, response_segment_ids = tokenized_dict['input_ids'], tokenized_dict['attention_mask'], tokenized_dict['token_type_ids']
            response_segment_ids = [1]*(len(response_segment_ids)-1)
            input_ids = context_ids + response_ids[1:]
            input_ids = input_ids[-self.max_len:]
            input_masks = context_masks + response_masks[1:]
            input_masks = input_masks[-self.max_len:]
            input_segment_ids = context_segment_ids + response_segment_ids
            input_segment_ids = input_segment_ids[-self.max_len:]
            input_ids[0] = self.cls_id
            input_ids += [self.pad_id] * (self.max_len - len(input_ids))
            input_masks += [0] * (self.max_len - len(input_masks))
            input_segment_ids += [0] * (self.max_len - len(input_segment_ids))
            assert len(input_ids) == self.max_len
            assert len(input_masks) == self.max_len
            assert len(input_segment_ids) == self.max_len
            ret_input_ids.append(input_ids)
            ret_input_masks.append(input_masks)
            ret_segment_ids.append(input_segment_ids)
        return ret_input_ids, ret_input_masks, ret_segment_ids

In [None]:
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
import os
import random
import pickle


class SelectionDataset(Dataset):
    def __init__(self, file_path, context_transform, response_transform, concat_transform, sample_cnt=None, mode='poly'):
        self.context_transform = context_transform
        self.response_transform = response_transform
        self.concat_transform = concat_transform
        self.data_source = []
        self.mode = mode
        neg_responses = []
        with open(file_path, encoding='utf-8') as f:
            group = {
                'context': None,
                'responses': [],
                'labels': []
            }
            for line in f:
                split = line.strip('\n').split('\t')
                lbl, context, response = int(split[0]), split[1:-1], split[-1]
                if lbl == 1 and len(group['responses']) > 0:
                    self.data_source.append(group)
                    group = {
                        'context': None,
                        'responses': [],
                        'labels': []
                    }
                    if sample_cnt is not None and len(self.data_source) >= sample_cnt:
                        break
                else:
                        neg_responses.append(response)
                group['responses'].append(response)
                group['labels'].append(lbl)
                group['context'] = context
            if len(group['responses']) > 0:
                self.data_source.append(group)

    def __len__(self):
        return len(self.data_source)

    def __getitem__(self, index):
        group = self.data_source[index]
        context, responses, labels = group['context'], group['responses'], group['labels']
        if self.mode == 'cross':
            transformed_text = self.concat_transform(context, responses)
            ret = transformed_text, labels
        else:
            transformed_context = self.context_transform(context)  # [token_ids],[seg_ids],[masks]
            transformed_responses = self.response_transform(responses)  # [token_ids],[seg_ids],[masks]
            ret = transformed_context, transformed_responses, labels

        return ret

    def batchify_join_str(self, batch):
        if self.mode == 'cross':
            text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch = [], [], []
            labels_batch = []
            for sample in batch:
                text_token_ids_list, text_input_masks_list, text_segment_ids_list = sample[0]

                text_token_ids_list_batch.append(text_token_ids_list)
                text_input_masks_list_batch.append(text_input_masks_list)
                text_segment_ids_list_batch.append(text_segment_ids_list)

                labels_batch.append(sample[1])

            long_tensors = [text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch]

            text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch = (
                torch.tensor(t, dtype=torch.long) for t in long_tensors)

            labels_batch = torch.tensor(labels_batch, dtype=torch.long)
            return text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch

        else:
            contexts_token_ids_list_batch, contexts_input_masks_list_batch, \
            responses_token_ids_list_batch, responses_input_masks_list_batch = [], [], [], []
            labels_batch = []
            for sample in batch:
                (contexts_token_ids_list, contexts_input_masks_list), (responses_token_ids_list, responses_input_masks_list) = sample[:2]

                contexts_token_ids_list_batch.append(contexts_token_ids_list)
                contexts_input_masks_list_batch.append(contexts_input_masks_list)

                responses_token_ids_list_batch.append(responses_token_ids_list)
                responses_input_masks_list_batch.append(responses_input_masks_list)

                labels_batch.append(sample[-1])

            long_tensors = [contexts_token_ids_list_batch, contexts_input_masks_list_batch,
                                            responses_token_ids_list_batch, responses_input_masks_list_batch]

            contexts_token_ids_list_batch, contexts_input_masks_list_batch, \
            responses_token_ids_list_batch, responses_input_masks_list_batch = (
                torch.tensor(t, dtype=torch.long) for t in long_tensors)

            labels_batch = torch.tensor(labels_batch, dtype=torch.long)
            return contexts_token_ids_list_batch, contexts_input_masks_list_batch, \
                          responses_token_ids_list_batch, responses_input_masks_list_batch, labels_batch

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertPreTrainedModel, BertModel


class BiEncoder(BertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = kwargs['bert']

    def forward(self, context_input_ids, context_input_masks,
                            responses_input_ids, responses_input_masks, labels=None):
        ## only select the first response (whose lbl==1)
        if labels is not None:
            responses_input_ids = responses_input_ids[:, 0, :].unsqueeze(1)
            responses_input_masks = responses_input_masks[:, 0, :].unsqueeze(1)

        context_vec = self.bert(context_input_ids, context_input_masks)[0][:,0,:]  # [bs,dim]

        batch_size, res_cnt, seq_length = responses_input_ids.shape
        responses_input_ids = responses_input_ids.view(-1, seq_length)
        responses_input_masks = responses_input_masks.view(-1, seq_length)

        responses_vec = self.bert(responses_input_ids, responses_input_masks)[0][:,0,:]  # [bs,dim]
        responses_vec = responses_vec.view(batch_size, res_cnt, -1)

        if labels is not None:
            responses_vec = responses_vec.squeeze(1)
            dot_product = torch.matmul(context_vec, responses_vec.t())  # [bs, bs]
            mask = torch.eye(context_input_ids.size(0)).to(context_input_ids.device)
            loss = F.log_softmax(dot_product, dim=-1) * mask
            loss = (-loss.sum(dim=1)).mean()
            return loss
        else:
            context_vec = context_vec.unsqueeze(1)
            dot_product = torch.matmul(context_vec, responses_vec.permute(0, 2, 1)).squeeze()
            return dot_product


class CrossEncoder(BertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = kwargs['bert']
        self.linear = nn.Linear(config.hidden_size, 1)

    def forward(self, text_input_ids, text_input_masks, text_input_segments, labels=None):
        batch_size, neg, dim = text_input_ids.shape
        text_input_ids = text_input_ids.reshape(-1, dim)
        text_input_masks = text_input_masks.reshape(-1, dim)
        text_input_segments = text_input_segments.reshape(-1, dim)
        text_vec = self.bert(text_input_ids, text_input_masks, text_input_segments)[0][:,0,:]  # [bs,dim]
        score = self.linear(text_vec)
        score = score.view(-1, neg)
        if labels is not None:
            loss = -F.log_softmax(score, -1)[:,0].mean()
            return loss
        else:
            return score


class PolyEncoder(BertPreTrainedModel):
    def __init__(self, config, *inputs, **kwargs):
        super().__init__(config, *inputs, **kwargs)
        self.bert = kwargs['bert']
        self.poly_m = kwargs['poly_m']
        self.poly_code_embeddings = nn.Embedding(self.poly_m, config.hidden_size)
        # https://github.com/facebookresearch/ParlAI/blob/master/parlai/agents/transformer/polyencoder.py#L355
        torch.nn.init.normal_(self.poly_code_embeddings.weight, config.hidden_size ** -0.5)

    def dot_attention(self, q, k, v):
        # q: [bs, poly_m, dim] or [bs, res_cnt, dim]
        # k=v: [bs, length, dim] or [bs, poly_m, dim]
        attn_weights = torch.matmul(q, k.transpose(2, 1)) # [bs, poly_m, length]
        attn_weights = F.softmax(attn_weights, -1)
        output = torch.matmul(attn_weights, v) # [bs, poly_m, dim]
        return output

    def forward(self, context_input_ids, context_input_masks,
                            responses_input_ids, responses_input_masks, labels=None):
        # during training, only select the first response
        # we are using other instances in a batch as negative examples
        if labels is not None:
            responses_input_ids = responses_input_ids[:, 0, :].unsqueeze(1)
            responses_input_masks = responses_input_masks[:, 0, :].unsqueeze(1)
        batch_size, res_cnt, seq_length = responses_input_ids.shape # res_cnt is 1 during training

        # context encoder
        ctx_out = self.bert(context_input_ids, context_input_masks)[0]  # [bs, length, dim]
        poly_code_ids = torch.arange(self.poly_m, dtype=torch.long).to(context_input_ids.device)
        poly_code_ids = poly_code_ids.unsqueeze(0).expand(batch_size, self.poly_m)
        poly_codes = self.poly_code_embeddings(poly_code_ids) # [bs, poly_m, dim]
        embs = self.dot_attention(poly_codes, ctx_out, ctx_out) # [bs, poly_m, dim]

        # response encoder
        responses_input_ids = responses_input_ids.view(-1, seq_length)
        responses_input_masks = responses_input_masks.view(-1, seq_length)
        cand_emb = self.bert(responses_input_ids, responses_input_masks)[0][:,0,:] # [bs, dim]
        cand_emb = cand_emb.view(batch_size, res_cnt, -1) # [bs, res_cnt, dim]

        # merge
        if labels is not None:
            # we are recycling responses for faster training
            # we repeat responses for batch_size times to simulate test phase
            # so that every context is paired with batch_size responses
            cand_emb = cand_emb.permute(1, 0, 2) # [1, bs, dim]
            cand_emb = cand_emb.expand(batch_size, batch_size, cand_emb.shape[2]) # [bs, bs, dim]
            ctx_emb = self.dot_attention(cand_emb, embs, embs).squeeze() # [bs, bs, dim]
            dot_product = (ctx_emb*cand_emb).sum(-1) # [bs, bs]
            mask = torch.eye(batch_size).to(context_input_ids.device) # [bs, bs]
            loss = F.log_softmax(dot_product, dim=-1) * mask
            loss = (-loss.sum(dim=1)).mean()
            return loss
        else:
            ctx_emb = self.dot_attention(cand_emb, embs, embs) # [bs, res_cnt, dim]
            dot_product = (ctx_emb*cand_emb).sum(-1)
            return dot_product

In [None]:
logging.basicConfig(level=logging.ERROR)

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

def eval_running_model(dataloader, test=False):
    model.eval()
    eval_loss, eval_hit_times = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    r10 = r2 = r1 = r5 = 0
    mrr = []
    for step, batch in enumerate(dataloader):
        batch = tuple(t.to(device) for t in batch)
        if args.architecture == 'cross':
            text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch = batch
            with torch.no_grad():
                logits = model(text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch)
                loss = F.cross_entropy(logits, torch.argmax(labels_batch, 1))
        else:
            context_token_ids_list_batch, context_input_masks_list_batch, \
            response_token_ids_list_batch, response_input_masks_list_batch, labels_batch = batch
            with torch.no_grad():
                logits = model(context_token_ids_list_batch, context_input_masks_list_batch,
                                              response_token_ids_list_batch, response_input_masks_list_batch)
                loss = F.cross_entropy(logits, torch.argmax(labels_batch, 1))
        r2_indices = torch.topk(logits, 2)[1] # R 2 @ 100
        r5_indices = torch.topk(logits, 5)[1] # R 5 @ 100
        r10_indices = torch.topk(logits, 10)[1] # R 10 @ 100
        r1 += (logits.argmax(-1) == 0).sum().item()
        r2 += ((r2_indices==0).sum(-1)).sum().item()
        r5 += ((r5_indices==0).sum(-1)).sum().item()
        r10 += ((r10_indices==0).sum(-1)).sum().item()
        # mrr
        logits = logits.data.cpu().numpy()
        for logit in logits:
            y_true = np.zeros(len(logit))
            y_true[0] = 1
            mrr.append(label_ranking_average_precision_score([y_true], [logit]))
        eval_loss += loss.item()
        nb_eval_examples += labels_batch.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = r1 / nb_eval_examples
    if not test:
        result = {
            'train_loss': tr_loss / nb_tr_steps,
            'eval_loss': eval_loss,
            'R1': r1 / nb_eval_examples,
            'R2': r2 / nb_eval_examples,
            'R5': r5 / nb_eval_examples,
            'R10': r10 / nb_eval_examples,
            'MRR': np.mean(mrr),
            'epoch': epoch,
            'global_step': global_step,
        }
    else:
        result = {
            'eval_loss': eval_loss,
            'R1': r1 / nb_eval_examples,
            'R2': r2 / nb_eval_examples,
            'R5': r5 / nb_eval_examples,
            'R10': r10 / nb_eval_examples,
            'MRR': np.mean(mrr),
        }

    return result

In [None]:
args = {
        "bert_model": 'ckpt/pretrained/bert-small-uncased',
        "eval": "store_true",
        "model_type": 'bert',
        "output_dir": 'output',
        "train_dir": 'data/ubuntu_data',

        "use_pretrain": "store_true",
        "architecture": 'poly',

        "max_contexts_length": 128,
        "max_response_length": 32,
        "train_batch_size": 32,
        "eval_batch_size": 32,
        "print_freq": 100,

        "poly_m": 0,

        "learning_rate": 5e-5,
        "weight_decay": 0.01,
        "warmup_steps": 100,
        "adam_epsilon": 1e-8,
        "max_grad_norm": 1.0,

        "num_train_epochs": 10.0,
        'seed': 12345,
        'gradient_accumulation_steps': 1,
        "fp16": "store_true",
        "fp16_opt_level": "O1",
        'gpu': 0
        }

In [None]:
from easydict import EasyDict as edict
args = edict(args)

In [None]:
os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % args.gpu
set_seed(args)

MODEL_CLASSES = {
    'bert': (BertConfig, BertTokenizerFast, BertModel),
}

ConfigClass, TokenizerClass, BertModelClass = MODEL_CLASSES[args.model_type]

## init dataset and bert model
tokenizer = TokenizerClass.from_pretrained(args.bert_model, vocab_file=os.path.join(args.bert_model, "vocab.txt"), do_lower_case=True, clean_text=False)
context_transform = SelectionJoinTransform(tokenizer=tokenizer, max_len=args.max_contexts_length)
response_transform = SelectionSequentialTransform(tokenizer=tokenizer, max_len=args.max_response_length)
concat_transform = SelectionConcatTransform(tokenizer=tokenizer, max_len=args.max_response_length+args.max_contexts_length)

print('=' * 80)
print('Train dir:', args.train_dir)
print('Output dir:', args.output_dir)
print('=' * 80)

if not args.eval:
    train_dataset = SelectionDataset(os.path.join(args.train_dir, 'train.txt'),
                                                                    context_transform, response_transform, concat_transform, sample_cnt=None, mode=args.architecture)
    val_dataset = SelectionDataset(os.path.join(args.train_dir, 'dev.txt'),
                                                                context_transform, response_transform, concat_transform, sample_cnt=1000, mode=args.architecture)
    train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, collate_fn=train_dataset.batchify_join_str, shuffle=True, num_workers=0)
    t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
else: # test
    val_dataset = SelectionDataset(os.path.join(args.train_dir, 'test.txt'),
                                                                context_transform, response_transform, concat_transform, sample_cnt=None, mode=args.architecture)

val_dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, collate_fn=val_dataset.batchify_join_str, shuffle=False, num_workers=0)


epoch_start = 1
global_step = 0
best_eval_loss = float('inf')
best_test_loss = float('inf')

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)
shutil.copyfile(os.path.join(args.bert_model, 'vocab.txt'), os.path.join(args.output_dir, 'vocab.txt'))
shutil.copyfile(os.path.join(args.bert_model, 'config.json'), os.path.join(args.output_dir, 'config.json'))
log_wf = open(os.path.join(args.output_dir, 'log.txt'), 'a', encoding='utf-8')
print (args, file=log_wf)

state_save_path = os.path.join(args.output_dir, '{}_{}_pytorch_model.bin'.format(args.architecture, args.poly_m))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

########################################
## build BERT encoder
########################################
bert_config = ConfigClass.from_json_file(os.path.join(args.bert_model, 'config.json'))
if args.use_pretrain and not args.eval:
    previous_model_file = os.path.join(args.bert_model, "pytorch_model.bin")
    print('Loading parameters from', previous_model_file)
    log_wf.write('Loading parameters from %s' % previous_model_file + '\n')
    model_state_dict = torch.load(previous_model_file, map_location="cpu")
    bert = BertModelClass.from_pretrained(args.bert_model, state_dict=model_state_dict)
    del model_state_dict
else:
    bert = BertModelClass(bert_config)

if args.architecture == 'poly':
    model = PolyEncoder(bert_config, bert=bert, poly_m=args.poly_m)
elif args.architecture == 'bi':
    model = BiEncoder(bert_config, bert=bert)
elif args.architecture == 'cross':
    model = CrossEncoder(bert_config, bert=bert)
else:
    raise Exception('Unknown architecture.')
model.resize_token_embeddings(len(tokenizer)) 
model.to(device)

if args.eval:
    print('Loading parameters from', state_save_path)
    model.load_state_dict(torch.load(state_save_path))
    test_result = eval_running_model(val_dataloader, test=True)
    print (test_result)
    exit()
    
no_decay = ["bias", "LayerNorm.weight"]

optimizer_grouped_parameters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": args.weight_decay,
    },
    {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(
    optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.fp16:
    try:
        from apex import amp
    except ImportError:
        raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
    model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

print_freq = args.print_freq//args.gradient_accumulation_steps
eval_freq = min(len(train_dataloader) // 2, 1000)
eval_freq = eval_freq//args.gradient_accumulation_steps
print('Print freq:', print_freq, "Eval freq:", eval_freq)

for epoch in range(epoch_start, int(args.num_train_epochs) + 1):
    tr_loss = 0
    nb_tr_steps = 0
    with tqdm(total=len(train_dataloader)//args.gradient_accumulation_steps) as bar:
        for step, batch in enumerate(train_dataloader):
            model.train()
            optimizer.zero_grad()
            batch = tuple(t.to(device) for t in batch)
            if args.architecture == 'cross':
                text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch = batch
                loss = model(text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch)
            else:
                context_token_ids_list_batch, context_input_masks_list_batch, \
                response_token_ids_list_batch, response_input_masks_list_batch, labels_batch = batch
                loss = model(context_token_ids_list_batch, context_input_masks_list_batch,
                                        response_token_ids_list_batch, response_input_masks_list_batch,
                                        labels_batch)

            loss = loss / args.gradient_accumulation_steps
            
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()
            
            tr_loss += loss.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                nb_tr_steps += 1
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if nb_tr_steps and nb_tr_steps % print_freq == 0:
                    bar.update(min(print_freq, nb_tr_steps))
                    time.sleep(0.02)
                    print(global_step, tr_loss / nb_tr_steps)
                    log_wf.write('%d\t%f\n' % (global_step, tr_loss / nb_tr_steps))

                if global_step and global_step % eval_freq == 0:
                    val_result = eval_running_model(val_dataloader)
                    print('Global Step %d VAL res:\n' % global_step, val_result)
                    log_wf.write('Global Step %d VAL res:\n' % global_step)
                    log_wf.write(str(val_result) + '\n')

                    if val_result['eval_loss'] < best_eval_loss:
                        best_eval_loss = val_result['eval_loss']
                        val_result['best_eval_loss'] = best_eval_loss
                        # save model
                        print('[Saving at]', state_save_path)
                        log_wf.write('[Saving at] %s\n' % state_save_path)
                        torch.save(model.state_dict(), state_save_path)
            log_wf.flush()

    # add a eval step after each epoch
    val_result = eval_running_model(val_dataloader)
    print('Epoch %d, Global Step %d VAL res:\n' % (epoch, global_step), val_result)
    log_wf.write('Global Step %d VAL res:\n' % global_step)
    log_wf.write(str(val_result) + '\n')

    if val_result['eval_loss'] < best_eval_loss:
        best_eval_loss = val_result['eval_loss']
        val_result['best_eval_loss'] = best_eval_loss
        # save model
        print('[Saving at]', state_save_path)
        log_wf.write('[Saving at] %s\n' % state_save_path)
        torch.save(model.state_dict(), state_save_path)
    print(global_step, tr_loss / nb_tr_steps)
    log_wf.write('%d\t%f\n' % (global_step, tr_loss / nb_tr_steps))

### Run.py BK

In [None]:
logging.basicConfig(level=logging.ERROR)

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

def eval_running_model(dataloader, test=False):
    model.eval()
    eval_loss, eval_hit_times = 0, 0
    nb_eval_steps, nb_eval_examples = 0, 0
    r10 = r2 = r1 = r5 = 0
    mrr = []
    for step, batch in enumerate(dataloader):
        batch = tuple(t.to(device) for t in batch)
        if args.architecture == 'cross':
            text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch = batch
            with torch.no_grad():
                logits = model(text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch)
                loss = F.cross_entropy(logits, torch.argmax(labels_batch, 1))
        else:
            context_token_ids_list_batch, context_input_masks_list_batch, \
            response_token_ids_list_batch, response_input_masks_list_batch, labels_batch = batch
            with torch.no_grad():
                logits = model(context_token_ids_list_batch, context_input_masks_list_batch,
                                              response_token_ids_list_batch, response_input_masks_list_batch)
                loss = F.cross_entropy(logits, torch.argmax(labels_batch, 1))
        r2_indices = torch.topk(logits, 2)[1] # R 2 @ 100
        r5_indices = torch.topk(logits, 5)[1] # R 5 @ 100
        r10_indices = torch.topk(logits, 10)[1] # R 10 @ 100
        r1 += (logits.argmax(-1) == 0).sum().item()
        r2 += ((r2_indices==0).sum(-1)).sum().item()
        r5 += ((r5_indices==0).sum(-1)).sum().item()
        r10 += ((r10_indices==0).sum(-1)).sum().item()
        # mrr
        logits = logits.data.cpu().numpy()
        for logit in logits:
            y_true = np.zeros(len(logit))
            y_true[0] = 1
            mrr.append(label_ranking_average_precision_score([y_true], [logit]))
        eval_loss += loss.item()
        nb_eval_examples += labels_batch.size(0)
        nb_eval_steps += 1
    eval_loss = eval_loss / nb_eval_steps
    eval_accuracy = r1 / nb_eval_examples
    if not test:
        result = {
            'train_loss': tr_loss / nb_tr_steps,
            'eval_loss': eval_loss,
            'R1': r1 / nb_eval_examples,
            'R2': r2 / nb_eval_examples,
            'R5': r5 / nb_eval_examples,
            'R10': r10 / nb_eval_examples,
            'MRR': np.mean(mrr),
            'epoch': epoch,
            'global_step': global_step,
        }
    else:
        result = {
            'eval_loss': eval_loss,
            'R1': r1 / nb_eval_examples,
            'R2': r2 / nb_eval_examples,
            'R5': r5 / nb_eval_examples,
            'R10': r10 / nb_eval_examples,
            'MRR': np.mean(mrr),
        }

    return result


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--bert_model", default='ckpt/pretrained/bert-small-uncased', type=str)
    parser.add_argument("--eval", action="store_true")
    parser.add_argument("--model_type", default='bert', type=str)
    parser.add_argument("--output_dir", required=True, type=str)
    parser.add_argument("--train_dir", default='data/ubuntu_data', type=str)

    parser.add_argument("--use_pretrain", action="store_true")
    parser.add_argument("--architecture", required=True, type=str, help='[poly, bi, cross]')

    parser.add_argument("--max_contexts_length", default=128, type=int)
    parser.add_argument("--max_response_length", default=32, type=int)
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--eval_batch_size", default=32, type=int, help="Total batch size for eval.")
    parser.add_argument("--print_freq", default=100, type=int, help="Log frequency")

    parser.add_argument("--poly_m", default=0, type=int, help="Number of m of polyencoder")

    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.01, type=float)
    parser.add_argument("--warmup_steps", default=100, type=float)
    parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")

    parser.add_argument("--num_train_epochs", default=10.0, type=float,
                                            help="Total number of training epochs to perform.")
    parser.add_argument('--seed', type=int, default=12345, help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                                            help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit",
    )
    parser.add_argument(
        "--fp16_opt_level",
        type=str,
        default="O1",
        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                  "See details at https://nvidia.github.io/apex/amp.html",
    )
    parser.add_argument('--gpu', type=int, default=0)
    args = parser.parse_args()
    print(args)
    os.environ["CUDA_VISIBLE_DEVICES"] = "%d" % args.gpu
    set_seed(args)

    MODEL_CLASSES = {
        'bert': (BertConfig, BertTokenizerFast, BertModel),
    }
    ConfigClass, TokenizerClass, BertModelClass = MODEL_CLASSES[args.model_type]

    ## init dataset and bert model
    tokenizer = TokenizerClass.from_pretrained(args.bert_model, vocab_file=os.path.join(args.bert_model, "vocab.txt"), do_lower_case=True, clean_text=False)
    context_transform = SelectionJoinTransform(tokenizer=tokenizer, max_len=args.max_contexts_length)
    response_transform = SelectionSequentialTransform(tokenizer=tokenizer, max_len=args.max_response_length)
    concat_transform = SelectionConcatTransform(tokenizer=tokenizer, max_len=args.max_response_length+args.max_contexts_length)

    print('=' * 80)
    print('Train dir:', args.train_dir)
    print('Output dir:', args.output_dir)
    print('=' * 80)

    if not args.eval:
        train_dataset = SelectionDataset(os.path.join(args.train_dir, 'train.txt'),
                                                                      context_transform, response_transform, concat_transform, sample_cnt=None, mode=args.architecture)
        val_dataset = SelectionDataset(os.path.join(args.train_dir, 'dev.txt'),
                                                                  context_transform, response_transform, concat_transform, sample_cnt=1000, mode=args.architecture)
        train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size, collate_fn=train_dataset.batchify_join_str, shuffle=True, num_workers=0)
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
    else: # test
        val_dataset = SelectionDataset(os.path.join(args.train_dir, 'test.txt'),
                                                                  context_transform, response_transform, concat_transform, sample_cnt=None, mode=args.architecture)

    val_dataloader = DataLoader(val_dataset, batch_size=args.eval_batch_size, collate_fn=val_dataset.batchify_join_str, shuffle=False, num_workers=0)


    epoch_start = 1
    global_step = 0
    best_eval_loss = float('inf')
    best_test_loss = float('inf')

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    shutil.copyfile(os.path.join(args.bert_model, 'vocab.txt'), os.path.join(args.output_dir, 'vocab.txt'))
    shutil.copyfile(os.path.join(args.bert_model, 'config.json'), os.path.join(args.output_dir, 'config.json'))
    log_wf = open(os.path.join(args.output_dir, 'log.txt'), 'a', encoding='utf-8')
    print (args, file=log_wf)

    state_save_path = os.path.join(args.output_dir, '{}_{}_pytorch_model.bin'.format(args.architecture, args.poly_m))
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    ########################################
    ## build BERT encoder
    ########################################
    bert_config = ConfigClass.from_json_file(os.path.join(args.bert_model, 'config.json'))
    if args.use_pretrain and not args.eval:
        previous_model_file = os.path.join(args.bert_model, "pytorch_model.bin")
        print('Loading parameters from', previous_model_file)
        log_wf.write('Loading parameters from %s' % previous_model_file + '\n')
        model_state_dict = torch.load(previous_model_file, map_location="cpu")
        bert = BertModelClass.from_pretrained(args.bert_model, state_dict=model_state_dict)
        del model_state_dict
    else:
        bert = BertModelClass(bert_config)

    if args.architecture == 'poly':
        model = PolyEncoder(bert_config, bert=bert, poly_m=args.poly_m)
    elif args.architecture == 'bi':
        model = BiEncoder(bert_config, bert=bert)
    elif args.architecture == 'cross':
        model = CrossEncoder(bert_config, bert=bert)
    else:
        raise Exception('Unknown architecture.')
    model.resize_token_embeddings(len(tokenizer)) 
    model.to(device)
    
    if args.eval:
        print('Loading parameters from', state_save_path)
        model.load_state_dict(torch.load(state_save_path))
        test_result = eval_running_model(val_dataloader, test=True)
        print (test_result)
        exit()
        
    no_decay = ["bias", "LayerNorm.weight"]
    
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": args.weight_decay,
        },
        {"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
    )
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    print_freq = args.print_freq//args.gradient_accumulation_steps
    eval_freq = min(len(train_dataloader) // 2, 1000)
    eval_freq = eval_freq//args.gradient_accumulation_steps
    print('Print freq:', print_freq, "Eval freq:", eval_freq)

    for epoch in range(epoch_start, int(args.num_train_epochs) + 1):
        tr_loss = 0
        nb_tr_steps = 0
        with tqdm(total=len(train_dataloader)//args.gradient_accumulation_steps) as bar:
            for step, batch in enumerate(train_dataloader):
                model.train()
                optimizer.zero_grad()
                batch = tuple(t.to(device) for t in batch)
                if args.architecture == 'cross':
                    text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch = batch
                    loss = model(text_token_ids_list_batch, text_input_masks_list_batch, text_segment_ids_list_batch, labels_batch)
                else:
                    context_token_ids_list_batch, context_input_masks_list_batch, \
                    response_token_ids_list_batch, response_input_masks_list_batch, labels_batch = batch
                    loss = model(context_token_ids_list_batch, context_input_masks_list_batch,
                                          response_token_ids_list_batch, response_input_masks_list_batch,
                                          labels_batch)

                loss = loss / args.gradient_accumulation_steps
                
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()
                
                tr_loss += loss.item()

                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    nb_tr_steps += 1
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
                    global_step += 1

                    if nb_tr_steps and nb_tr_steps % print_freq == 0:
                        bar.update(min(print_freq, nb_tr_steps))
                        time.sleep(0.02)
                        print(global_step, tr_loss / nb_tr_steps)
                        log_wf.write('%d\t%f\n' % (global_step, tr_loss / nb_tr_steps))

                    if global_step and global_step % eval_freq == 0:
                        val_result = eval_running_model(val_dataloader)
                        print('Global Step %d VAL res:\n' % global_step, val_result)
                        log_wf.write('Global Step %d VAL res:\n' % global_step)
                        log_wf.write(str(val_result) + '\n')

                        if val_result['eval_loss'] < best_eval_loss:
                            best_eval_loss = val_result['eval_loss']
                            val_result['best_eval_loss'] = best_eval_loss
                            # save model
                            print('[Saving at]', state_save_path)
                            log_wf.write('[Saving at] %s\n' % state_save_path)
                            torch.save(model.state_dict(), state_save_path)
                log_wf.flush()

        # add a eval step after each epoch
        val_result = eval_running_model(val_dataloader)
        print('Epoch %d, Global Step %d VAL res:\n' % (epoch, global_step), val_result)
        log_wf.write('Global Step %d VAL res:\n' % global_step)
        log_wf.write(str(val_result) + '\n')

        if val_result['eval_loss'] < best_eval_loss:
            best_eval_loss = val_result['eval_loss']
            val_result['best_eval_loss'] = best_eval_loss
            # save model
            print('[Saving at]', state_save_path)
            log_wf.write('[Saving at] %s\n' % state_save_path)
            torch.save(model.state_dict(), state_save_path)
        print(global_step, tr_loss / nb_tr_steps)
        log_wf.write('%d\t%f\n' % (global_step, tr_loss / nb_tr_steps))