In [None]:
from src.configs.sverl import CharacteristicConfig
from src.sverl import Characteristic
from src.sverl.state_samplers import ProceduralSampler
from src.sverl.masking import ZeroMasker
from src.models import Model
from torch import nn
import copy
from itertools import product
import torch
from torch.types import Tensor

In [3]:
class TestModel(Model):
    def _construct_model(self):
        model = nn.Sequential(
            self._init_layer(nn.Linear(3, 16)),
            nn.ReLU(),
            self._init_layer(nn.Linear(16, 2)),
        )
        return model

In [30]:
# -- 1. Define 8 possible binary states
def generate_all_states():
    return Tensor(list(product([0, 1], repeat=3)))  # shape: (8, 3)

In [48]:
# -- 2. Define a deterministic policy (for example purposes)
class SimpleBinaryPolicy(nn.Module):
    def __init__(self):
        super().__init__()
        # We'll just hardcode a mapping: e.g., sum(state) > 1 → action 1 (yes), else 0 (no)
    
    def forward(self, x):
        # Rule-based: sum the features, output action 1 if sum > 1, else 0
        return torch.nn.functional.one_hot((x.sum(dim=-1) > 1).long(), 2).float()

In [49]:
policy = SimpleBinaryPolicy()
states = generate_all_states()

In [None]:
masker = ZeroMasker()

In [62]:
model = TestModel()
optimiser = torch.optim.Adam(model.parameters(), lr=0.00005)

states = generate_all_states()
targets = policy(states)
for epoch in range(5000):
    i = torch.randint(0, 8, (64, ))
    x = states[i]
    y = targets[i]
    
    mask = torch.rand_like(x) > 0.5
    x[mask] = 0

    logits = model(x)
    loss = torch.square(logits - y).mean()

    optimiser.zero_grad()
    loss.backward()
    optimiser.step()
    if epoch % 100 == 0:
        print(loss.item())

0.4430120885372162
0.424513578414917
0.4155327081680298
0.40287721157073975
0.3839949667453766
0.3637024164199829
0.49410679936408997
0.3590109050273895
0.3487154245376587
0.2918359935283661
0.32505398988723755
0.3515026867389679
0.35201072692871094
0.2721220850944519
0.3118796646595001
0.3076092600822449
0.3098236620426178
0.311847060918808
0.2846294939517975
0.25962144136428833
0.2846491038799286
0.3035305440425873
0.271392285823822
0.2596524953842163
0.308035671710968
0.2783072590827942
0.3013690114021301
0.24796490371227264
0.22790537774562836
0.23861238360404968
0.23913463950157166
0.24817289412021637
0.26149481534957886
0.22604233026504517
0.23798400163650513
0.23919571936130524
0.2218385487794876
0.22781449556350708
0.24834254384040833
0.24642255902290344
0.22528666257858276
0.23255057632923126
0.23461364209651947
0.21395862102508545
0.21034842729568481
0.19343635439872742
0.19429150223731995
0.2148958444595337
0.2197200357913971
0.19877982139587402


In [None]:
states = generate_all_states()
targets = policy(states)
total_loss = 0
for epoch in range(5000):
    with torch.no_grad():
        i = torch.randint(0, 8, (64, ))
        x = states[i]
        y = targets[i]
        
        mask = torch.rand_like(x) > 0.5
        x[mask] = 0

        logits = model(x)
        loss = torch.square(logits - y).mean()
        total_loss += loss
total_loss /= 5000

tensor(0.2067)

In [64]:
s = {'state': {'model': model.state_dict()}}
torch.save(s, 'test.pt')

In [66]:
s2 = torch.load('test.pt', weights_only=False)

In [74]:
model2 = TestModel()
model2.load_state_dict(s2['state']['model'])
print()




In [80]:
states = generate_all_states()
targets = policy(states)
total_loss = 0
for epoch in range(5000):
    with torch.no_grad():
        i = torch.randint(0, 8, (64, ))
        x = states[i]
        y = targets[i]
        
        mask = torch.rand_like(x) > 0.5
        x[mask] = 0

        logits = model2(x)
        loss = torch.square(logits - y).mean()
        total_loss += loss
total_loss /= 5000
print(total_loss)

tensor(0.2071)
