In [14]:
import sys, os
sys.path.append(os.path.abspath('..'))

In [15]:
import torch
from Data import GraphDataset
from Models import GCNFeatureExtractor, CLWrapper
from utils.train import train

In [16]:
use_pre_paired = True
batch_size = 3
lr = 0.01
epochs = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
embedding_size = 10
pooling = 'global_avg'

In [17]:
# load data 
from Data import load_graphs

basic_data, _ = load_graphs()

	Collected 1 sample from null_ops.
	Collected 4 samples from commutations.
	Collected 7 samples from equivalences.
	Collected 7 samples from combined.
Loaded 19 samples and 19 quantum circuits from subset.


In [18]:
# import pickle

# def load_graphs(data_dir='../Data/raw/', file_name='handcrafted_dataset.pkl', subset=None):
#     file_path = os.path.join(data_dir, file_name)
#     graphs, qcs = [], []
    
#     with open(file_path, 'rb') as f:
#         dataset = pickle.load(f)

#     if file_name == 'handcrafted_dataset.pkl':
#         if subset is not None:
#             dataset = dataset[subset]
#             print(f"Loaded {len(dataset)} elements from subset {subset}:")
#         # collect all graphs in dataset, which are stored in a nested dictionary
#         dataset = collect_from_dict(dataset) 

#     # extract graphs and qcs separately, if needed
#     for sample in dataset:
#         if isinstance(sample, tuple): # if a tuple (qc, graph)
#             g, qc = sample
#             graphs.append(g)
#             qcs.append(qc)
#         elif all(isinstance(graph, tuple) for graph in sample):# if a list of tuples (qc, graph) 
#             g, qc = zip(*sample)
#             qcs.append(list(qc))
#             graphs.append(list(g))

#     print(f"Loaded {len(graphs)} samples and {len(qcs)} quantum circuits from subset.")
    
#     return graphs, qcs

# def collect_from_dict(dictionary):
#     graphs = []
#     for k, v in dictionary.items():
#         if isinstance(v, dict):
#             graphs.extend(collect_from_dict(v))
#         elif isinstance(v, list):
#             if all(isinstance(item, list) for item in v):  # if list of lists
#                 graphs.extend(v)
#                 print(f"\tCollected {len(v)} items from {k}.")
#             elif all(isinstance(item, tuple) for item in v):  # if list of tuples
#                 graphs.append(v)
#                 print(f"\tCollected 1 sample from {k}.")
                
#     return graphs

## Basic transforms training

In [19]:
dataset = GraphDataset(basic_data[:-1], pre_paired=use_pre_paired)

In [20]:
graphs = basic_data[0]
print(len(graphs))

18


In [21]:
import numpy as np
import networkx as nx
def get_attr_matrix(graph):
    nodes_list = list(graph.nodes)
    nodes_view = graph.nodes(data=True)
    return np.array([nodes_view[node]['feature_vector'] for node in nodes_list])

In [22]:
# Check if graphs in the dataset are correctly converted to PyTorch Geometric Data objects
graphs = [basic_data[0][0:2]]
dataset_1 = GraphDataset(graphs, pre_paired=use_pre_paired)

data_1, data_2 = dataset_1[0]

print("NODE FEATURES:")

print(get_attr_matrix(graphs[0][0]))
print(get_attr_matrix(graphs[0][1]))

print("---")

print(data_1.x)
print(data_2.x)

print("ADJACENCY:")

print(nx.adjacency_matrix(graphs[0][0]).todense())
print(nx.adjacency_matrix(graphs[0][1]).todense())

print("---")

print(data_1.edge_index)
print(data_2.edge_index)

NODE FEATURES:
[[0 1 0 0 0 0]
 [1 0 0 0 1 0]
 [1 0 0 0 0 1]]
[[0 1 0 0 0 0]
 [1 0 0 0 1 0]
 [1 0 0 0 0 1]
 [0 1 0 0 0 0]
 [0 1 0 0 0 0]]
---
tensor([[0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 1.],
        [0., 1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0., 0.]])
tensor([[0., 1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 1., 0.],
        [1., 0., 0., 0., 0., 1.]])
ADJACENCY:
[[0 1 0]
 [0 0 1]
 [0 1 0]]
[[0 1 0 0 0]
 [0 0 1 1 0]
 [0 1 0 0 0]
 [0 0 0 0 1]
 [0 0 0 0 0]]
