In [1]:
import sys  
sys.path.insert(0, '..')

In [2]:
import jax

In [3]:
jax.device_count()

8

In [4]:
# requirements
from sklearn.model_selection import train_test_split
import gzip
from tqdm import tqdm
import numpy as np
import csv
import wandb
import json
import os

# for training script
from dataclasses import dataclass, field, asdict, replace
from functools import partial
from typing import Callable, List, Union

import jax
import jax.numpy as jnp
import optax
from flax import jax_utils, struct, traverse_util
from flax.training import train_state
from flax.serialization import to_bytes, from_bytes
from flax.training.common_utils import shard
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
from trainer.loss.custom import multiple_negatives_ranking_loss


from datasets import load_dataset, DatasetDict
from transformers import AutoTokenizer, FlaxAutoModel
import datasets


from trainer.utils.ops import normalize_L2, mean_pooling, cos_sim



In [30]:
@dataclass
class TrainingArgs:
    model_id: str = "microsoft/codebert-base"
    max_epochs: int = 20
    batch_size_per_device: int = 32
    seed: int = 42
    lr: float = 2e-5
    init_lr: float = 1e-5
    warmup_steps: int = 2000
    weight_decay: float = 1e-3

    input1_maxlen: int = 200
    input2_maxlen: int = 200
    
    logging_steps: int = 20
    save_dir: str = "checkpoints"
    save_dir_exp: str = "checkpoints"

    tr_data_files: List[str] = field(
        default_factory=lambda: [
            "tr.csv",
        ]
    )
        
    val_data_files: List[str] = field(
        default_factory=lambda: [
            "val.csv",
        ]
    )

    def __post_init__(self):
        self.batch_size = self.batch_size_per_device * jax.device_count()


In [8]:
class TrainState(train_state.TrainState):
    loss_fn: Callable = struct.field(pytree_node=False)
    scheduler_fn: Callable = struct.field(pytree_node=False)
    acc_fn: Callable = struct.field(pytree_node=False)


def warmup_and_constant(lr, init_lr, warmup_steps):
    warmup_fn = optax.linear_schedule(init_value=init_lr, end_value=lr, transition_steps=warmup_steps)
    constant_fn = optax.constant_schedule(value=lr)
    lr = optax.join_schedules(schedules=[warmup_fn, constant_fn], boundaries=[warmup_steps])
    return lr

def build_tx(lr, init_lr, warmup_steps, weight_decay):
    def weight_decay_mask(params):
        params = traverse_util.flatten_dict(params)
        mask = {k: (v[-1] != "bias" and v[-2:] != ("LayerNorm", "scale")) for k, v in params.items()}
        return traverse_util.unflatten_dict(mask)
    lr = warmup_and_constant(lr, init_lr, warmup_steps)
    tx = optax.adamw(learning_rate=lr, weight_decay=weight_decay, mask=weight_decay_mask)
    return tx, lr


In [9]:
@jax.jit
def batch_accuracy(embeddings_a: jnp.DeviceArray, embeddings_b: jnp.DeviceArray,
                   similarity_fct=cos_sim):
    """

    :param embeddings_a:
    :param embeddings_b: if passing additional hard negatives, use jnp.concatenate([positives, negatives], axis=0) as input.
    :param similarity_fct:
    :return:
    """
    assert (len(embeddings_a) <= len(embeddings_b))
    scores = similarity_fct(embeddings_a, embeddings_b)
    assert scores.shape == (len(embeddings_a), len(embeddings_b))

    indices = np.argmax(scores, axis=1)

    labels = jnp.arange(len(scores), dtype=jnp.int32)

    return np.sum(indices == labels)

    

In [10]:

@partial(jax.pmap, axis_name="batch")
def train_step(state, model_input1, model_input2, drp_rng):
    train = True
    new_drp_rng, drp_rng = jax.random.split(drp_rng, 2)

    def loss_fn(params, model_input1, model_input2, drp_rng):
        def _forward(model_input):
            attention_mask = model_input["attention_mask"]
            model_output = state.apply_fn(**model_input, params=params, train=train, dropout_rng=drp_rng)

            embedding = mean_pooling(model_output, attention_mask)
            embedding = normalize_L2(embedding)

            # gather all the embeddings on same device for calculation loss over global batch
            embedding = jax.lax.all_gather(embedding, axis_name="batch")
            embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

            return embedding

        embedding1, embedding2 = _forward(model_input1), _forward(model_input2)
        return state.loss_fn(embedding1, embedding2)

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params, model_input1, model_input2, drp_rng)
    state = state.apply_gradients(grads=grads)

    step = jax.lax.pmean(state.step, axis_name="batch")
    metrics = {"train_loss": loss, "lr": state.scheduler_fn(step)}

    return state, metrics, new_drp_rng

