In [158]:
import torch
import torch.nn as nn
from torchtext.data import Field, BucketIterator
from src.Self_attention_sequence_encoder import SelfAttentionEmbedder
from src.vocab_classes import BPE_Code_vocab
from src.trainers import Model_Trainer 
from src.useful_utils import load_CSN_data
from tokenizers import ByteLevelBPETokenizer
import tqdm.notebook as tqdm 
import numpy as np

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [176]:
CodeSearchNet_data_path = "/nfs/code_search_net_archive/python/final/jsonl/"
train_data, valid_data, test_data = load_CSN_data(CodeSearchNet_data_path)

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))




In [177]:
train_data = train_data[:100000]

In [178]:
train_pairs = []
for sample in train_data:
    query = " ".join(sample["docstring_tokens"])
    doc = sample["code"].replace(sample["docstring"], "")
    train_pairs.append((query, doc))        

In [179]:
tokenizer = ByteLevelBPETokenizer("datasets/code_search_net/code_bpe_hugging_32k-vocab.json",
                                  "datasets/code_search_net/code_bpe_hugging_32k-merges.txt", )

In [180]:
cutoff = 500
trunk_train_pairs = [(q,d) for q,d in tqdm.tqdm(train_pairs) if len(tokenizer.encode(d).ids)<cutoff and len(tokenizer.encode(q).ids)<cutoff]

HBox(children=(FloatProgress(value=0.0, max=100000.0), HTML(value='')))




In [10]:
vocab = BPE_Code_vocab()

In [250]:
model = SelfAttentionEmbedder(vocab_size=vocab.vocab_size, embed_dim=128, att_heads=8, layers=3, dim_feedforward=512, loss_type="softmax")
model.init_train_params(lr=0.0005)

In [228]:
dataset = model.data2dataset(trunk_train_pairs, vocab)

HBox(children=(FloatProgress(value=0.0, max=91196.0), HTML(value='')))




In [251]:
train_iterator = BucketIterator(
    dataset,
    batch_size = 32,
    repeat=True,
    shuffle=True,
    device = "cuda")

In [252]:
it = iter(train_iterator)


In [253]:
%%capture
model.to("cuda")

In [280]:
batch = next(it)
q_vec, d_vec = model(batch.query, batch.doc)
matrix = model.softmax_matrix(q_vec, d_vec)
matrix

tensor([[92.5424, 92.3537, 92.9671,  ..., 92.2242, 92.4776, 92.6549],
        [91.4828, 92.3798, 91.9793,  ..., 92.0625, 90.8034, 91.3626],
        [92.5334, 91.9231, 92.3086,  ..., 92.0379, 92.8420, 92.8789],
        ...,
        [92.7986, 92.9297, 93.1334,  ..., 92.7572, 92.6939, 92.6148],
        [91.7903, 91.4939, 91.4189,  ..., 91.4688, 92.0508, 92.0379],
        [92.7520, 92.2777, 92.7913,  ..., 92.2941, 93.0204, 92.9002]],
       device='cuda:0', grad_fn=<MmBackward>)

In [284]:
torch.argmax(matrix, dim=0)

tensor([ 8, 29, 25,  4,  4,  4,  4,  4, 29, 25, 29,  4,  4,  5, 29,  4, 25, 25,
        25,  4, 25,  4, 29,  4,  4,  4,  4, 25,  4, 25,  4,  4],
       device='cuda:0')

In [274]:
tmp_input = torch.tensor([[-1.,5],
                          [1,0]])
tmp_labels = torch.tensor([1,0])
model.cross_entropy(tmp_input, tmp_labels)

tensor([0.0025, 0.3133])

In [255]:
trainer = Model_Trainer(model, vocab)

'output_dir' not defined, training and model outputs won't be saved.


In [279]:
train_logs = trainer.train(model, train_iterator, 100000, save_interval=10000000, log_interval=50)

HBox(children=(FloatProgress(value=0.0, max=100000.0), HTML(value='')))

Finished training


## Writing the same loss function as CodeSearchNet in pytorch

In [44]:
import tensorflow as tf

In [130]:
tf_query_representations = tf.constant([[-1,3.],[3,4],[5,6]])
tf_code_representations = tf.constant([[1,2.],[3,4],[5,6]])
margin = 1.

tf_query_norms = tf.norm(tf_query_representations, axis=-1, keepdims=True) + 1e-10
tf_code_norms = tf.norm(tf_code_representations, axis=-1, keepdims=True) + 1e-10

tf_cosine_similarities = tf.matmul(tf_query_representations / tf_query_norms,
                                tf_code_representations / tf_code_norms,
                                transpose_a=False,
                                transpose_b=True,
                                name='code_query_cooccurrence_logits',
                                )  # B x B
tf_similarity_scores = tf_cosine_similarities

# A max-margin-like loss, but do not penalize negative cosine similarities.
tf_neg_matrix = tf.linalg.diag(tf.fill(dims=[tf.shape(tf_cosine_similarities)[0]], value=float('-inf')))
tf_per_sample_loss = tf.maximum(0., margin
                                 - tf.linalg.diag_part(tf_cosine_similarities)
                                 + tf.reduce_max(tf.nn.relu(tf_cosine_similarities + tf_neg_matrix),
                                                 axis=-1))

### PyTorch version

In [151]:
query_representations = torch.tensor([[-1,3.],[3,4],[5,6]])
code_representations = torch.tensor([[1,2.],[3,4],[5,6]])
margin = 1.

query_norm = torch.norm(query_representations, dim=-1, keepdim=True) + 1e-10
code_norm = torch.norm(code_representations, dim=-1, keepdim=True) + 1e-10

# query_vector = query_vector/query_norm
# doc_vector = doc_vector/doc_norm
# batch_size = query_vector.shape[0]

cosine_similarities = torch.mm(torch.div(query_representations,query_norm),
                                torch.div(code_representations,code_norm).T
                                )

neg_matrix = torch.diag(torch.full((cosine_similarities.shape[0],), float('-inf')))

good_sample_loss = torch.diagonal(cosine_similarities)
bad_sample_loss = torch.max(cosine_similarities + neg_matrix, dim=-1)[0]


per_sample_loss = torch.clamp(margin - good_sample_loss + bad_sample_loss, min=0.0)

## TF CodeSearchNet softmax

In [211]:
tf_logits = tf.matmul(tf_query_representations,
                   tf_code_representations,
                   transpose_a=False,
                   transpose_b=True,
                   name='code_query_cooccurrence_logits',
                   )  # B x B

similarity_scores = tf_logits

tf_per_sample_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=tf.range(tf.shape(tf_code_representations)[0]),  # [0, 1, 2, 3, ..., n]
    logits=tf_logits
)

In [214]:
tf.range(tf.shape(tf_code_representations)[0])

<tf.Tensor: id=575, shape=(3,), dtype=int32, numpy=array([0, 1, 2], dtype=int32)>

In [226]:
tf_per_sample_loss

<tf.Tensor: id=566, shape=(3,), dtype=float32, numpy=array([ 8.018479, 14.000001,  0.      ], dtype=float32)>

In [224]:
logits = torch.mm(query_representations, code_representations.T)
criterion = nn.CrossEntropyLoss(reduction='none')
labels = torch.arange(0,code_representations.shape[0])
per_sample_loss = criterion(logits, labels)

In [225]:
per_sample_loss

tensor([ 8.0185, 14.0000, -0.0000])

In [209]:
logits

tensor([[ 5.,  9., 13.],
        [11., 25., 39.],
        [17., 39., 61.]])