---
tensor([[0, 1, 1, 2, 3],
        [1, 2, 3, 1, 4]])
tensor([[0, 1, 2],
        [1, 2, 1]])


In [23]:
dataset_1

<Data.Dataset.GraphDataset at 0x1caffae2750>

In [24]:
# create a fully connected nn stacking fully connected layers and ReLU activations
out_dim_gnn = 64

fc_model = torch.nn.Sequential(
    torch.nn.Linear(out_dim_gnn, 256),
    torch.nn.Sigmoid(),
    torch.nn.Linear(256, 128),
    torch.nn.Sigmoid(),
    torch.nn.Linear(128, embedding_size)
)



In [25]:
n_features = dataset[0][0].x.shape[1] if use_pre_paired else dataset[0].x.shape[1]
print(f'Number of features in the dataset: {n_features}')

gnn = GCNFeatureExtractor(in_channels=n_features, out_channels=embedding_size, pooling_strategy=pooling)
proj = fc_model

# No projection head
model = CLWrapper(gnn).to(device)

Number of features in the dataset: 6


In [26]:
train(model, dataset, epochs, batch_size, tau= 1.0, device=device)

Epoch 1/100: 100%|██████████| 6/6 [00:00<00:00, 25.36batch/s, Loss=1.57]


Epoch 1/100 completed. Avg Loss: 1.5750


Epoch 2/100: 100%|██████████| 6/6 [00:00<00:00, 47.53batch/s, Loss=1.57]


Epoch 2/100 completed. Avg Loss: 1.5664


Epoch 3/100: 100%|██████████| 6/6 [00:00<00:00, 55.84batch/s, Loss=1.48]


Epoch 3/100 completed. Avg Loss: 1.4847


Epoch 4/100: 100%|██████████| 6/6 [00:00<00:00, 57.05batch/s, Loss=1.47]


Epoch 4/100 completed. Avg Loss: 1.4702


Epoch 5/100: 100%|██████████| 6/6 [00:00<00:00, 58.54batch/s, Loss=1.52]


Epoch 5/100 completed. Avg Loss: 1.5165


Epoch 6/100: 100%|██████████| 6/6 [00:00<00:00, 49.56batch/s, Loss=1.41]


Epoch 6/100 completed. Avg Loss: 1.4068


Epoch 7/100: 100%|██████████| 6/6 [00:00<00:00, 44.91batch/s, Loss=1.51]


Epoch 7/100 completed. Avg Loss: 1.5085


Epoch 8/100: 100%|██████████| 6/6 [00:00<00:00, 46.60batch/s, Loss=1.49]


Epoch 8/100 completed. Avg Loss: 1.4855


Epoch 9/100: 100%|██████████| 6/6 [00:00<00:00, 46.76batch/s, Loss=1.41]


Epoch 9/100 completed. Avg Loss: 1.4090


Epoch 10/100: 100%|██████████| 6/6 [00:00<00:00, 44.06batch/s, Loss=1.29]


Epoch 10/100 completed. Avg Loss: 1.2851


Epoch 11/100: 100%|██████████| 6/6 [00:00<00:00, 43.07batch/s, Loss=1.38]


Epoch 11/100 completed. Avg Loss: 1.3815


Epoch 12/100: 100%|██████████| 6/6 [00:00<00:00, 39.21batch/s, Loss=1.5]  


Epoch 12/100 completed. Avg Loss: 1.4952


Epoch 13/100: 100%|██████████| 6/6 [00:00<00:00, 49.32batch/s, Loss=1.45]


Epoch 13/100 completed. Avg Loss: 1.4477


Epoch 14/100: 100%|██████████| 6/6 [00:00<00:00, 46.13batch/s, Loss=1.33]


Epoch 14/100 completed. Avg Loss: 1.3330


Epoch 15/100: 100%|██████████| 6/6 [00:00<00:00, 47.80batch/s, Loss=1.36]


Epoch 15/100 completed. Avg Loss: 1.3607


Epoch 16/100: 100%|██████████| 6/6 [00:00<00:00, 49.67batch/s, Loss=1.06]


Epoch 16/100 completed. Avg Loss: 1.0648


Epoch 17/100: 100%|██████████| 6/6 [00:00<00:00, 29.03batch/s, Loss=1.56]


Epoch 17/100 completed. Avg Loss: 1.5629


