In [71]:
import torch.nn as nn
import torch
import torch.nn.functional as F
from contrastive_loss import ContrastiveLoss
import json
import pandas as pd
import numpy as np


In [72]:
class SourceNet(nn.Module):
    def __init__(self):
        super(SourceNet, self).__init__()
        self.net = nn.Sequential(
            nn.LazyLinear(128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x)

class RefNet(nn.Module):
    def __init__(self):
        super(RefNet, self).__init__()
        self.net = nn.Sequential(
            nn.LazyLinear(128),
            nn.LayerNorm(128),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Linear(64, 1),
        )

    def forward(self, x):
        return self.net(x)


In [73]:

def create_pairs(json_file):
    rows = []
    
    with open(json_file) as f:
        data = json.load(f)
        
    for paper in data:
        paper_id = paper['_id']
        
        # Get the referenced paper IDs from refs_trace
        positive_refs = [ref['_id'] for ref in paper.get('refs_trace', [])]
        
        # Create rows for positive examples
        for ref_id in positive_refs:
            rows.append({
                'source_id': paper_id,
                'ref_id': ref_id,
                'target': 1
            })
        
        # Create rows for negative examples
        for ref_id in paper['references']:
            if ref_id not in positive_refs:
                rows.append({
                    'source_id': paper_id,
                    'ref_id': ref_id,
                    'target': 0
                })
            
                
    return pd.DataFrame(rows)

# Create both positive and negative pairs
df = create_pairs('data/paper_source_trace_train_ans.json')


In [74]:
abs_embeddings = pd.read_csv('data/abstract_embeddings.csv')
title_embeddings = pd.read_csv('data/title_embeddings.csv')

# Get all unique paper IDs from both source and ref columns
all_paper_ids = pd.concat([df['source_id'], df['ref_id']]).unique()

# Get list of paper IDs that exist in abs_embeddings
valid_ids = set(abs_embeddings['id'].values)

# Filter df to only keep rows where both source_id and ref_id are in abs_embeddings
df = df[df['source_id'].isin(valid_ids) & df['ref_id'].isin(valid_ids)].reset_index(drop=True)


In [75]:
# Create a PyTorch dataset class
class PaperDataset(torch.utils.data.Dataset):
    def __init__(self, pairs_df, abs_embeddings_df, title_embeddings_df):
        self.pairs_df = pairs_df
        
        # Create combined embeddings dictionary
        self.embeddings = {}
        for _, row in abs_embeddings_df.iterrows():
            paper_id = row['id']
            abs_emb = row[row.index[1:]].values.astype(np.float32)
            title_emb = title_embeddings_df[title_embeddings_df['id'] == paper_id].iloc[0][title_embeddings_df.columns[1:]].values.astype(np.float32)
            combined_emb = np.concatenate([abs_emb, title_emb])
            self.embeddings[paper_id] = torch.tensor(combined_emb, dtype=torch.float32)
        
    def __len__(self):
        return len(self.pairs_df)
    
    def __getitem__(self, idx):
        row = self.pairs_df.iloc[idx]
        source_emb = self.embeddings[row['source_id']]
        ref_emb = self.embeddings[row['ref_id']]
        label = torch.tensor(row['target'], dtype=torch.float32)
        return source_emb, ref_emb, label

# Create dataset
dataset = PaperDataset(df, abs_embeddings, title_embeddings)

In [76]:
# Create dataloader
batch_size = 32
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True
)

In [77]:

# Initialize networks and optimizer
source_net = SourceNet()
ref_net = RefNet()
optimizer = torch.optim.Adam(list(source_net.parameters()) + list(ref_net.parameters()), lr=1e-4)

# Contrastive loss function
def contrastive_loss(source_out, ref_out, labels, margin=1.0):
    distances = F.pairwise_distance(source_out, ref_out)
    return torch.mean((1-labels) * torch.pow(distances, 2) + 
                     labels * torch.pow(torch.clamp(margin - distances, min=0.0), 2))

# Training loop
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0
    for source_emb, ref_emb, labels in dataloader:
        # Forward pass
        source_out = source_net(source_emb)
        ref_out = ref_net(ref_emb)
        
        # Calculate loss
        loss = contrastive_loss(source_out, ref_out, labels)
        
        # Backward pass and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    avg_loss = total_loss / len(dataloader)
    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}')


