In [13]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

In [14]:
class ComparisonDataset(Dataset):
    def __init__(self, data_len):
        self.data_len = data_len

    def __len__(self):
        return self.data_len

    def __getitem__(self, index):
        x = random.random()
        y1 = random.random()
        y2 = random.random()
        if abs(y1 - x) > abs(y2 - x):
            chosen = torch.tensor([x, y2])
            rejected = torch.tensor([x, y1])
        else:
            chosen = torch.tensor([x, y1])
            rejected = torch.tensor([x, y2])
        
        return chosen, rejected

In [15]:
dataset = ComparisonDataset(200000)
data_loader = DataLoader(dataset, batch_size=40)

In [16]:
# MLP input: (x, y), output scalar reward
model = nn.Sequential(
    nn.Linear(2, 16),
    nn.ReLU(),
    nn.Linear(16, 16),
    nn.ReLU(),
    nn.Linear(16, 1)
)

In [17]:
optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

In [18]:
model.train()
for idx, (chosen_batch, rejected_batch) in enumerate(data_loader):
    optim.zero_grad()
    chosen_reward = model(chosen_batch)
    rejected_reward = model(rejected_batch)
    loss = (-torch.log(nn.functional.sigmoid(chosen_reward - rejected_reward))).mean()
    loss.backward()
    optim.step()
    if idx % 100 == 0:
        print(f"{idx} loss: {loss}")

0 loss: 0.6887573003768921
100 loss: 0.6224793195724487
200 loss: 0.5490990877151489
300 loss: 0.393026739358902
400 loss: 0.26217490434646606
500 loss: 0.20685604214668274
600 loss: 0.13266679644584656
700 loss: 0.13813342154026031
800 loss: 0.16788017749786377
900 loss: 0.10095753520727158
1000 loss: 0.10350020229816437
1100 loss: 0.09711901098489761
1200 loss: 0.12045450508594513
1300 loss: 0.06185797601938248
1400 loss: 0.0990157276391983
1500 loss: 0.13050368428230286
1600 loss: 0.11660881340503693
1700 loss: 0.07478947937488556
1800 loss: 0.09369362890720367
1900 loss: 0.04702315106987953
2000 loss: 0.06259207427501678
2100 loss: 0.06624438613653183
2200 loss: 0.04136281460523605
2300 loss: 0.024363122880458832
2400 loss: 0.03977678343653679
2500 loss: 0.085675910115242
2600 loss: 0.013221899047493935
2700 loss: 0.04432803764939308
2800 loss: 0.06046494096517563
2900 loss: 0.06497308611869812
3000 loss: 0.03227042406797409
3100 loss: 0.030278358608484268
3200 loss: 0.033947061747

In [19]:
model.eval()

Sequential(
  (0): Linear(in_features=2, out_features=16, bias=True)
  (1): ReLU()
  (2): Linear(in_features=16, out_features=16, bias=True)
  (3): ReLU()
  (4): Linear(in_features=16, out_features=1, bias=True)
)

In [20]:
demo_input = torch.tensor([0.5, 0.5])
print(model(demo_input))

tensor([76.9150], grad_fn=<AddBackward0>)


In [21]:
demo_input = torch.tensor([0.5, 0.51])
print(model(demo_input))

tensor([76.0057], grad_fn=<AddBackward0>)


In [22]:
demo_input = torch.tensor([0.5, 0.4])
print(model(demo_input))

tensor([68.4210], grad_fn=<AddBackward0>)


In [23]:
demo_input = torch.tensor([0.5, 0.6])
print(model(demo_input))

tensor([67.8225], grad_fn=<AddBackward0>)


In [24]:
demo_input = torch.tensor([0.5, 0.7])
print(model(demo_input))

tensor([58.7301], grad_fn=<AddBackward0>)