Epoch 18/100: 100%|██████████| 6/6 [00:00<00:00, 49.59batch/s, Loss=1.45]


Epoch 18/100 completed. Avg Loss: 1.4495


Epoch 19/100: 100%|██████████| 6/6 [00:00<00:00, 45.74batch/s, Loss=1.21]


Epoch 19/100 completed. Avg Loss: 1.2142


Epoch 20/100: 100%|██████████| 6/6 [00:00<00:00, 46.29batch/s, Loss=1.14]


Epoch 20/100 completed. Avg Loss: 1.1397


Epoch 21/100: 100%|██████████| 6/6 [00:00<00:00, 41.82batch/s, Loss=1.22]


Epoch 21/100 completed. Avg Loss: 1.2217


Epoch 22/100: 100%|██████████| 6/6 [00:00<00:00, 38.13batch/s, Loss=1.34]


Epoch 22/100 completed. Avg Loss: 1.3423


Epoch 23/100: 100%|██████████| 6/6 [00:00<00:00, 46.12batch/s, Loss=1.32]


Epoch 23/100 completed. Avg Loss: 1.3191


Epoch 24/100: 100%|██████████| 6/6 [00:00<00:00, 44.40batch/s, Loss=1.32]


Epoch 24/100 completed. Avg Loss: 1.3248


Epoch 25/100: 100%|██████████| 6/6 [00:00<00:00, 31.40batch/s, Loss=1.24]


Epoch 25/100 completed. Avg Loss: 1.2361


Epoch 26/100: 100%|██████████| 6/6 [00:00<00:00, 39.64batch/s, Loss=1.2] 


Epoch 26/100 completed. Avg Loss: 1.2047


Epoch 27/100: 100%|██████████| 6/6 [00:00<00:00, 38.10batch/s, Loss=1.32] 


Epoch 27/100 completed. Avg Loss: 1.3185


Epoch 28/100: 100%|██████████| 6/6 [00:00<00:00, 48.33batch/s, Loss=1.21]


Epoch 28/100 completed. Avg Loss: 1.2146


Epoch 29/100: 100%|██████████| 6/6 [00:00<00:00, 50.78batch/s, Loss=1.29]


Epoch 29/100 completed. Avg Loss: 1.2950


Epoch 30/100: 100%|██████████| 6/6 [00:00<00:00, 48.81batch/s, Loss=1.2] 


Epoch 30/100 completed. Avg Loss: 1.1979


Epoch 31/100: 100%|██████████| 6/6 [00:00<00:00, 43.27batch/s, Loss=1.29]


Epoch 31/100 completed. Avg Loss: 1.2911


Epoch 32/100: 100%|██████████| 6/6 [00:00<00:00, 43.36batch/s, Loss=1.31]


Epoch 32/100 completed. Avg Loss: 1.3066


Epoch 33/100: 100%|██████████| 6/6 [00:00<00:00, 44.78batch/s, Loss=1.28]


Epoch 33/100 completed. Avg Loss: 1.2772


Epoch 34/100: 100%|██████████| 6/6 [00:00<00:00, 47.33batch/s, Loss=1.12]


Epoch 34/100 completed. Avg Loss: 1.1178


Epoch 35/100: 100%|██████████| 6/6 [00:00<00:00, 43.63batch/s, Loss=1.32]


Epoch 35/100 completed. Avg Loss: 1.3166


Epoch 36/100: 100%|██████████| 6/6 [00:00<00:00, 40.36batch/s, Loss=1.4] 


Epoch 36/100 completed. Avg Loss: 1.3983


Epoch 37/100: 100%|██████████| 6/6 [00:00<00:00, 38.18batch/s, Loss=1.35]


Epoch 37/100 completed. Avg Loss: 1.3497


Epoch 38/100: 100%|██████████| 6/6 [00:00<00:00, 39.08batch/s, Loss=1.16]


Epoch 38/100 completed. Avg Loss: 1.1622


Epoch 39/100: 100%|██████████| 6/6 [00:00<00:00, 39.35batch/s, Loss=1.4] 


Epoch 39/100 completed. Avg Loss: 1.3968


Epoch 40/100: 100%|██████████| 6/6 [00:00<00:00, 45.91batch/s, Loss=1.28]


Epoch 40/100 completed. Avg Loss: 1.2753


