forked from Lightning-AI/lightning-thunder
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: updated test case for nan_to_num to test nans and floats properly
- Loading branch information
Showing
4 changed files
with
200 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import thunder | ||
import numpy as np | ||
|
||
reduction = "mean" | ||
def max(input): | ||
output = torch.finfo(input.dtype).max | ||
return output | ||
|
||
cfn = thunder.jit(max) | ||
input =torch.randn(3, 5).type(torch.float64) | ||
output = cfn(input) | ||
|
||
|
||
def mse(input, target): | ||
output = torch.nn.functional.mse_loss(input, target, reduction=reduction) | ||
return output | ||
|
||
def mse_thunder(input, target): | ||
output = thunder.torch.mse_loss(input, target, reduction=reduction) | ||
return output | ||
|
||
input = torch.randn(3, 5, requires_grad=True).type(torch.float64) | ||
target = torch.randn(3, 5).type(torch.float64) | ||
|
||
cfn = thunder.jit(mse) | ||
actual_loss = cfn(input, target) | ||
|
||
|
||
# actual_loss = cfn(input, target) | ||
actual_loss.sum().backward() | ||
thunder_grad = input.retain_grad() | ||
input.grad = None | ||
|
||
expected_loss = mse(input, target) | ||
input_grad = torch.ones_like(expected_loss) | ||
answer_grad = torch.ops.aten.mse_loss_backward(actual_loss, input, target, 1) | ||
expected_loss.sum().backward() | ||
pytorch_grad = input.grad | ||
|
||
torch.testing.assert_close(thunder_grad, pytorch_grad) | ||
|
||
traces = thunder.last_traces(cfn) | ||
|
||
grad_jfn = thunder.core.transforms.grad(cfn) | ||
actual_grad, = grad_jfn(input, target) | ||
|
||
expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction) | ||
go = torch.ones_like(expected_loss) | ||
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go) | ||
|
||
print("Max error in loss:", (actual_loss - expected_loss).abs().max().item()) | ||
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import thunder | ||
import torch | ||
|
||
# def foo(x): | ||
# x = torch.Tensor.where(x) | ||
# return x | ||
|
||
# x = torch.randn(3, device='cuda') | ||
x = torch.tensor([1, 1, 1, 0, 1]) | ||
# print(x) | ||
# jit_foo = thunder.jit(foo) | ||
# o = jit_foo(x) | ||
|
||
# print(thunder.last_traces(jit_foo)[-1]) | ||
# print(f"output: {o}") | ||
|
||
x = torch.randn(3, 2) | ||
y = torch.ones(3, 2) | ||
|
||
def foo(x): | ||
return torch.where(x) | ||
|
||
def bar(x, y): | ||
return torch.where(x > 0, x, y) | ||
|
||
jit_foo = thunder.jit(foo) | ||
o = jit_foo(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
|
||
import torch | ||
import thunder.torch as ltorch | ||
import thunder | ||
from functools import partial, wraps | ||
import inspect | ||
''' | ||
@wraps(torch.Tensor.is_cuda) is equivalent to partial(update_wrapper, wrapped=torch.Tensor.is_cuda) | ||
''' | ||
''' | ||
@wraps(torch.Tensor.is_cuda) | ||
def my_func(x: torch.Tensor) : | ||
means | ||
torch.Tensor.is_cuda(my_func) | ||
''' | ||
|
||
# def first_decorator(func): #same as @wrap(first_decorator) | ||
# def wrapper(x): | ||
# print(f"wrapper: {x}") | ||
# ret = func(x) | ||
# return ret | ||
# return wrapper | ||
|
||
class myClass(): | ||
def __init__(self): | ||
self.x = "" | ||
|
||
def add(self, x): | ||
""" | ||
is this the doc? | ||
""" | ||
print('myClass') | ||
return str(x) | ||
|
||
def subtract(self, x): | ||
return x - 10 | ||
|
||
@wraps(myClass.add) | ||
def fn_with_wrap(x): | ||
print(f"fn_with_wrap: {x}") | ||
return x | ||
|
||
def fn_no_wrap(x): | ||
print(f"fn_no_wrap: {x}") | ||
return x | ||
|
||
def tmp_func(x): | ||
print(f"tmp function: {x}") | ||
return x | ||
|
||
|
||
def my_decorator(func):#func = torch.Tensor.is_cuda | ||
@wraps(func) | ||
def wrapper(x: torch.Tensor): | ||
print('decorator') | ||
ret = func(x) | ||
return ret | ||
return wrapper | ||
|
||
|
||
# def bigger_func(): | ||
# def my_decorator(func):#func = torch.Tensor.is_cuda | ||
# @wraps(func) | ||
# def wrapper(x: torch.Tensor): | ||
# print('decorator') | ||
# ret = func(x) | ||
# return ret | ||
# return wrapper | ||
|
||
# my_decorator(torch.Tensor.is_cuda) | ||
|
||
@wraps(torch.Tensor.is_cuda) | ||
def my_func(x: torch.Tensor) -> bool: | ||
print((x.__dict__)) | ||
return torch.abs(x) | ||
|
||
@wraps(ltorch.is_cuda) | ||
def my_func2(x: torch.Tensor) -> bool: | ||
print((x.__dict__)) | ||
return torch.abs(x) | ||
|
||
def my_func3(x: torch.Tensor) -> bool: | ||
print((x.__dict__)) | ||
return torch.abs(x) | ||
|
||
@my_decorator | ||
def my_func4(x: torch.Tensor) -> bool: | ||
print((x.__dict__)) | ||
return torch.abs(x) | ||
|
||
jit_tmp = thunder.functional.jit(my_func) | ||
# @my_decorator | ||
# def my_func2(x): | ||
# print('hi') | ||
# return torch.abs(x) | ||
|
||
x = torch.tensor([1, 1, 1, 0, 1]) | ||
inspect.signature() | ||
# my_func2(x) | ||
# print('') | ||
|
||
import numpy as np | ||
np.abs |