In [70]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)
np.random.seed(42)

In [71]:
N = 59
dataset = np.zeros((2*N, N+1))

for i in range(N):
    dataset[i][N]=0
    dataset[i][i]=1
        
for i in range(N, 2*N):
    dataset[i][N]=1
    dataset[i][i-N]=1
            
labels = np.zeros((2*N, N))
for i in range(N):
    one_idx = (i+10)%N
    labels[i][one_idx]=1

for i in range(N, 2*N):
    one_idx = (i-N+20)%N
    labels[i][one_idx]=1
            
# first half of the dataset is one relation-0 and the second half on relation-1
# last number in input represents the relation number which is then convert to a N-dim vector by the model, so that finally we have e_h and e_r of same dims

# First half (clean relation-0 samples)
first_half_data = dataset[:N]
first_half_labels = labels[:N]

# Second half (relation-1 samples)
second_half_data = dataset[N:]
second_half_labels = labels[N:]

# Validation size from second half
val_size_second_half = 10  # pick whatever small number you want

# Shuffle the second half only
perm = np.random.permutation(N)
second_half_data = second_half_data[perm]
second_half_labels = second_half_labels[perm]

# Split second half
val_data = second_half_data[:val_size_second_half]
val_labels = second_half_labels[:val_size_second_half]

train_data = np.concatenate([
    first_half_data,              # ALL first half goes to train
    second_half_data[val_size_second_half:]   # the rest of second half
], axis=0)

train_labels = np.concatenate([
    first_half_labels,
    second_half_labels[val_size_second_half:]
], axis=0)

# (Optional) Shuffle train set only
perm_train = np.random.permutation(train_data.shape[0])
train_data = train_data[perm_train]
train_labels = train_labels[perm_train]

# Convert to tensors
train_data = torch.from_numpy(train_data).float()
train_labels = torch.from_numpy(train_labels).float()
val_data = torch.from_numpy(val_data).float()
val_labels = torch.from_numpy(val_labels).float()

In [72]:
class BilinearMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.entity_embed = nn.Embedding(N, input_size)  # NEW: embeddings for numbers 0..N-1
        self.relation_embed = nn.Embedding(2, input_size)  # same as your old embed
        
        self.bl = nn.Bilinear(input_size, input_size, hidden_size, bias=False)
        self.lin = nn.Linear(hidden_size, output_size, bias=False)

    def forward(self, x):
        # x[:, :N] is 1-hot for entity
        entity_idx = torch.argmax(x[:, :N], dim=1).long()
        rel_idx = x[:, N].long()

        e_h = self.entity_embed(entity_idx)       # (batch, input_size)
        e_r = self.relation_embed(rel_idx)        # (batch, input_size)

        h = self.bl(e_h, e_r)
        logits = self.lin(h)
        return logits


In [73]:
# def apply_scaled_default_init(model: nn.Module, scale: float = 5.0):
#     """
#     Apply PyTorch defaults (Kaiming/Xavier where appropriate) then multiply weights by `scale`.
#     This preserves relative structure but makes them larger.
#     """
#     def _init(m):
#         if isinstance(m, nn.Embedding):
#             nn.init.normal_(m.weight, mean=0.0, std=1.0)
#             m.weight.data.mul_(scale)
#         elif isinstance(m, nn.Bilinear):
#             # use kaiming for bilinear weights flattened -> still ok to use normal then scale
#             nn.init.kaiming_normal_(m.weight.view(m.weight.size(0), -1), a=0, mode='fan_in', nonlinearity='linear')
#             m.weight.data.mul_(scale)
#             if m.bias is not None:
#                 nn.init.zeros_(m.bias)
#         elif isinstance(m, nn.Linear):
#             nn.init.kaiming_normal_(m.weight, a=0, mode='fan_in', nonlinearity='linear')
#             m.weight.data.mul_(scale)
#             if m.bias is not None:
#                 nn.init.zeros_(m.bias)
#     model.apply(_init)

In [74]:
def train(model, train_data, train_labels, val_data, val_labels, epochs=100, batch_size=32, lr=0.003):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.01)
    loss_fn = torch.nn.CrossEntropyLoss()
    
    train_loss_values = []
    val_loss_values = []
    train_acc_values = []
    val_acc_values = []
    
    running_train_loss = 0
    print(epochs)
    for epoch in range(epochs):
        model.train()
        correct_train_preds = 0
        total_train_preds = 0
        for batch in range(0, len(train_data), batch_size):
            optimizer.zero_grad()
            output = model(train_data[batch:batch+batch_size])
            loss = loss_fn(output, torch.argmax(train_labels[batch:batch+batch_size], axis=1)) 
            running_train_loss += loss.item()
            preds = torch.argmax(output, axis=1)
            correct_train_preds += (preds == torch.argmax(train_labels[batch:batch+batch_size], axis=1)).sum().item()
            total_train_preds += len(preds)
            loss.backward()
            optimizer.step()
        model.eval()


        output = model(val_data)
        val_loss = loss_fn(output, torch.argmax(val_labels, axis=1)).item()
        val_preds = torch.argmax(output, axis=1)
        correct_val_preds = (val_preds == torch.argmax(val_labels, axis=1)).sum().item()
        total_val_preds = len(val_preds)
        avg_train_loss = running_train_loss / (len(train_data) / batch_size)
        train_acc = correct_train_preds / total_train_preds
        val_acc = correct_val_preds / total_val_preds
        train_loss_values.append(avg_train_loss)
        val_loss_values.append(val_loss)
        train_acc_values.append(train_acc)
        val_acc_values.append(val_acc)

        print("Epoch: {} | Train loss: {:.2f} | Validation loss: {:.2f} | Train accuracy: {:.2f} | Validation accuracy: {:.2f}".format(epoch, avg_train_loss, val_loss, train_acc, val_acc))

        running_train_loss = 0
    return model, train_loss_values, val_loss_values, train_acc_values, val_acc_values

In [75]:
model = BilinearMLP(N, 100, N)
# apply_scaled_default_init(model, scale=20.0)
model, train_loss_values, val_loss_values, train_acc_values, val_acc_values = train(model, train_data, train_labels, val_data, val_labels, epochs=1000, batch_size=16, lr=0.003)

1000
Epoch: 0 | Train loss: 6.92 | Validation loss: 9.02 | Train accuracy: 0.03 | Validation accuracy: 0.00
Epoch: 1 | Train loss: 0.06 | Validation loss: 10.24 | Train accuracy: 0.99 | Validation accuracy: 0.00
Epoch: 2 | Train loss: 0.00 | Validation loss: 10.66 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 3 | Train loss: 0.00 | Validation loss: 10.58 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 4 | Train loss: 0.00 | Validation loss: 10.25 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 5 | Train loss: 0.00 | Validation loss: 9.80 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 6 | Train loss: 0.00 | Validation loss: 9.29 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 7 | Train loss: 0.00 | Validation loss: 8.77 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 8 | Train loss: 0.00 | Validation loss: 8.27 | Train accuracy: 1.00 | Validation accuracy: 0.00
Epoch: 9 | Train loss: 0.00 | Validation loss: 7.78 | Train accu

KeyboardInterrupt: 