Skip to content

jjcmoon/semiring_torch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

11 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

semiring_torch

Run your pytorch code on any semiring with a single line of code! Semiring_torch is built on top of autoray.

Warning: this is a proof of concept. Expect bugs and missing features.

Installation

pip install semiring-torch

Example

By using the logarithmic semiring, you can easily write numerically stable code. In the following example, we compute a matrix product in log-space.

Regular torch semiring_torch
import torch

x1 = torch.tensor([[0.1, 0.6], [0.1, 0.4]])
x2 = torch.tensor([[0.5, 0.3], [0.2, 0.1]])
x1 = x1.log()
x2 = x2.log()
result = x1[:, :, None] + x2[None, :, :]
result = torch.logsumexp(result, dim=1)
result = result.exp()
from autoray import numpy as torch
from semiring_torch import logarithmic_semiring

with logarithmic_semiring:
    x1 = torch.tensor([[0.1, 0.6], [0.1, 0.4]])
    x2 = torch.tensor([[0.5, 0.3], [0.2, 0.1]])
    result = x1 @ x2

Usage

Semiring_torch only supports a subset of torch operations. In general, methods that only rely on additions and multiplications are supported (e.g. torch.trace or torch.dot), as well as methods that only affect the shape of the tensor (e.g. x.reshape or x.transpose). Functionality that goes beyond a semiring (e.g. torch.sin) will raise a runtime error.

Currently only the logarithmic semiring is supported, but more semirings can be added easily.

About

Semiring programming in PyTorch

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages