In [2]:
import pandas as pd
import torch

device = torch.device("cuda:2")
X = torch.stack([tensor for tensor in pd.read_pickle('./Datasets/cat_embeddings.pkl')['specter_cat_embeddings'].values]).to(device)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define the neural network architecture for classification
class Classifier(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Classifier, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.sigmoid(self.fc2(x))
        return x

# Initialize the model
input_dim = 768  # Embedding size
hidden_dim = 256
output_dim = 1  # Binary classification (positive or not)
model = Classifier(input_dim, hidden_dim, output_dim).to(device)

# Define loss function and optimizer
criterion = nn.BCELoss()  # Binary cross-entropy loss
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Assuming X is a tensor of dimension (313, 768) for positive samples
# and X_unlabeled is a tensor of dimension (N, 768) for unlabelled samples

# Train the model using PU learning algorithm
num_iterations = 10
for iteration in range(num_iterations):
    # Step 1: Train the model on positive samples (P)
    optimizer.zero_grad()
    positive_outputs = model(X)
    loss = criterion(positive_outputs, torch.ones_like(positive_outputs))  # Label positive samples as 1
    loss.backward()
    optimizer.step()

    # Step 2: Pseudo-label unlabelled samples (U) and update the positive set (P)
    with torch.no_grad():
        unlabeled_outputs = model(X_unlabeled)
        # Threshold for pseudo-labelling: you can adjust this threshold based on confidence
        pseudo_labels = (unlabeled_outputs > 0.5).float()  # Threshold at 0.5
        pseudo_positive_samples = X_unlabeled[pseudo_labels.squeeze(1) == 1]
        X = torch.cat((X, pseudo_positive_samples), dim=0)
        
        # Remove pseudo-positive samples from the unlabeled set
        X_unlabeled = X_unlabeled[pseudo_labels.squeeze(1) == 0]

# After training, the model is ready for prediction on new data