Epoch 41/100: 100%|██████████| 6/6 [00:00<00:00, 49.36batch/s, Loss=1.22]


Epoch 41/100 completed. Avg Loss: 1.2229


Epoch 42/100: 100%|██████████| 6/6 [00:00<00:00, 46.32batch/s, Loss=1.38]


Epoch 42/100 completed. Avg Loss: 1.3847


Epoch 43/100: 100%|██████████| 6/6 [00:00<00:00, 44.09batch/s, Loss=1.45]


Epoch 43/100 completed. Avg Loss: 1.4500


Epoch 44/100: 100%|██████████| 6/6 [00:00<00:00, 45.87batch/s, Loss=1.24]


Epoch 44/100 completed. Avg Loss: 1.2403


Epoch 45/100: 100%|██████████| 6/6 [00:00<00:00, 49.22batch/s, Loss=1.23]


Epoch 45/100 completed. Avg Loss: 1.2291


Epoch 46/100: 100%|██████████| 6/6 [00:00<00:00, 34.16batch/s, Loss=1.13]


Epoch 46/100 completed. Avg Loss: 1.1343


Epoch 47/100: 100%|██████████| 6/6 [00:00<00:00, 48.61batch/s, Loss=1.33]


Epoch 47/100 completed. Avg Loss: 1.3265


Epoch 48/100: 100%|██████████| 6/6 [00:00<00:00, 49.10batch/s, Loss=1.3] 


Epoch 48/100 completed. Avg Loss: 1.2972


Epoch 49/100: 100%|██████████| 6/6 [00:00<00:00, 50.33batch/s, Loss=1.2] 


Epoch 49/100 completed. Avg Loss: 1.1981


Epoch 50/100: 100%|██████████| 6/6 [00:00<00:00, 50.87batch/s, Loss=1.22]


Epoch 50/100 completed. Avg Loss: 1.2217


Epoch 51/100: 100%|██████████| 6/6 [00:00<00:00, 47.42batch/s, Loss=1.37]


Epoch 51/100 completed. Avg Loss: 1.3703


Epoch 52/100: 100%|██████████| 6/6 [00:00<00:00, 45.73batch/s, Loss=1.31]


Epoch 52/100 completed. Avg Loss: 1.3075


Epoch 53/100: 100%|██████████| 6/6 [00:00<00:00, 48.05batch/s, Loss=1.44]


Epoch 53/100 completed. Avg Loss: 1.4369


Epoch 54/100: 100%|██████████| 6/6 [00:00<00:00, 47.85batch/s, Loss=1.28]


Epoch 54/100 completed. Avg Loss: 1.2849


Epoch 55/100: 100%|██████████| 6/6 [00:00<00:00, 42.97batch/s, Loss=1.47]


Epoch 55/100 completed. Avg Loss: 1.4666


Epoch 56/100: 100%|██████████| 6/6 [00:00<00:00, 43.27batch/s, Loss=1.4] 


Epoch 56/100 completed. Avg Loss: 1.3979


Epoch 57/100: 100%|██████████| 6/6 [00:00<00:00, 48.46batch/s, Loss=1.23]


Epoch 57/100 completed. Avg Loss: 1.2347


Epoch 58/100: 100%|██████████| 6/6 [00:00<00:00, 45.68batch/s, Loss=1.36]


Epoch 58/100 completed. Avg Loss: 1.3648


Epoch 59/100: 100%|██████████| 6/6 [00:00<00:00, 42.06batch/s, Loss=1.24]


Epoch 59/100 completed. Avg Loss: 1.2447


Epoch 60/100: 100%|██████████| 6/6 [00:00<00:00, 36.89batch/s, Loss=1.14]


Epoch 60/100 completed. Avg Loss: 1.1362


Epoch 61/100: 100%|██████████| 6/6 [00:00<00:00, 35.67batch/s, Loss=1.25] 


Epoch 61/100 completed. Avg Loss: 1.2472


Epoch 62/100: 100%|██████████| 6/6 [00:00<00:00, 39.31batch/s, Loss=1.14] 


Epoch 62/100 completed. Avg Loss: 1.1444


Epoch 63/100: 100%|██████████| 6/6 [00:00<00:00, 43.58batch/s, Loss=1.34]


Epoch 63/100 completed. Avg Loss: 1.3362


