In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

# 假设我们有一些数据
class KGDataset(Dataset):
    def __init__(self, triples):
        self.triples = triples

    def __len__(self):
        return len(self.triples)

    def __getitem__(self, idx):
        return self.triples[idx]

# 一个简单的编码器示例
class Encoder(nn.Module):
    def __init__(self, embedding_dim):
        super(Encoder, self).__init__()
        self.embedding = nn.Embedding(1000, embedding_dim)
        self.fc = nn.Linear(embedding_dim, embedding_dim)

    def forward(self, x):
        x = self.embedding(x)
        x = torch.relu(self.fc(x))
        return x

# CAGED模型
class CAGED(nn.Module):
    def __init__(self, embedding_dim):
        super(CAGED, self).__init__()
        self.encoder = Encoder(embedding_dim)
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, pos_triples, neg_triples):
        pos_h, pos_r, pos_t = pos_triples[:, 0], pos_triples[:, 1], pos_triples[:, 2]
        neg_h, neg_r, neg_t = neg_triples[:, 0], neg_triples[:, 1], neg_triples[:, 2]

        pos_h_emb = self.encoder(pos_h)
        pos_r_emb = self.encoder(pos_r)
        pos_t_emb = self.encoder(pos_t)

        neg_h_emb = self.encoder(neg_h)
        neg_r_emb = self.encoder(neg_r)
        neg_t_emb = self.encoder(neg_t)

        pos_score = torch.sum(pos_h_emb * pos_r_emb * pos_t_emb, dim=1)
        neg_score = torch.sum(neg_h_emb * neg_r_emb * neg_t_emb, dim=1)

        scores = torch.cat([pos_score, neg_score], dim=0)
        labels = torch.cat([torch.ones(pos_score.size(0)), torch.zeros(neg_score.size(0))], dim=0).long()

        return scores, labels

    def compute_loss(self, scores, labels):
        return self.criterion(scores, labels)

# 数据和训练参数
triples = torch.tensor([
    [1, 2, 3],
    [4, 5, 6],
    # 添加更多的三元组
])

dataset = KGDataset(triples)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# 创建模型、优化器
embedding_dim = 50
model = CAGED(embedding_dim)
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 训练循环
for epoch in range(10):
    for pos_triples in dataloader:
        neg_triples = pos_triples[torch.randperm(pos_triples.size(0))]
        
        optimizer.zero_grad()
        scores, labels = model(pos_triples, neg_triples)
        loss = model.compute_loss(scores, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch}, Loss: {loss.item()}")
