In [None]:
from typing import Any
import torch

class KGAT(torch.nn.Module):
    def __init__(self, config : dict[str, Any]):
        super(KGAT, self).__init__() 
        self.config = config

        self.premise_embeddings = torch.nn.Embedding(config['num_premises'], config['embedding_dim'])

        self.relations_weights = torch.nn.ModuleList()
        
        for _ in range(config['n_layers']):
            self.relations_weights.append(
                torch.nn.Parameter(torch.Tensor(config['num_relations'], config['embedding_dim'], config['embedding_dim']))
            )

        torch.nn.init.xavier_uniform_(self.premise_embeddings.weight)
        for rel_weight in self.relations_weights:
            torch.nn.init.xavier_uniform_(rel_weight)

        self.scorer = torch.nn.Sequential(
            torch.nn.Linear(config['embedding_dim'] * (config['n_layers'] + 1) * 2, config['embedding_dim']),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(config['embedding_dim'], 1)
        )

    def forward(self, n_contexts : int, edge_index : torch.Tensor, edge_type : torch.Tensor):
        premise_emb = self.premise_embeddings.weight
    
        # append zeros for n_contexts
        context_emb = torch.zeros((n_contexts, self.config['embedding_dim']), device=premise_emb.device)
        X = torch.cat([premise_emb, context_emb], dim=0)

        all_embs = [X]
        for layer in range(self.config['n_layers']):
            neigh_emb = torch.zeros_like(X)
            for rel in range(self.config['num_relations']):
                rel_edges = (edge_type == rel).nonzero(as_tuple=True)[0]
                if rel_edges.numel() == 0:
                    continue
                src_nodes = edge_index[0, rel_edges]
                dst_nodes = edge_index[1, rel_edges]
                
                # move both head and tail to relation subspace
                src_emb = X[src_nodes]
                rel_weight = self.relations_weights[layer][rel]
                src_transformed = torch.matmul(src_emb, rel_weight)
                dst_emb = X[dst_nodes]
                dst_transformerd = torch.tanh(torch.matmul(dst_emb, rel_weight))
                # attention scores
                scores = (src_transformed * dst_transformerd).sum(dim=1)
                messages = src_transformed * scores.unsqueeze(1)

                # apply message dropout
                messages = torch.nn.functional.dropout(messages, p=self.config['message_dropout'], training=self.training)
    
                # aggregate messages using add
                neigh_emb = neigh_emb.index_add(0, dst_nodes, messages)
                

            #biinteraction=LeakyReLU W1(eh + eN) +  LeakyReLU W2(eh âŠ™ eN),
            X = torch.nn.functional.leaky_relu(X + neigh_emb) + torch.nn.functional.leaky_relu(X * neigh_emb)
            # normalize
            X = torch.nn.functional.dropout(X, p=self.config['node_dropout'], training=self.training)
            X = torch.nn.functional.normalize(X, p=2, dim=1)

            all_embs.append(X)

        all_embs = torch.cat(all_embs, dim=1) # (num_premises + n_contexts, embedding_dim * (n_layers + 1))

        return all_embs[:self.config['num_premises']], all_embs[self.config['num_premises']:] # premise_emb, context_emb
    
    def score(self, n_contexts : int, edge_index : torch.Tensor, edge_type : torch.Tensor) -> torch.Tensor:
        premise_emb, context_emb = self.forward(n_contexts, edge_index, edge_type)
        n_premises = premise_emb.size(0)

        # create all pairs of premise and context embeddings
        premise_expanded = premise_emb.unsqueeze(1).expand(-1, n_contexts, -1)
        context_expanded = context_emb.unsqueeze(0).expand(n_premises, -1, -1)

        # concatenate premise and context embeddings
        pair_emb = torch.cat([premise_expanded, context_expanded], dim=-1) # (n_premises, n_contexts, embedding_dim * (n_layers + 1) * 2)

        scores = self.scorer(pair_emb).squeeze(-1) # (n_premises, n_contexts)

        return scores

SyntaxError: incomplete input (1250243143.py, line 4)

In [None]:
# made up training data for testing one batch overfitting
NUM_PREMISES = 10000
NUM_CONTEXTS = 100
NUM_RELATIONS = 2
PREMISE_EDGE_INDEX = torch.randint(0, NUM_PREMISES, (2, 100000)).cuda()
# Remove edges where src >= dst
PREMISE_EDGE_INDEX = PREMISE_EDGE_INDEX[:, PREMISE_EDGE_INDEX[0] < PREMISE_EDGE_INDEX[1]]
PREMISE_EDGE_TYPE = torch.randint(0, NUM_RELATIONS, (PREMISE_EDGE_INDEX.size(1),)).cuda()

PREMISE_TO_CONTEXT_EDGE_INDEX = torch.concat([
    torch.randint(0, NUM_PREMISES, (1, 5000)),
    torch.randint(NUM_PREMISES, NUM_PREMISES + NUM_CONTEXTS, (1, 5000))
], dim=0).cuda()
PREMISE_TO_CONTEXT_EDGE_TYPE = torch.randint(0, NUM_RELATIONS, (PREMISE_TO_CONTEXT_EDGE_INDEX.size(1),)).cuda()

CONTEXT_LABELS = torch.concat([
    torch.randint(NUM_PREMISES, NUM_PREMISES + NUM_CONTEXTS, (2 * NUM_CONTEXTS,)).cuda(),
    torch.randint(0, NUM_PREMISES, (2 * NUM_CONTEXTS,)).cuda()
])

EDGE_INDEX = torch.concat([PREMISE_EDGE_INDEX, PREMISE_TO_CONTEXT_EDGE_INDEX], dim=1)
EDGE_TYPE = torch.concat([PREMISE_EDGE_TYPE, PREMISE_TO_CONTEXT_EDGE_TYPE], dim=0)

config = {
    'num_premises': NUM_PREMISES,
    'num_relations': NUM_RELATIONS,
    'embedding_dim': 64,
    'n_layers': 2,
    'message_dropout': 0.1,
    'node_dropout': 0.1,
}

model = KGAT(config).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.01)
criterion = torch.nn.BCEWithLogitsLoss()
model.train()

def Rat10(scores : torch.Tensor, labels : torch.Tensor) -> float:
    recommended_indices = torch.topk(scores, k=10, dim=1).indices
    is_correct = labels[recommended_indices]
    return is_correct.sum().item() / torch.max(labels.sum(dim=1), 10).sum()

for epoch in range(10000):
    optimizer.zero_grad()
    scores = model.score(NUM_CONTEXTS, EDGE_INDEX, EDGE_TYPE)

    loss = criterion(scores, CONTEXT_LABELS)
    loss.backward()
    optimizer.step()
    R10 = Rat10(scores, CONTEXT_LABELS)
    print(f"Epoch {epoch}, Loss: {loss.item()}, R@10={R10}")
    