Epoch 64/100: 100%|██████████| 6/6 [00:00<00:00, 49.24batch/s, Loss=1.22]


Epoch 64/100 completed. Avg Loss: 1.2226


Epoch 65/100: 100%|██████████| 6/6 [00:00<00:00, 46.80batch/s, Loss=1.4] 


Epoch 65/100 completed. Avg Loss: 1.4030


Epoch 66/100: 100%|██████████| 6/6 [00:00<00:00, 45.47batch/s, Loss=1.4] 


Epoch 66/100 completed. Avg Loss: 1.4038


Epoch 67/100: 100%|██████████| 6/6 [00:00<00:00, 46.04batch/s, Loss=1.33]


Epoch 67/100 completed. Avg Loss: 1.3295


Epoch 68/100: 100%|██████████| 6/6 [00:00<00:00, 44.52batch/s, Loss=1.21]


Epoch 68/100 completed. Avg Loss: 1.2065


Epoch 69/100: 100%|██████████| 6/6 [00:00<00:00, 43.66batch/s, Loss=1.17] 


Epoch 69/100 completed. Avg Loss: 1.1658


Epoch 70/100: 100%|██████████| 6/6 [00:00<00:00, 46.89batch/s, Loss=1.22]


Epoch 70/100 completed. Avg Loss: 1.2238


Epoch 71/100: 100%|██████████| 6/6 [00:00<00:00, 22.69batch/s, Loss=1.09]


Epoch 71/100 completed. Avg Loss: 1.0924


Epoch 72/100: 100%|██████████| 6/6 [00:00<00:00, 39.07batch/s, Loss=1.48]


Epoch 72/100 completed. Avg Loss: 1.4827


Epoch 73/100: 100%|██████████| 6/6 [00:00<00:00, 34.84batch/s, Loss=1.3] 


Epoch 73/100 completed. Avg Loss: 1.2970


Epoch 74/100: 100%|██████████| 6/6 [00:00<00:00, 38.62batch/s, Loss=1.31]


Epoch 74/100 completed. Avg Loss: 1.3118


Epoch 75/100: 100%|██████████| 6/6 [00:00<00:00, 37.71batch/s, Loss=1.4] 


Epoch 75/100 completed. Avg Loss: 1.4030


Epoch 76/100: 100%|██████████| 6/6 [00:00<00:00, 40.26batch/s, Loss=1.19]


Epoch 76/100 completed. Avg Loss: 1.1943


Epoch 77/100: 100%|██████████| 6/6 [00:00<00:00, 34.88batch/s, Loss=1.23] 


Epoch 77/100 completed. Avg Loss: 1.2283


Epoch 78/100: 100%|██████████| 6/6 [00:00<00:00, 38.96batch/s, Loss=1.15] 


Epoch 78/100 completed. Avg Loss: 1.1512


Epoch 79/100: 100%|██████████| 6/6 [00:00<00:00, 35.45batch/s, Loss=1.49]


Epoch 79/100 completed. Avg Loss: 1.4913


Epoch 80/100: 100%|██████████| 6/6 [00:00<00:00, 29.88batch/s, Loss=1.3] 


Epoch 80/100 completed. Avg Loss: 1.3034


Epoch 81/100: 100%|██████████| 6/6 [00:00<00:00, 37.34batch/s, Loss=1.27]


Epoch 81/100 completed. Avg Loss: 1.2731


Epoch 82/100: 100%|██████████| 6/6 [00:00<00:00, 33.56batch/s, Loss=1.44]


Epoch 82/100 completed. Avg Loss: 1.4364


Epoch 83/100: 100%|██████████| 6/6 [00:00<00:00, 35.42batch/s, Loss=1.47]


Epoch 83/100 completed. Avg Loss: 1.4672


Epoch 84/100: 100%|██████████| 6/6 [00:00<00:00, 38.23batch/s, Loss=1.21]


Epoch 84/100 completed. Avg Loss: 1.2109


Epoch 85/100: 100%|██████████| 6/6 [00:00<00:00, 31.28batch/s, Loss=1.16]


Epoch 85/100 completed. Avg Loss: 1.1634


Epoch 86/100: 100%|██████████| 6/6 [00:00<00:00, 37.30batch/s, Loss=1.24]


Epoch 86/100 completed. Avg Loss: 1.2423


Epoch 87/100: 100%|██████████| 6/6 [00:00<00:00, 36.30batch/s, Loss=1.31]


