Skip to content

Commit

Permalink
feat: updated test case for nan_to_num to test nans and floats properly
Browse files Browse the repository at this point in the history
  • Loading branch information
k223kim committed Apr 22, 2024
1 parent 236dd78 commit cc0ef39
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 10 deletions.
53 changes: 53 additions & 0 deletions mse_loss_kaeun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import thunder
import numpy as np

reduction = "mean"
def max(input):
output = torch.finfo(input.dtype).max
return output

cfn = thunder.jit(max)
input =torch.randn(3, 5).type(torch.float64)
output = cfn(input)


def mse(input, target):
output = torch.nn.functional.mse_loss(input, target, reduction=reduction)
return output

def mse_thunder(input, target):
output = thunder.torch.mse_loss(input, target, reduction=reduction)
return output

input = torch.randn(3, 5, requires_grad=True).type(torch.float64)
target = torch.randn(3, 5).type(torch.float64)

cfn = thunder.jit(mse)
actual_loss = cfn(input, target)


# actual_loss = cfn(input, target)
actual_loss.sum().backward()
thunder_grad = input.retain_grad()
input.grad = None

expected_loss = mse(input, target)
input_grad = torch.ones_like(expected_loss)
answer_grad = torch.ops.aten.mse_loss_backward(actual_loss, input, target, 1)
expected_loss.sum().backward()
pytorch_grad = input.grad

torch.testing.assert_close(thunder_grad, pytorch_grad)

traces = thunder.last_traces(cfn)

grad_jfn = thunder.core.transforms.grad(cfn)
actual_grad, = grad_jfn(input, target)

expected_loss = torch.nn.functional.mse_loss(input, target, reduction = reduction)
go = torch.ones_like(expected_loss)
expected_grad, = torch.autograd.grad(torch.nn.functional.mse_loss(input, target, reduction=reduction), input, go)

print("Max error in loss:", (actual_loss - expected_loss).abs().max().item())
print("Max error in logits grad:", (actual_grad - expected_grad).abs().max().item())
25 changes: 15 additions & 10 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -2496,19 +2496,25 @@ def where_sample_generator(op, device, dtype, requires_grad, **kwargs):
def nan_to_num_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)

a = make((4, 4), dtype=dtype, requires_grad=requires_grad)
if dtype == torch.FloatType or torch.IntType:
a = torch.tensor((0, float("nan"), float("inf"), -float("inf")))
elif dtype == torch.ComplexType:
a = torch.tensor((complex(0, 0), complex(float("nan"), float("nan")), complex(float("inf"), -float("inf"))))
# shapes, nan, posinf, neginf
cases = (
((2, 1, 2), None, None, None),
((4, 4), None, 1.0, None),
((4, 4), None, None, 1.0),
((5,), None, 1.0, 0.0),
((8, 1, 6), 1, None, None),
((8, 7, 5, 1), 1, 1.0, None),
((8, 7, 5, 1), 1, None, 0.0),
((8, 7, 5, 1), 1, 1.0, 0.0),
(a, None, None, None),
(a, None, 1.0, None),
(a, None, None, 1.0),
(a, None, 1.0, 0.0),
(a, 1, None, None),
(a, 1, 1.0, None),
(a, 1, None, 0.0),
(a, 1, 1.0, 0.0),
)

for a_shape, nan, posinf, neginf in cases:
yield SampleInput(make(a_shape, dtype=dtype, requires_grad=requires_grad), nan, posinf, neginf)
yield SampleInput(a, nan, posinf, neginf)


def nan_to_num_error_generator(op, device, dtype=torch.float32, **kwargs):
Expand All @@ -2533,7 +2539,6 @@ def nan_to_num_error_generator(op, device, dtype=torch.float32, **kwargs):

nan_to_num_opinfo = OpInfo(
ltorch.nan_to_num,
supports_grad=True,
sample_input_generator=nan_to_num_sample_generator,
error_input_generator=nan_to_num_error_generator,
torch_reference=torch.nan_to_num,
Expand Down
27 changes: 27 additions & 0 deletions where_kaeun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import thunder
import torch

# def foo(x):
# x = torch.Tensor.where(x)
# return x

# x = torch.randn(3, device='cuda')
x = torch.tensor([1, 1, 1, 0, 1])
# print(x)
# jit_foo = thunder.jit(foo)
# o = jit_foo(x)

# print(thunder.last_traces(jit_foo)[-1])
# print(f"output: {o}")

x = torch.randn(3, 2)
y = torch.ones(3, 2)

def foo(x):
return torch.where(x)

def bar(x, y):
return torch.where(x > 0, x, y)

jit_foo = thunder.jit(foo)
o = jit_foo(x)
105 changes: 105 additions & 0 deletions wrap_kaeun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@

import torch
import thunder.torch as ltorch
import thunder
from functools import partial, wraps
import inspect
'''
@wraps(torch.Tensor.is_cuda) is equivalent to partial(update_wrapper, wrapped=torch.Tensor.is_cuda)
'''
'''
@wraps(torch.Tensor.is_cuda)
def my_func(x: torch.Tensor) :
means
torch.Tensor.is_cuda(my_func)
'''

# def first_decorator(func): #same as @wrap(first_decorator)
# def wrapper(x):
# print(f"wrapper: {x}")
# ret = func(x)
# return ret
# return wrapper

class myClass():
def __init__(self):
self.x = ""

def add(self, x):
"""
is this the doc?
"""
print('myClass')
return str(x)

def subtract(self, x):
return x - 10

@wraps(myClass.add)
def fn_with_wrap(x):
print(f"fn_with_wrap: {x}")
return x

def fn_no_wrap(x):
print(f"fn_no_wrap: {x}")
return x

def tmp_func(x):
print(f"tmp function: {x}")
return x


def my_decorator(func):#func = torch.Tensor.is_cuda
@wraps(func)
def wrapper(x: torch.Tensor):
print('decorator')
ret = func(x)
return ret
return wrapper


# def bigger_func():
# def my_decorator(func):#func = torch.Tensor.is_cuda
# @wraps(func)
# def wrapper(x: torch.Tensor):
# print('decorator')
# ret = func(x)
# return ret
# return wrapper

# my_decorator(torch.Tensor.is_cuda)

@wraps(torch.Tensor.is_cuda)
def my_func(x: torch.Tensor) -> bool:
print((x.__dict__))
return torch.abs(x)

@wraps(ltorch.is_cuda)
def my_func2(x: torch.Tensor) -> bool:
print((x.__dict__))
return torch.abs(x)

def my_func3(x: torch.Tensor) -> bool:
print((x.__dict__))
return torch.abs(x)

@my_decorator
def my_func4(x: torch.Tensor) -> bool:
print((x.__dict__))
return torch.abs(x)

jit_tmp = thunder.functional.jit(my_func)
# @my_decorator
# def my_func2(x):
# print('hi')
# return torch.abs(x)

x = torch.tensor([1, 1, 1, 0, 1])
inspect.signature()
# my_func2(x)
# print('')

import numpy as np
np.abs

0 comments on commit cc0ef39

Please sign in to comment.