Skip to content

Commit

Permalink
Merge pull request #14 from infinitylogesh/main
Browse files Browse the repository at this point in the history
Training script for multi context training from ConveRT
  • Loading branch information
nreimers committed Jul 3, 2021
2 parents 9d5da09 + bfa9b91 commit 049d30a
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 0 deletions.
248 changes: 248 additions & 0 deletions conversational-model/multi_context_train.py
@@ -0,0 +1,248 @@
from dataclasses import dataclass, field
from functools import partial
from typing import Callable, List, Union

import jax
from jax.config import config
config.update('jax_enable_x64', True)
import jax.numpy as jnp
import optax
from flax import jax_utils, struct, traverse_util
from flax.training import train_state
from flax.training.common_utils import shard
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from transformers import AutoTokenizer, FlaxBertModel
from datasets import load_dataset




@dataclass
class TrainingArgs:
model_id: str = "bert-base-uncased"
max_epochs: int = 5
batch_size: int = 2
seed: int = 42
lr: float = 2e-5
init_lr: float = 1e-5
warmup_steps: int = 2000
weight_decay: float = 1e-3

current_context_maxlen: int = 128
past_context_maxlen: int = 128
response_maxlen:int = 128
max_past_contexts:int = 10

tr_data_files: List[str] = field(
default_factory=lambda: [
"data/dummy.jsonl",
]
)


def scheduler_fn(lr, init_lr, warmup_steps, num_train_steps):
decay_steps = num_train_steps - warmup_steps
warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
decay_fn = optax.linear_schedule(init_value=lr, end_value=1e-7, transition_steps=decay_steps)
lr = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[warmup_steps])
return lr


def build_tx(lr, init_lr, warmup_steps, num_train_steps, weight_decay):
def weight_decay_mask(params):
params = traverse_util.flatten_dict(params)
mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
return traverse_util.unflatten_dict(mask)
lr = scheduler_fn(lr, init_lr, warmup_steps, num_train_steps)
tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
return tx, lr


class TrainState(train_state.TrainState):
loss_fn: Callable = struct.field(pytree_node=False)
scheduler_fn: Callable = struct.field(pytree_node=False)


def multiple_negative_ranking_loss(embedding1, embedding2):
def _cross_entropy(logits):
bsz = logits.shape[-1]
labels = (jnp.arange(bsz)[..., None] == jnp.arange(bsz)[None]).astype("f4")
logits = jax.nn.log_softmax(logits, axis=-1)
loss = -jnp.sum(labels * logits, axis=-1)
return loss

batch_similarity = jnp.dot(embedding1, jnp.transpose(embedding2))
return _cross_entropy(batch_similarity)


@partial(jax.pmap, axis_name="batch")
def train_step(state, current_context_input, response_input, past_context_input, drp_rng):
train = True
new_drp_rng, drp_rng = jax.random.split(drp_rng, 2)

def loss_fn(params, current_context_input, response_input, past_context_input, drp_rng):
def _forward(model_input):
attention_mask = model_input["attention_mask"][..., None]
embedding = state.apply_fn(**model_input, params=params, train=train, dropout_rng=drp_rng)[0]
attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding))

embedding = embedding * attention_mask
embedding = jnp.mean(embedding, axis=1)

modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True)
embedding = embedding / jnp.maximum(modulus, 1e-12)

# gather all the embeddings on same device for calculation loss over global batch
embedding = jax.lax.all_gather(embedding, axis_name="batch")
embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

return embedding

current_context_emb, response_emb,past_context_emb = _forward(current_context_input), _forward(response_input),_forward(past_context_input)
full_context_emb = (current_context_emb + past_context_emb)/2
current_context_response_loss = state.loss_fn(current_context_emb, response_emb)
past_context_response_loss = state.loss_fn(past_context_emb,response_emb)
full_context_response_loss = state.loss_fn(full_context_emb,response_emb)
# loss considering
# 1) the interaction between the immediate context and its accompanying response,
# 2) the interaction of the response with up to N past contexts from the conversation history,
# as well as 3) the interaction of the full context with the response
loss = (current_context_response_loss + past_context_response_loss + full_context_response_loss) / 3
return jnp.mean(loss)

grad_fn = jax.value_and_grad(loss_fn)
loss, grads = grad_fn(state.params, current_context_input, response_input, past_context_input, drp_rng)
state = state.apply_gradients(grads=grads)

step = jax.lax.pmean(state.step, axis_name="batch")
metrics = {"tr_loss": loss, "lr": state.scheduler_fn(step)}

return state, metrics, new_drp_rng


def get_batched_dataset(dataset, batch_size, seed=None):
if seed is not None:
dataset = dataset.shuffle(seed=seed)
for i in range(len(dataset) // batch_size):
batch = dataset[i*batch_size: (i+1)*batch_size]
yield dict(batch)


@dataclass
class DataCollator:
tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer]
current_context_maxlen: int = 128
response_maxlen: int = 128
past_context_maxlen: int = 128
max_past_contexts:int = 10

