In [None]:
import torch
from transformers import AutoTokenizer
from tqdm import tqdm
from load import load_dataset
from models.diffpool import DiffPoolModel
from metrics import Metrics

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

In [None]:
model_name = "sentence-transformers/all-MiniLM-L6-v2"
embeddings_dim = 384
tokenizer = AutoTokenizer.from_pretrained(model_name)
train_loader, val_loader = load_dataset(tokenizer)

model = DiffPoolModel(
    model_name=model_name,
    num_node_features=300,
    nout=embeddings_dim,
).to(device)

In [None]:
save_path = "outputs/saved/circle_loss/circle70.pt"
print("Loading best model...")
checkpoint = torch.load(save_path)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval()

In [None]:
import numpy as np
import os

train_loader, _ = load_dataset(tokenizer, shuffle=False, batch_size=32)

model.eval()

graph_embeddings = []
text_embeddings = []
loss = 0
result = {}

with torch.no_grad():
    for batch in tqdm(train_loader):
        input_ids = batch.input_ids
        batch.pop("input_ids")
        attention_mask = batch.attention_mask
        batch.pop("attention_mask")
        graph_batch = batch

        x_graph, x_text = model(
            graph_batch.to(device),
            input_ids.to(device),
            attention_mask.to(device),
        )

        for output in x_graph:
            graph_embeddings.append(output.tolist())
        for output in x_text:
            text_embeddings.append(output.tolist())

graph_embeddings, text_embeddings = (
    torch.Tensor(np.array(graph_embeddings)),
    torch.Tensor(np.array(text_embeddings)),
)
metric = Metrics(loss="circle")
similarities = metric.similarity(text_embeddings, graph_embeddings)

In [None]:
# get the 10% embeddings that have the worst cosine similarity
worst_similarities_indices = torch.diagonal(similarities).argsort(descending=False)[
    : int(len(similarities) * 0.1)
]

In [None]:
print(similarities[worst_similarities_indices, worst_similarities_indices])

worst_similarities_indices = worst_similarities_indices.tolist()
print(worst_similarities_indices)
print(len(worst_similarities_indices))

In [None]:
try:
    os.mkdir("./data/worst_embeddings")
    os.mkdir("./data/worst_embeddings/train")
except:
    pass

train_loader, _ = load_dataset(tokenizer, shuffle=False, batch_size=1)
# create a new dataset with the worst embeddings
for i, data in enumerate(tqdm(train_loader)):
    if i in worst_similarities_indices:
        torch.save(
            data, os.path.join("./data/worst_embeddings/train", "data_{}.pt".format(i))
        )  # The id does not correspond to the CIDs but that's not a problem