# Usage Guide: Automating PyTorch Loss Functions with Rubick on MNIST dataset

### Importing necessary libraries

Here we import `Rubick` class

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from rubick_v6 import Rubick

### Dataset Preparation

In [10]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

### Model Definition

In [11]:
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # flatten the image
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.fc3(x)
        return x

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleNN().to(device)

### Generating loss function

We have defined the neural network architecture and prepared the data in the above code cells. Now we have to define the loss function based on which the model will be evaluated on in the training process.

We choose the `CodeLlama-7b-Instruct-hf` model as it performs well in coding and also while following instructions. 

As you can see in the output below, the model fails to generate a valid loss function in the first loop - the loss function fails the unit test on all three attempts. 

In the second loop, the model generates a valid loss function in the first attempt

In [6]:
model_id = "codellama/CodeLlama-7b-Instruct-hf"
token = "NONE"
prompt = "The task is to classify images present in MNIST dataset"

generator = Rubick(model_id, token, prompt)
generator.process_start()

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Starting loss function generation process
Here is initial code generated for loop:  0

Loss function code:
  import torch
import torch.nn as nn
import torch.nn.functional as F

class AutoLoss(nn.Module):
    def __init__(self):
        super(AutoLoss, self).__init__()

    def forward(self, y_pred, y_true):
        # Compute the loss
        loss = F.cross_entropy(y_pred, y_true)

        return loss

Test function code:
 from temp_code import AutoLoss

import unittest

import torch
import torch.nn as nn
import torch.nn.functional as F

class AutoLossTest(unittest.TestCase):
    def test_auto_loss_forward(self):
        # Define the input and output tensors
        y_pred = torch.randn(10, 10)
        y_true = torch.randint(0, 10, (10,))

        # Instantiate the loss function
        loss_fn = AutoLoss()

        # Forward pass
        loss = loss_fn(y_pred, y_true)

        # Check if the output is a tensor
        self.assertTrue(torch.is_tensor(loss))

        # Check if the outpu

### Defining the loss function

Here we assign the generated loss function `AutoLoss` to the variable `criterion` which will then be used for the rest of the training phase.

In [7]:
criterion = generator.AutoLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

### Training Loop

In [14]:
for epoch in range(5):  # few epochs for quick test
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/5], Loss: {running_loss/len(train_loader):.4f}")

Epoch [1/5], Loss: 0.3972
Epoch [2/5], Loss: 0.1888
Epoch [3/5], Loss: 0.1356
Epoch [4/5], Loss: 0.1087
Epoch [5/5], Loss: 0.0909


### Model Evaluation

In [15]:
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Test Accuracy: {100 * correct / total:.2f}%")

Test Accuracy: 96.86%