def _prepare_past_context(self,batch):
"""
concatenation of past contexts - contexts are sorted to have the most recent context first and so on
"""
past_contexts = []
keys = list(batch.keys())
keys.sort()
past_context_tuples = zip(*[batch[key] for key in keys if key.startswith("context/")])
for past_context_tuple in past_context_tuples:
past_context_tuple = tuple(p_ctxt for p_ctxt in past_context_tuple if p_ctxt is not None)
past_contexts.append(" ".join(past_context_tuple[:self.max_past_contexts]))
return past_contexts

def __call__(self, batch):
# Currently only static padding; TODO: change below for adding dynamic padding support
past_contexts = self._prepare_past_context(batch)
current_context_input = self.tokenizer(batch["context"], return_tensors="jax", max_length=self.current_context_maxlen, truncation=True, padding="max_length")
response_input = self.tokenizer(batch["response"], return_tensors="jax", max_length=self.response_maxlen, truncation=True, padding="max_length")
past_context_input = self.tokenizer(past_contexts, return_tensors="jax",max_length=self.past_context_maxlen, truncation=True, padding="max_length")
current_context_input, response_input,past_context_input = dict(current_context_input), dict(response_input),dict(past_context_input)
return shard(current_context_input), shard(response_input), shard(past_context_input)


def main(args):
# code is generic to any other model as well
model = FlaxBertModel.from_pretrained(args.model_id)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)

data_collator = DataCollator(
tokenizer=tokenizer,
current_context_maxlen=args.current_context_maxlen,
response_maxlen=args.response_maxlen,
past_context_maxlen=args.past_context_maxlen,
max_past_contexts = args.max_past_contexts
)

tr_dataset = load_dataset("json", data_files=args.tr_data_files, split="train")
columns_to_remove = ['response_author', 'context_author', 'subreddit', 'thread_id']
tr_dataset = tr_dataset.remove_columns(columns_to_remove)
# drop extra batch from the end
num_tr_samples = len(tr_dataset) - len(tr_dataset) % args.batch_size
tr_dataset = tr_dataset.shuffle(seed=args.seed).select(range(num_tr_samples))
print(tr_dataset)

tx_args = {
"lr": args.lr,
"init_lr": args.init_lr,
"warmup_steps": args.warmup_steps,
"num_train_steps": (len(tr_dataset) // args.batch_size) * args.max_epochs,
"weight_decay": args.weight_decay,
}
tx, lr = build_tx(**tx_args)

state = TrainState.create(
apply_fn=model.__call__,
params=model.params,
tx=tx,
loss_fn=multiple_negative_ranking_loss,
scheduler_fn=lr,
)
state = jax_utils.replicate(state)

rng = jax.random.PRNGKey(args.seed)
drp_rng = jax.random.split(rng, jax.device_count())
for epoch in range(args.max_epochs):
# training step
batch_iterator = get_batched_dataset(tr_dataset, args.batch_size, seed=epoch)
for batch in tqdm(batch_iterator, desc=f"Running epoch-{epoch}"):
current_context_input, response_input,past_context_input = data_collator(batch)
state, metrics, drp_rng = train_step(state,current_context_input, response_input,past_context_input, drp_rng)
print(metrics)

# evaluation step
# for batch in get_batched_dataset(val_dataset, args.batch_size, seed=None):
# model_input1, model_input2 = data_collator(batch)
# state, metric = val_step(state, model_input1, model_input2)

if __name__ == "__main__":
import json,os

jsons = [{'context': 'Taste good though. ',
'context/0': 'Deer are, in fact, massive ****.',
'context/1': "Freaking deer. You can always tell the country people. They can't stand deer. Just giant garden destroying rats. \n\nHell, that's",
'context/3': 'I lived in Germany when I was 5. I have this memory of my dad stopping to pick up a hedgehog that was in the middle of the road',
'context/2': "Kinda like when visitors from more populated areas see deer here. They're not at all rare but people act like they are",
'context_author': 'KakarotMaag',
'response': "Ground venison mixed with beef fat is the best burger I've ever had.",
'response_author': 'SG804',
'subreddit': 'gifs',
'thread_id': '98fur0'},{
'context/1': "Hello, how are you?",
'context/0': "I am fine. And you?",
'context': "Great. What do you think of the weather?",
'response': "It doesn't feel like February.",
'context_author': 'KakarotMaag',
'response_author': 'SG804',
'subreddit': 'gifs',
'thread_id': '98fur0'
}]

os.makedirs("data/",exist_ok=True)
with open('data/dummy.jsonl', 'w') as outfile:
for entry in jsons:
json.dump(entry, outfile)
outfile.write('\n')

args = TrainingArgs()
main(args)

1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -21,6 +21,7 @@ sentence_transformers = "*"
py7zr = "*"
tqdm = "*"
loguru = "*"
datasets = "^1.8.0"

[tool.poetry.dev-dependencies]
pytest = "^3.4"
Expand Down

0 comments on commit 049d30a

Please sign in to comment.