#### Main notebook for training and saving siamese network.

In [1]:
from sn_dataloader import SNDataloader
from sn_dataset import SNDataset
from siamese_network import SiameseNetwork
import pytorch_lightning as pl
import functions_manager as fm
from pytorch_lightning import loggers as pl_loggers

# Create Function Manager instance
function_manager = fm.FunctionManager()

# Create an instance of the SiameseNetwork
model = SiameseNetwork()

# Initialize the network
model.init_network()


# Create the dataset
dataset = SNDataset(name_to_reference_map= function_manager.getNameToReference(),
                    positive_negative_function_map= function_manager.getPositiveNegativeFunctionMap())

# Create the dataloader
batch_size = 8
SNDataloader = SNDataloader(dataset, batch_size=batch_size, shuffle=True)
dataLoader = SNDataloader.getDataLoader()

# Create a logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")

# Train the network
k = 2
trainer = pl.Trainer(max_epochs=20, log_every_n_steps=k, logger=tb_logger)
trainer.fit(model, dataLoader)

Downloading (…)okenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/539 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/499M [00:00<?, ?B/s]

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/lightning_logs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 532 M 
---------------------------------------
532 M     Trainable params
0         Non-trainable params
532 M     Total params
2,130.457 Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

0 tensor(0., grad_fn=<MeanBackward0>)
1 tensor(0., grad_fn=<MeanBackward0>)
2 tensor(0., grad_fn=<MeanBackward0>)
3 tensor(0., grad_fn=<MeanBackward0>)
4 tensor(0., grad_fn=<MeanBackward0>)
5 tensor(0., grad_fn=<MeanBackward0>)
6 tensor(0., grad_fn=<MeanBackward0>)
7 tensor(0., grad_fn=<MeanBackward0>)
8 tensor(0., grad_fn=<MeanBackward0>)
9 tensor(0., grad_fn=<MeanBackward0>)
10 tensor(0., grad_fn=<MeanBackward0>)
11 tensor(0., grad_fn=<MeanBackward0>)
12 tensor(0., grad_fn=<MeanBackward0>)
13 tensor(0., grad_fn=<MeanBackward0>)
14 tensor(0., grad_fn=<MeanBackward0>)
15 tensor(0., grad_fn=<MeanBackward0>)
16 tensor(0., grad_fn=<MeanBackward0>)
17 tensor(0., grad_fn=<MeanBackward0>)
18 tensor(0., grad_fn=<MeanBackward0>)
19 tensor(0., grad_fn=<MeanBackward0>)
20 tensor(0., grad_fn=<MeanBackward0>)
21 tensor(0., grad_fn=<MeanBackward0>)
22 tensor(0., grad_fn=<MeanBackward0>)
23 tensor(0., grad_fn=<MeanBackward0>)
24 tensor(0., grad_fn=<MeanBackward0>)
25 tensor(0., grad_fn=<MeanBackward

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [2]:
# Test Similarity check
from code_embedding import CodeEmbedding

anchor_fun, positive_fun, negative_fun = dataset.samples[10000]
print(anchor_fun, positive_fun, negative_fun)

code_embedding = CodeEmbedding()
anchor = code_embedding.getPerfectFunctionEmbedding(anchor_fun)
positive = code_embedding.getPerfectFunctionEmbedding(positive_fun)
negative = code_embedding.getPerfectFunctionEmbedding(negative_fun)
# print(anchor, positive, negative)

print(model.similarity_inference(anchor, anchor))
print(model.similarity_inference(anchor, positive))
print(model.similarity_inference(anchor, negative))



<function MathFunctions.a_cubed_minus_b_cubed at 0x15a9f5e10> <function MathFunctions.a_minus_b_whole_cubed_minus_3ab_times_a_minus_b at 0x15a9f5ea0> <function MathFunctions.a_plus_b_whole_square at 0x15a9f5630>
tensor([2.7713e-05], grad_fn=<NormBackward1>)
tensor([4.5159], grad_fn=<NormBackward1>)
tensor([3.7603], grad_fn=<NormBackward1>)


In [3]:
# check training step
loss = model.training_step([anchor, positive, negative], 1)

1 tensor(1.3556, grad_fn=<MeanBackward0>)


  rank_zero_warn(


In [4]:
# Check Triplet loss
anchor_forward = model(anchor)
position_forward = model(positive)
negative_forward = model(negative)
# print(anchor_forward.shape)

triplet_loss = model.triplet_loss(anchor_forward, position_forward, negative_forward, margin=0.6)
print(triplet_loss)

tensor(1.4122, grad_fn=<MeanBackward0>)


In [5]:
# Check final function embedding size
final_func_embedding = code_embedding.getFinalFunctionEmbedding(anchor_forward, position_forward, dim=0)
code_embedding.getShape(final_func_embedding)

torch.Size([2, 768])
2


(torch.Size([2, 768]), 2)