TODO: answer [my question](https://stats.stackexchange.com/q/623900/337906) about
training contrastive models.

I also have another question about why there are separate input and output embedding
matrices. It's not clear to me why that's a great idea / that increase in complexity is
clearly a good thing.

In [1]:
import torch

In [2]:
class SkipGram(torch.nn.Module):
    """
    Currently shares the input and output embedding matrices. And doesn't downsample or
    negative sample. TODO: evaluate them.
    """
    def __init__(self, vocab_size: int, embedding_dim: int):
        super(self.__class__, self).__init__()
        self.embeddings = torch.nn.Embedding(vocab_size, embedding_dim)

    def forward(
        self, center_input_ids: torch.Tensor, neighbor_input_ids: torch.Tensor = None
    ) -> torch.Tensor:
        if neighbor_input_ids is None:
            # inference mode
            return self.embeddings(center_input_ids)

        # input size checks for training
        if len(center_input_ids.shape) != 1:
            raise ValueError("center_input_ids must be 1-D.")
        if len(neighbor_input_ids.shape) != 2:
            raise ValueError("neighbor_input_ids must be 2-D.")
        
        # (batch_size, embedding_dim) x (embedding_dim, vocab_size)
        # -> (batch_size, vocab_size)
        vocab_scores: torch.Tensor = (
            self.embeddings(center_input_ids) @ self.embeddings.weight.T
        )

        # center_input_ids is (batch_size, context_size), so
        # (batch_size, vocab_size).take_along_dim(neighbor_input_ids)
        # -> (batch_size, context_size)
        # (batch_size, vocab_size).logsumexp(dim=1)
        # -> (batch_size, 1)
        return torch.mean(
            vocab_scores.take_along_dim(neighbor_input_ids, dim=1)
            - vocab_scores.logsumexp(dim=1, keepdim=True)
        )

In [3]:
skip_gram = SkipGram(vocab_size=10, embedding_dim=2)
optimizer = torch.optim.SGD(
    skip_gram.parameters(), lr=1e-3, momentum=0.9, nesterov=True
)

In [4]:
skip_gram.train();

In [5]:
loss: torch.Tensor = skip_gram(
    center_input_ids=torch.tensor([1, 2]),
    neighbor_input_ids=torch.tensor([[0, 2, 3], [1, 2, 3]]),
)
loss

tensor(-3.1747, grad_fn=<MeanBackward0>)

In [6]:
optimizer.zero_grad()
loss.backward()
optimizer.step()