#### Main notebook for training and saving siamese network.

In [None]:
from src.sn_dataloader import SNDataloader
from src.sn_dataset import SNDataset
from src.siamese_network import SiameseNetwork
import pytorch_lightning as pl
import functions_manager as fm
from pytorch_lightning import loggers as pl_loggers

# Create Function Manager instance
function_manager = fm.FunctionManager()

# Create an instance of the SiameseNetwork
model = SiameseNetwork()

# Initialize the network
model.init_network()


# Create the dataset
dataset = SNDataset(name_to_reference_map= function_manager.getNameToReference(),
                    positive_negative_function_map= function_manager.getPositiveNegativeFunctionMap())

# Create the dataloader
batch_size = 8
SNDataloader = SNDataloader(dataset, batch_size=batch_size, shuffle=True)
dataLoader = SNDataloader.getDataLoader()

# Create a logger
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")

# Train the network
# k = 2
# trainer = pl.Trainer(max_epochs=20, log_every_n_steps=k, logger=tb_logger)
# trainer.fit(model, dataLoader)

In [2]:
# Test Similarity check
from src.code_embedding import CodeEmbedding

anchor_fun, positive_fun, negative_fun = dataset.samples[10000]
# print(anchor_fun, positive_fun, negative_fun)

code_embedding = CodeEmbedding()
anchor = code_embedding.getPerfectFunctionEmbedding(anchor_fun)
positive = code_embedding.getPerfectFunctionEmbedding(positive_fun)
negative = code_embedding.getPerfectFunctionEmbedding(negative_fun)
# print(anchor, positive, negative)

print(model.similarity_inference(anchor, anchor))
print(model.similarity_inference(anchor, positive))
print(model.similarity_inference(anchor, negative))



tensor([2.7713e-05], grad_fn=<NormBackward1>)
tensor([4.4640], grad_fn=<NormBackward1>)
tensor([3.5096], grad_fn=<NormBackward1>)


In [3]:
# check training step
loss = model.training_step([anchor, positive, negative], 1)

1 tensor(1.5544, grad_fn=<MeanBackward0>)


  rank_zero_warn(


In [4]:
# Check Triplet loss
anchor_forward = model(anchor)
position_forward = model(positive)
negative_forward = model(negative)
# print(anchor_forward.shape)

triplet_loss = model.triplet_loss(anchor_forward, position_forward, negative_forward, margin=0.6)
print(triplet_loss)

tensor(1.5544, grad_fn=<MeanBackward0>)


In [6]:
# Check final function embedding size
final_func_embedding = code_embedding.getFinalFunctionEmbedding(anchor_forward, position_forward, dim=0)
code_embedding.getShape(final_func_embedding)

TypeError: CodeEmbedding.getFinalFunctionEmbedding() got an unexpected keyword argument 'dim'