Epoch 87/100 completed. Avg Loss: 1.3059


Epoch 88/100: 100%|██████████| 6/6 [00:00<00:00, 36.30batch/s, Loss=1.33]


Epoch 88/100 completed. Avg Loss: 1.3323


Epoch 89/100: 100%|██████████| 6/6 [00:00<00:00, 30.70batch/s, Loss=1.18]


Epoch 89/100 completed. Avg Loss: 1.1804


Epoch 90/100: 100%|██████████| 6/6 [00:00<00:00, 32.44batch/s, Loss=1.32]


Epoch 90/100 completed. Avg Loss: 1.3197


Epoch 91/100: 100%|██████████| 6/6 [00:00<00:00, 17.77batch/s, Loss=1.32]


Epoch 91/100 completed. Avg Loss: 1.3220


Epoch 92/100: 100%|██████████| 6/6 [00:00<00:00, 30.29batch/s, Loss=1.4] 


Epoch 92/100 completed. Avg Loss: 1.3952


Epoch 93/100: 100%|██████████| 6/6 [00:00<00:00, 36.07batch/s, Loss=1.35]


Epoch 93/100 completed. Avg Loss: 1.3548


Epoch 94/100: 100%|██████████| 6/6 [00:00<00:00, 43.51batch/s, Loss=1.32]


Epoch 94/100 completed. Avg Loss: 1.3235


Epoch 95/100: 100%|██████████| 6/6 [00:00<00:00, 43.56batch/s, Loss=1.29]


Epoch 95/100 completed. Avg Loss: 1.2908


Epoch 96/100: 100%|██████████| 6/6 [00:00<00:00, 41.15batch/s, Loss=1.31]


Epoch 96/100 completed. Avg Loss: 1.3050


Epoch 97/100: 100%|██████████| 6/6 [00:00<00:00, 41.49batch/s, Loss=1.26]


Epoch 97/100 completed. Avg Loss: 1.2637


Epoch 98/100: 100%|██████████| 6/6 [00:00<00:00, 39.01batch/s, Loss=1.17] 


Epoch 98/100 completed. Avg Loss: 1.1730


Epoch 99/100: 100%|██████████| 6/6 [00:00<00:00, 36.30batch/s, Loss=1.27]


Epoch 99/100 completed. Avg Loss: 1.2695


Epoch 100/100:   0%|          | 0/6 [00:00<?, ?batch/s]

Input representations:
 

Epoch 100/100:  67%|██████▋   | 4/6 [00:00<00:00, 34.39batch/s, Loss=1.18]

tensor([[ 1.0973e-01, -4.2434e-01,  3.2765e-01, -1.4701e-01, -3.7467e-01,
         -1.5895e-01,  3.2865e-01,  2.8424e-01,  2.6551e-01,  5.0397e-01],
        [ 6.1507e-01,  1.4024e-01, -3.0142e-01,  4.4778e-01, -3.5685e-01,
          2.4893e-01,  1.1806e-01, -1.7668e-01, -2.7579e-01, -1.1973e-02],
        [ 5.1013e-01,  1.9867e-01, -3.5249e-01,  4.4016e-01, -2.5609e-01,
          2.6141e-01,  4.6874e-02, -2.4684e-01, -3.9730e-01, -1.6560e-01],
        [ 2.3907e-01,  1.1930e-01, -4.9789e-01,  5.8901e-02, -1.6954e-01,
          3.0985e-01, -2.0877e-02, -4.4812e-03, -5.8746e-01, -4.5491e-01],
        [ 6.0522e-02, -1.8120e-01,  4.4700e-01,  1.2699e-01, -5.2287e-04,
         -3.2098e-01,  2.4908e-01, -1.9266e-02,  4.6238e-01,  6.0690e-01],
        [ 1.1142e-01,  1.9487e-01, -4.7025e-01,  1.9519e-01, -1.9101e-02,
          2.4697e-01, -1.3743e-01, -1.5969e-01, -5.7597e-01, -5.0289e-01]],
       grad_fn=<CatBackward0>)
Positive pairs scores:
 tensor([[-0.5753],
        [-0.2477],
        [ 0.

Epoch 100/100: 100%|██████████| 6/6 [00:00<00:00, 28.82batch/s, Loss=1.27]

Epoch 100/100 completed. Avg Loss: 1.2717