@partial(jax.pmap, axis_name="batch")
def val_step(state, model_inputs1, model_inputs2):
    train = False

    def _forward(model_input):
        attention_mask = model_input["attention_mask"][..., None]
        embedding = state.apply_fn(**model_input, params=state.params, train=train)[0]
        attention_mask = jnp.broadcast_to(attention_mask, jnp.shape(embedding))

        embedding = embedding * attention_mask
        embedding = jnp.mean(embedding, axis=1)

        modulus = jnp.sum(jnp.square(embedding), axis=-1, keepdims=True)
        embedding = embedding / jnp.maximum(modulus, 1e-12)

        # gather all the embeddings on same device for calculation loss over global batch
        embedding = jax.lax.all_gather(embedding, axis_name="batch")
        embedding = jnp.reshape(embedding, (-1, embedding.shape[-1]))

        return embedding

    embedding1, embedding2 = _forward(model_inputs1), _forward(model_inputs2)
    loss = state.loss_fn(embedding1, embedding2)
    acc = state.acc_fn(embedding1, embedding2)    
    return jnp.mean(loss), jnp.sum(acc)




In [11]:


def get_batched_dataset(dataset, batch_size, seed=None):
    if seed is not None:
        dataset = dataset.shuffle(seed=seed)
    for i in range(len(dataset) // batch_size):
        batch = dataset[i*batch_size: (i+1)*batch_size]
        yield dict(batch)


@dataclass
class DataCollator:
    tokenizer: Union[PreTrainedTokenizerFast, PreTrainedTokenizer]
    input1_maxlen: int = 128
    input2_maxlen: int = 128

    def __call__(self, batch):
        # Currently only static padding; TODO: change below for adding dynamic padding support
        model_input1 = self.tokenizer(batch["docstring"], return_tensors="jax", max_length=self.input1_maxlen, truncation=True, padding="max_length")
        model_input2 = self.tokenizer(batch["code"], return_tensors="jax", max_length=self.input2_maxlen, truncation=True, padding="max_length")
        model_input1, model_input2 = dict(model_input1), dict(model_input2)
        return shard(model_input1), shard(model_input2)


def save_checkpoint(save_dir, state, save_fn=None, training_args=None):
    print(f"saving checkpoint in {save_dir}", end=" ... ")

    os.makedirs(save_dir, exist_ok=True)
    state = jax_utils.unreplicate(state)

    if save_fn is not None:
        # saving model in HF fashion
        save_fn(save_dir, params=state.params)
    else:
        path = os.path.join(save_dir, "flax_model.msgpack")
        with open(path, "wb") as f:
            f.write(to_bytes(state.params))

    # this will save optimizer states
    path = os.path.join(save_dir, "opt_state.msgpack")
    with open(path, "wb") as f:
        f.write(to_bytes(state.opt_state))

    if training_args is not None:
        path = os.path.join(save_dir, "training_args.json")
        with open(path, "w") as f:
            json.dump(asdict(training_args), f)

    print("done!!")


def prepare_dataset(args):
    # tr_dataset = load_dataset("csv", data_files=args.tr_data_files, split="train")
    # val_dataset = load_dataset("csv", data_files=args.val_data_files, split="val")
    dataset = datasets.load_dataset("csv", data_files={
      "train": "../data/codesearchnet_train.csv.gz",
      "test": "../data/codesearchnet_test.csv.gz",
      "validation": "../data/codesearchnet_validation.csv.gz"
    })

    # columns_to_remove = ['repo', 'path', 'func_name', 'original_string', 'sha', 'url', 'partition']
    # dataset = dataset.remove_columns(columns_to_remove)

    # drop extra batch from the end
    for split in dataset:
        num_samples = len(dataset[split]) - len(dataset[split]) % args.batch_size
        dataset[split] = dataset[split].shuffle(seed=args.seed).select(range(num_samples))

    tr_dataset, val_dataset = dataset["train"], dataset["validation"]
    return tr_dataset, val_dataset



In [36]:
args = TrainingArgs()

logger = wandb.init(project="code-search-net", config=asdict(args))

logging_dict = dict(logger.config); logging_dict["save_dir_exp"] = f"{logging_dict['save_dir']}-{logger.id}"
args = replace(args, **logging_dict)

print(args)
#main(args, None)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


TrainingArgs(model_id='microsoft/codebert-base', max_epochs=20, batch_size_per_device=32, seed=42, lr=2e-05, init_lr=1e-05, warmup_steps=2000, weight_decay=0.001, input1_maxlen=200, input2_maxlen=200, logging_steps=20, save_dir='checkpoints', save_dir_exp='checkpoints-1xeeta0r', tr_data_files=['tr.csv'], val_data_files=['val.csv'])


In [37]:
model = FlaxAutoModel.from_pretrained(args.model_id)

In [38]:
tokenizer = AutoTokenizer.from_pretrained(args.model_id)

In [39]:
data_collator = DataCollator(
    tokenizer=tokenizer,
    input1_maxlen=args.input1_maxlen,
    input2_maxlen=args.input2_maxlen,
)


In [40]:
tx_args = {
    "lr": args.lr,
    "init_lr": args.init_lr,
    "warmup_steps": args.warmup_steps,
    "weight_decay": args.weight_decay,
}
tx, lr = build_tx(**tx_args)

In [41]:

state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=tx,
    loss_fn=multiple_negatives_ranking_loss,
    scheduler_fn=lr,
    acc_fn=batch_accuracy
)
state = jax_utils.replicate(state)

