#### Main notebook for training and saving siamese network.

In [1]:
from sn_dataloader import SNDataloader
from sn_dataset import SNDataset
from siamese_network import SiameseNetwork
import pytorch_lightning as pl
import functions_manager as fm
from pytorch_lightning import loggers as pl_loggers

# Create Function Manager instance
function_manager = fm.FunctionManager()

# Create an instance of the SiameseNetwork
model = SiameseNetwork()

# Initialize the network
model.init_network()


# Create the dataset
dataset = SNDataset(name_to_reference_map= function_manager.get_name_to_reference(),
                    positive_negative_function_map= function_manager.get_positive_negative_function_map())

# Create the dataloader
batch_size = 8
SNDataloader = SNDataloader(dataset, batch_size=batch_size, shuffle=True)
dataLoader = SNDataloader.get_data_loader()

# Create a logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")

# Train the network
k = 2
trainer = pl.Trainer(max_epochs=20, log_every_n_steps=k, logger=tb_logger, fast_dev_run=True)
trainer.fit(model, dataLoader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Missing logger folder: logs/lightning_logs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 532 M 
---------------------------------------
532 M     Trainable params
0         Non-trainable params
532 M     Total params
2,130.457 Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [2]:
# Test Similarity check
from code_embedding import CodeEmbedding

anchor_fun, positive_fun, negative_fun = dataset.samples[10000]
print(anchor_fun, positive_fun, negative_fun)

code_embedding = CodeEmbedding()
anchor = code_embedding.get_perfect_function_embedding(anchor_fun)
positive = code_embedding.get_perfect_function_embedding(positive_fun)
negative = code_embedding.get_perfect_function_embedding(negative_fun)
# print(anchor, positive, negative)

print(model.similarity_inference(anchor, anchor))
print(model.similarity_inference(anchor, positive))
print(model.similarity_inference(anchor, negative))



<function MathFunctions.a_plus_b_times_a_squared_minus_ab_plus_b_squared at 0x12563c940> <function MathFunctions.a_cubed_plus_b_cubed at 0x12563c820> <function MathFunctions.arccosine at 0x125621d80>
tensor([2.7713e-05], grad_fn=<NormBackward1>)
tensor([4.6593], grad_fn=<NormBackward1>)
tensor([4.9510], grad_fn=<NormBackward1>)


In [3]:
# check training step
loss = model.training_step([anchor, positive, negative], 1)

1 tensor(0.3083, grad_fn=<MeanBackward0>)


In [4]:
# Check Triplet loss
anchor_forward = model(anchor)
position_forward = model(positive)
negative_forward = model(negative)
# print(anchor_forward.shape)

triplet_loss = model.triplet_loss(anchor_forward, position_forward, negative_forward, margin=0.6)
print(triplet_loss)

tensor(0.3083, grad_fn=<MeanBackward0>)


In [5]:
# Check final function embedding size
final_func_embedding = code_embedding.get_final_function_embedding(anchor_forward, position_forward, dim=0)
code_embedding.get_shape(final_func_embedding)

torch.Size([2, 768])
2


(torch.Size([2, 768]), 2)