# Usage guide: Loss function generation and error correction using Rubick

This notebook demonstrates the usage of the `Rubick` class to automatically generate a PyTorch-compatible loss function for a simple classification task on the MNIST dataset.

Rather than showcasing a full end-to-end training pipeline, the focus here is on evaluating Rubick's ability to:

- Generate an initial loss function based on a user-defined prompt ("The task is to classify images present in MNIST dataset")
- Identify and respond to a runtime error during unit testing
- Correct the loss function in the second loop using its internal logic
- Produce a final, valid loss function that passes the unit test

This example highlights Rubick's core capability: automated, iterative debugging and correction of AI-generated code — particularly useful for streamlining model prototyping and experiment setup.

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from rubick import Rubick

In [2]:
model_id = "codellama/CodeLlama-7b-Instruct-hf"
token = "None"
prompt = "The task is to classify images present in MNIST dataset"

generator = Rubick(model_id, token, prompt)
generator.process_start()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Starting loss function generation process
import torch
import torch.nn as nn
import torch.nn.functional as F

class AutoLoss(nn.Module):
    def __init__(self):
        super(AutoLoss, self).__init__()

    def forward(self, input, target):
        return F.cross_entropy(input, target)
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F

class TestAutoLoss(unittest.TestCase):
    def setUp(self):
        self.model = AutoLoss()

    def test_forward(self):
        input = torch.randn(3, 5)
        target = torch.randint(0, 5, (3,))
        output = self.model(input, target)
        self.assertTrue(isinstance(output, torch.Tensor))
        self.assertTrue(output.shape == (3,))

if __name__ == '__main__':
    unittest.main()
Here is initial code generated for loop:  0
Loss function code: import torch
import torch.nn as nn
import torch.nn.functional as F

class AutoLoss(nn.Module):
    def __init__(self):
        super(AutoLoss, self).__init__()

    def for