Epoch [1/100], Loss: 0.1154
Epoch [2/100], Loss: 0.1103
Epoch [3/100], Loss: 0.1067
Epoch [4/100], Loss: 0.1062
Epoch [5/100], Loss: 0.1056
Epoch [6/100], Loss: 0.1061
Epoch [7/100], Loss: 0.1088
Epoch [8/100], Loss: 0.1051
Epoch [9/100], Loss: 0.1023
Epoch [10/100], Loss: 0.1035
Epoch [11/100], Loss: 0.1038
Epoch [12/100], Loss: 0.1027
Epoch [13/100], Loss: 0.1030
Epoch [14/100], Loss: 0.1020
Epoch [15/100], Loss: 0.1012
Epoch [16/100], Loss: 0.1019
Epoch [17/100], Loss: 0.1030
Epoch [18/100], Loss: 0.1021
Epoch [19/100], Loss: 0.1016
Epoch [20/100], Loss: 0.1014
Epoch [21/100], Loss: 0.1008
Epoch [22/100], Loss: 0.1031
Epoch [23/100], Loss: 0.1030
Epoch [24/100], Loss: 0.1061
Epoch [25/100], Loss: 0.1037
Epoch [26/100], Loss: 0.1037
Epoch [27/100], Loss: 0.1005
Epoch [28/100], Loss: 0.1003
Epoch [29/100], Loss: 0.1002
Epoch [30/100], Loss: 0.0999
Epoch [31/100], Loss: 0.0998
Epoch [32/100], Loss: 0.1008
Epoch [33/100], Loss: 0.0997
Epoch [34/100], Loss: 0.1021
Epoch [35/100], Loss: 0

In [78]:
# Calculate average distances for target and non-target pairs
target_distances = []
nontarget_distances = []

with torch.no_grad():
    for source_emb, ref_emb, labels in dataloader:
        # Get network outputs
        source_out = source_net(source_emb)
        ref_out = ref_net(ref_emb)
        
        # Calculate distances
        distances = F.pairwise_distance(source_out, ref_out)
        
        # Split distances based on labels
        target_distances.extend(distances[labels == 0].tolist())
        nontarget_distances.extend(distances[labels == 1].tolist())

avg_target_dist = sum(target_distances) / len(target_distances)
avg_nontarget_dist = sum(nontarget_distances) / len(nontarget_distances)

print(f"Average distance for target pairs: {avg_target_dist:.4f}")
print(f"Average distance for non-target pairs: {avg_nontarget_dist:.4f}")


Average distance for target pairs: 0.0639
Average distance for non-target pairs: 0.1307


In [79]:
# Load and process validation data
with open('data/paper_source_trace_valid_wo_ans.json', 'r') as f:
    valid_data = json.load(f)

# Load submission example for padding reference
with open('data/submission_example_valid.json', 'r') as f:
    submission_example = json.load(f)

embeddings_dict = dataset.embeddings
paper_scores = {}

with torch.no_grad():
    for paper in valid_data:
        paper_id = paper['_id']
        references = paper['references']
        expected_length = len(submission_example[paper_id])
        
        if paper_id not in embeddings_dict:
            paper_scores[paper_id] = [0.0] * expected_length
            continue
        
        source_out = source_net(embeddings_dict[paper_id].unsqueeze(0))
        
        scores = []
        for ref_id in references[:expected_length]:
            if ref_id not in embeddings_dict:
                scores.append(0.0)
                continue
            
            ref_out = ref_net(embeddings_dict[ref_id].unsqueeze(0))
            distance = F.pairwise_distance(source_out, ref_out).item()
            score = max(0, min(1, 1 - distance))  # Clip score between 0 and 1
            scores.append(score)
        
        scores += [0.0] * (expected_length - len(scores))
        paper_scores[paper_id] = scores

with open('validation_scores.json', 'w') as f:
    json.dump(paper_scores, f)

In [80]:
# Compare lengths with submission example
with open('data/submission_example_valid.json', 'r') as f:
    submission_example = json.load(f)

for paper_id, scores in paper_scores.items():
    example_scores = submission_example[paper_id]
    if len(scores) != len(example_scores):
        print(f"{paper_id}: scores={len(scores)}, example={len(example_scores)}")
