In [1]:
%env WANDB_NOTEBOOK_NAME=toy_counter.ipynb
%env CUDA_VISIBLE_DEVICES=1

env: WANDB_NOTEBOOK_NAME=toy_counter.ipynb
env: CUDA_VISIBLE_DEVICES=1


In [2]:
import torch
torch.set_float32_matmul_precision('high')
import wandb

In [3]:
class CountingTransformer(torch.nn.Module):
    def __init__(self, n_tokens, d_embedding, n_outputs, embedding_weights=None):
        super().__init__()
        self.embedding = torch.nn.Embedding(n_tokens, d_embedding, _weight=embedding_weights)
        encoder_layer = torch.nn.TransformerEncoderLayer(
            d_embedding, dim_feedforward=16, nhead=4, batch_first=True)
        self.encoder = torch.nn.TransformerEncoder(encoder_layer, num_layers=2, enable_nested_tensor=False)
        self.head = torch.nn.Linear(d_embedding, n_outputs)
    
    def forward(self, x):
        x = self.embedding(x)
        x = self.encoder(x)
        x = self.head(x.max(dim=1).values)
        return x

In [4]:
def batched_bincount(x, dim, max_value, dtype=None):
    if dtype is None:
        dtype = x.dtype
    target = torch.zeros(x.shape[0], max_value, dtype=dtype, device=x.device)
    values = torch.ones(x.size(), dtype=dtype, device=x.device)
    target.scatter_add_(dim, x, values)
    return target

In [5]:
from tqdm.notebook import trange

In [6]:
#  torch.nn.init.orthogonal_(tensor, gain=1, generator=None)[source]
#  torch.nn.init.sparse_(tensor, sparsity, std=0.01, generator=None)[source]

In [7]:
device = torch.device('cuda')
n_tokens = 16
data = torch.randint(0, n_tokens, (1024, 16), device=device)
target_counts = batched_bincount(data, 1, n_tokens).float()
target_uniques = batched_bincount(data, 1, n_tokens, dtype=torch.bool).sum(dim=1)

In [8]:
min_target = target_uniques.min()
target_uniques_CE = target_uniques - min_target
max_targets = int(target_uniques_CE.max())
print(max_targets)

7


In [9]:
d_embedding = n_tokens
weight = None
# weight = torch.eye(n_tokens)
# weight = torch.empty(n_tokens, d_embedding, dtype=torch.float32)
# torch.nn.init.orthogonal_(weight, gain=1)
# torch.nn.init.sparse_(weight, sparsity=0.5)
training_loss = "CrossEntropyLoss"
model = CountingTransformer(n_tokens, d_embedding, n_outputs=max_targets + 1, embedding_weights=weight).to(device)
#model = CountingTransformer(n_tokens, d_embedding, n_outputs=1, embedding_weights=weight).to(device)
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
# /home/kna/.cache/pypoetry/virtualenvs/wyckofftransformer-FeCwefly-py3.10/lib/python3.10/site-packages/torch/autograd/graph.py:767: UserWarning: aten::layer_norm: an autograd kernel was not registered to the Autograd key(s) but we are trying to backprop through it. This may lead to silently incorrect behavior. This behavior is deprecated and will be removed in a future version of PyTorch. If your operator is differentiable, please ensure you have registered an autograd kernel to the correct Autograd key (e.g. DispatchKey::Autograd, DispatchKey::CompositeImplicitAutograd). If your operator is not differentiable, or to squash this warning and use the previous behavior, please register torch::CppFunction::makeFallthrough() to DispatchKey::Autograd. (Triggered internally at /home/kna/WyckoffTransformer/pytorch/torch/csrc/autograd/autograd_not_implemented_fallback.cpp:63.)
# import intel_extension_for_pytorch as ipex
# model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.float32)
# model = torch.compile(model, backend="ipex", fullgraph=True)

model = torch.compile(model, fullgraph=True)
loss_fn = torch.nn.CrossEntropyLoss(reduction="mean")
#loss_fn = torch.nn.MSELoss(reduction="mean")
loss_name = f"loss_{type(loss_fn).__name__}"
loss_fn = torch.compile(loss_fn, fullgraph=True)

@torch.compile(fullgraph=True)
def get_mse_losses(scores, target):
    expected = torch.softmax(scores, dim=1).matmul(torch.arange(scores.size(1), device=scores.device, dtype=output.dtype))
    max_likelihood = scores.argmax(dim=1).float()
    mse_expected = torch.nn.functional.mse_loss(expected, target)
    mse_likelihood = torch.nn.functional.mse_loss(max_likelihood, target)
    return mse_expected, mse_likelihood

with wandb.init(project="toy_counter", config={"training_loss": training_loss}) as run:
    for i in trange(100000):
        optimizer.zero_grad(set_to_none=True)
        output = model(data)
        loss = loss_fn(output, target_uniques_CE)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
        with torch.no_grad():
            mse, mse_likelihood = get_mse_losses(output, target_uniques_CE)
            run.log({"loss_MSELoss": mse, "loss_MSELoss_likelihood": mse_likelihood}, commit=False)
        optimizer.step()
        run.log({loss_name: loss.item()})

[34m[1mwandb[0m: Currently logged in as: [33mkazeev[0m ([33msymmetry-advantage[0m). Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/100000 [00:00<?, ?it/s]