rng = jax.random.PRNGKey(args.seed)
drp_rng = jax.random.split(rng, jax.device_count())




In [42]:
tr_dataset, val_dataset = prepare_dataset(args)



In [43]:
def train_epoch(state, dataset, logger, args, drp_rng):
    # training step
    total = len(tr_dataset) // args.batch_size
    batch_iterator = get_batched_dataset(dataset, args.batch_size, seed=epoch)
    for i, batch in tqdm(enumerate(batch_iterator), desc=f"Running epoch-{epoch}", total=total):
        model_input1, model_input2 = data_collator(batch)
        state, metrics, drp_rng = train_step(state, model_input1, model_input2, drp_rng)

        #print("metrics", metrics)
        if (i + 1) % args.logging_steps == 0:
            train_loss = jax_utils.unreplicate(metrics["train_loss"]).item()
            # tqdm.write(str(dict(train_loss=train_loss, step=i+1)))
            logger.log({
                "train_loss": train_loss,
                "step": i + 1,
            }, commit=True)

    return state, drp_rng

def eval_epoch(state, val_dataset, logger, args, best_acc):
    # evaluation
    val_loss  = jnp.array(0.)
    val_acc  = jnp.array(0.)
    total = len(val_dataset) // args.batch_size
    val_batch_iterator = get_batched_dataset(val_dataset, args.batch_size, seed=None)
    for j, batch in tqdm(enumerate(val_batch_iterator), desc=f"Eval after epoch-{epoch}", total=total):
       model_input1, model_input2 = data_collator(batch)
       val_step_loss, val_step_acc = val_step(state, model_input1, model_input2)
       val_loss += jax_utils.unreplicate(val_step_loss)
       val_acc += jax_utils.unreplicate(val_step_acc)

    val_loss = val_loss.item() / (j + 1)
    val_acc = val_acc.item() / len(val_dataset)
    logger.log({"val_loss": val_loss, "val_acc": val_acc}, commit=True)
    
    if val_acc > best_acc:
        save_dir = os.path.join(args.save_dir, args.save_dir_exp, args.save_dir_exp + f"-epoch-{epoch}")
        save_checkpoint(save_dir, state, save_fn=model.save_pretrained, training_args=args)

        return val_acc
    
    return best_acc


In [44]:
best_acc = eval_epoch(state, val_dataset, logger, args, best_acc=jnp.array(0.))
for epoch in range(args.max_epochs):
    state, drp_rng = train_epoch(state, tr_dataset, logger, args, drp_rng)
    best_acc = eval_epoch(state, val_dataset, logger, args, best_acc)


Eval after epoch-0:   0%|          | 0/122 [00:00<?, ?it/s]

saving checkpoint in checkpoints/checkpoints-1xeeta0r/checkpoints-1xeeta0r-epoch-0 ... done!!


Running epoch-0:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-0:   0%|          | 0/122 [00:00<?, ?it/s]

saving checkpoint in checkpoints/checkpoints-1xeeta0r/checkpoints-1xeeta0r-epoch-0 ... done!!


Running epoch-1:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-1:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-2:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-2:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-3:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-3:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-4:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-4:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-5:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-5:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-6:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-6:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-7:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-7:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-8:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-8:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-9:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-9:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-10:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-10:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-11:   0%|          | 0/2664 [00:00<?, ?it/s]



Eval after epoch-11:   0%|          | 0/122 [00:00<?, ?it/s]

Running epoch-12:   0%|          | 0/2664 [00:00<?, ?it/s]



KeyboardInterrupt: 