In [1]:
from PPairS.datasets import IPCCDataset
from PPairS.trainer import Trainer

import torch as t
from torch import nn, Tensor
import torch.nn.functional as F

from sklearn.linear_model import SGDClassifier

from jaxtyping import Float
from tqdm.notebook import tqdm

In [2]:
trainer = Trainer(
    dataset_class = IPCCDataset,
    dataset_kwargs = {"contrast": False},
    splits = (0.8, 0.9),
    batch_size = 64
)

caching data: 100%|██████████| 29/29 [01:36<00:00,  3.32s/it]


In [16]:
class FeatureDirection(nn.Module):

    def __init__(
        self,
        d_model: int=4096,
        device: str="cuda"
    ):
        super(FeatureDirection, self).__init__()
        self.d_model = d_model
        self.device = device

        init_fd = t.randn(self.d_model, device=self.device)
        self.fd = nn.Parameter(data=init_fd)
        self.register_parameter(name="fd", param=self.fd)
        self.fd.data = self.fd / t.linalg.vector_norm(self.fd)

    def forward(
            self,
            x: Float[Tensor, "batch d_model"]
    ) -> Float[Tensor, "batch"]:
        return x @ self.fd
    
    def unit(self):
        with t.no_grad(): self.fd.data = self.fd / t.linalg.vector_norm(self.fd)

    def orthogonalize(self, x: Float[Tensor, "batch d_model"]) -> Float[Tensor, "batch d_model"]:
        proj = self.forward(x)
        orthog = x - t.outer(proj, self.fd)
        return orthog

In [20]:
n_epoch = 10
belief_direction = FeatureDirection()
orthog_direction = FeatureDirection()
opt_b = t.optim.SGD(belief_direction.parameters(), lr=0.01, weight_decay=0.001)
opt_o = t.optim.SGD(orthog_direction.parameters(), lr=0.01, weight_decay=0.001)
desc = "epoch 1"
for epoch in range(n_epoch):
    bar = tqdm(trainer.train_loader)
    bar.set_description(desc)
    for batch, labels in bar:
        # ----------
        # update belief direction
        # ----------
        opt_b.zero_grad()
        logits = belief_direction(batch)
        L_classify = F.binary_cross_entropy_with_logits(logits, (labels > 0.5).float())
        orthog = belief_direction.orthogonalize(batch)
        logits = orthog_direction(orthog)
        L_inhibit = -F.binary_cross_entropy_with_logits(logits, (labels > 0.5).float())
        L = L_classify + L_inhibit
        L.backward()
        opt_b.step()
        belief_direction.unit()
        # ----------
        # update probe on orthogonalized data
        # ----------
        opt_o.zero_grad()
        logits = orthog_direction(orthog.detach())
        L = F.binary_cross_entropy_with_logits(logits, (labels > 0.5).float())
        L.backward()
        opt_o.step()
        orthog_direction.unit()

    # validation loss
    accuracy_b, accuracy_o = [], []
    for batch, labels in trainer.val_loader:
        with t.no_grad():
            logits_b = belief_direction(batch)
            orthog = belief_direction.orthogonalize(batch)
            logits_o = orthog_direction(orthog)
            P_b = F.sigmoid(logits_b)
            P_o = F.sigmoid(logits_o)
            accuracy_b.append(((P_b>0.5) == (labels>0.5)).float().mean())
            accuracy_o.append(((P_o>0.5) == (labels>0.5)).float().mean())
    accuracy_b = sum(accuracy_b) / len(accuracy_b)
    accuracy_o = sum(accuracy_o) / len(accuracy_o)
    desc = f"epoch {epoch+2} ({round(accuracy_b.item(), 3)} / {round(accuracy_o.item(), 3)})"

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

In [7]:
# n_epoch = 10
# belief_direction = FeatureDirection()
# opt = t.optim.SGD(belief_direction.parameters(), lr=0.01, weight_decay=0.001)
# desc = "epoch 1"
# for epoch in range(n_epoch):
#     bar = tqdm(trainer.train_loader)
#     bar.set_description(desc)
#     for batch, labels in bar:
#         logits = belief_direction(batch)
#         loss = F.binary_cross_entropy_with_logits(logits, (labels > 0.5).float())
#         opt.zero_grad()
#         loss.backward()
#         opt.step()
#         belief_direction.unit()

#     # validation loss
#     accuracy = []
#     for batch, labels in trainer.val_loader:
#         with t.no_grad():
#             logits = belief_direction(batch)
#             P = F.sigmoid(logits)
#             accuracy.append(((P>0.5) == (labels>0.5)).float().mean())
#     accuracy = sum(accuracy) / len(accuracy)
#     desc = f"epoch {epoch+2} ({round(accuracy.item(), 3)})"

# # test loss
# accuracy = []
# for batch, labels in trainer.test_loader:
#     with t.no_grad():
#         logits = belief_direction(batch)
#         P = F.sigmoid(logits)
#         accuracy.append(((P>0.5) == (labels>0.5)).float().mean())
# accuracy = sum(accuracy) / len(accuracy)
# print(f"test accuracy: {round(accuracy.item(), 3)}")

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

  0%|          | 0/3376 [00:00<?, ?it/s]

test accuracy: 0.792
