In [3]:
import unittest
import torch
import torch.nn as nn

class MSELossWithMask(nn.Module):
    def __init__(self):
        super(MSELossWithMask, self).__init__()

    def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
        mask = (targets != 0).float()
        number_ratings = torch.clamp(mask.sum(), min=1.0)
        squared_error = (targets - inputs).pow(2)
        masked_error = (mask * squared_error).sum()
        loss = masked_error / number_ratings
        return loss

class TestMSELossWithMask(unittest.TestCase):
    def setUp(self):
        self.loss_fn = MSELossWithMask()

    def test_zero_loss(self):
        inputs = torch.tensor([[1.0, 2.0, 3.0]])
        targets = torch.tensor([[1.0, 2.0, 3.0]])
        loss = self.loss_fn(inputs, targets)
        self.assertAlmostEqual(loss.item(), 0.0, places=6)

    def test_nonzero_loss(self):
        inputs = torch.tensor([[1.0, 2.0, 3.0]])
        targets = torch.tensor([[2.0, 3.0, 4.0]])
        loss = self.loss_fn(inputs, targets)
        self.assertGreater(loss.item(), 0.0)

    def test_mask_functionality(self):
        inputs = torch.tensor([[1.0, 2.0, 3.0]])
        targets = torch.tensor([[0.0, 2.0, 3.0]])
        loss = self.loss_fn(inputs, targets)
        expected_loss = ((2.0 - 2.0)**2 + (3.0 - 3.0)**2) / 2
        self.assertAlmostEqual(loss.item(), expected_loss, places=6)

    def test_all_zeros_target(self):
        inputs = torch.tensor([[1.0, 2.0, 3.0]])
        targets = torch.tensor([[0.0, 0.0, 0.0]])
        loss = self.loss_fn(inputs, targets)
        self.assertAlmostEqual(loss.item(), 0.0, places=6)

    def test_multi_dimensional_input(self):
        inputs = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])
        targets = torch.tensor([[[1.0, 0.0], [0.0, 4.0]]])
        loss = self.loss_fn(inputs, targets)
        expected_loss = ((1.0 - 1.0)**2 + (4.0 - 4.0)**2) / 2
        self.assertAlmostEqual(loss.item(), expected_loss, places=6)

    def test_large_input(self):
        inputs = torch.rand(100, 100)
        targets = torch.rand(100, 100)
        loss = self.loss_fn(inputs, targets)
        self.assertIsInstance(loss.item(), float)

# Function to run all tests
def run_tests():
    suite = unittest.TestLoader().loadTestsFromTestCase(TestMSELossWithMask)
    runner = unittest.TextTestRunner(verbosity=2)
    runner.run(suite)

# Run this cell to execute the tests
run_tests()

test_all_zeros_target (__main__.TestMSELossWithMask.test_all_zeros_target) ... ok
test_large_input (__main__.TestMSELossWithMask.test_large_input) ... ok
test_mask_functionality (__main__.TestMSELossWithMask.test_mask_functionality) ... ok
test_multi_dimensional_input (__main__.TestMSELossWithMask.test_multi_dimensional_input) ... ok
test_nonzero_loss (__main__.TestMSELossWithMask.test_nonzero_loss) ... ok
test_zero_loss (__main__.TestMSELossWithMask.test_zero_loss) ... ok

----------------------------------------------------------------------
Ran 6 tests in 0.016s

OK
