# Use Remote Granite Code Models (20B) with LangChain for Unit Test Code Generation

## Introduction and Setup
 
This recipe demonstrates how to generate unit tests for Python functions and classes using inference calls against a model hosted remotely on [Replicate](https://replicate.com/). This recipe targets developers who are looking to streamline the process of creating unit tests with minimal manual effort. The user inputs Python code and returns unit test code, incorporating "test doubles" for external dependencies.  The notebook depends on Granite [`Utils`](https://github.com/ibm-granite-community/utils) package for integration with LLMs using the [Langchain](https://www.langchain.com/) framework.

### Pre-requisites

To run this notebook, ensure you have the following:

1. Python version: 3.9 or higher
2. A Replicate API token. See the `../recipes/Getting_Started_with_Granite_Code.ipynb` for details.

### Model Details:

1. Model Platform : Replicate
2. Model : IBM Granite 20b Code Instruct 8k
3. Model Version : ibm-granite/granite-20b-code-instruct-8k:409a0c68b74df416c7ae2a3f1552101123356f5a2c6e46d681629b62904c605b

### Program 

1. Input: Python code/snippets with instructions for test packages that need to utilized and optional type of unit test case scenarios to be covered.
2. Output: Python code with unit test packages and libraries, test doubles, assert implementation for Unit testing of given input.

> **Note:**
>
> Results using the 20b code instruct Granite model are generally better than the outputs when using the 8b code instruct Granite model. Whichever model you use, the code generated may require additional modifications to work, depending on the test libraries requested and other aspects of the user input. 

### Install the required Langchain and Replicate packages

Include a granite-community package with some simple utility functions.

In [1]:
!pip install git+https://github.com/ibm-granite-community/utils \
    "langchain_community<0.3.0" \
    replicate

Collecting git+https://github.com/ibm-granite-community/utils
  Cloning https://github.com/ibm-granite-community/utils to /private/var/folders/8t/m9m188_d0tb8szvfqlc20hfr0000gn/T/pip-req-build-idbityg4
  Running command git clone --filter=blob:none --quiet https://github.com/ibm-granite-community/utils /private/var/folders/8t/m9m188_d0tb8szvfqlc20hfr0000gn/T/pip-req-build-idbityg4
  Resolved https://github.com/ibm-granite-community/utils to commit a5965f40db3950dd2a41f3ca62a2c34adcdc20d7
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting langchain_community<0.3.0
  Using cached langchain_community-0.2.19-py3-none-any.whl.metadata (2.7 kB)
Collecting replicate
  Using cached replicate-1.0.4-py3-none-any.whl.metadata (29 kB)
Collecting python-dotenv (from ibm-granite-community-utils==0.1.dev46)
  Using cached python_dotenv-1.0.1-py3-none-any.whl.metadata (23 

In [2]:
from ibm_granite_community.notebook_utils import set_env_var, get_env_var

### Define a System Prompt

We will pass the following system prompt as part of the inference call.

In [3]:
system_prompt = """
Role: Python Code Generator.
User Input: <Python code>, optional Test libraries, output file locations.
Output: Python code for unit testing success and failure conditions of the given input <python code> leveraging the specified test libraries. 
Validity: Generates error-free unit test code for the input <python code> by importing those libraries.
Test Libraries: User provided test libraries.
"""

## Remote Model using Replicate

We will use Granite code models hosted at [Replicate](https://replicate.com) for inference, hosted in the [ibm-granite](https://replicate.com/ibm-granite) organization.

> **TIP:** If you get an "authentication" or similar error below, see the instructions mentioned above at `../recipes/Getting_Started_with_Granite_Code.ipynb`.

Now, we define the model to use and a dictional of parameters to pass to the `Replicate` constructor.

In [4]:
model_id="ibm-granite/granite-3.1-8b-instruct"
 
input_parameters = {      
        "top_k": 60,
        "top_p": 0.3, 
        "max_tokens": 1000,
        "min_tokens": 0,
        "temperature": 0.3, 
        "presence_penalty": 0,
        "frequency_penalty": 0,
        "system_prompt": system_prompt
        }
from langchain_community.llms import Replicate

granite_via_replicate = Replicate(
            model=model_id,
            model_kwargs=input_parameters,
            replicate_api_token=get_env_var('REPLICATE_API_TOKEN'),
        )
        

### Perform Inference

Next, we invoke the model to generate test cases for application code.

The first example requests generation of unit-test code for the input Python code shown in the prompt. We specifically ask the model to use Python's `unittest` library for the test code.

In [5]:
code1="""
Use Python's "unittest" library to generate unit tests for the following code:

import json

def lambda_handler(event, context):
    if 'queryStringParameters' in event:    # If parameters
        print(event['queryStringParameters']['first_name'])
        print(event['queryStringParameters']['last_name'])
        body = 'Hello {} {}!'.format(event['queryStringParameters']['first_name'], 
                                    event['queryStringParameters']['last_name'])  
    else:    # If no parameters
        print('No parameters!')
        body = 'Who are you?'
        
    return {
        'statusCode': 200,
        'body': json.dumps(body)
    }
"""

replicate_response = granite_via_replicate.invoke(code1)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: ```python
import unittest
import json
from main import lambda_handler  # Assuming the code is in a file named 'main.py'

class TestLambdaHandler(unittest.TestCase):

    def test_with_parameters(self):
        event = {
            'queryStringParameters': {
                'first_name': 'John',
                'last_name': 'Doe'
            }
        }
        response = lambda_handler(event, None)
        self.assertEqual(response['statusCode'], 200)
        self.assertEqual(response['body'], json.loads('Hello John Doe!'))

    def test_without_parameters(self):
        event = {}
        response = lambda_handler(event, None)
        self.assertEqual(response['statusCode'], 200)
        self.assertEqual(response['body'], json.loads('Who are you?'))

if __name__ == '__main__':
    unittest.main()
```

This Python script uses the `unittest` library to create a test case for the `lambda_handler` function. It includes two test methods: `test_with_paramet

Here are some steps you can use to try running the generated test code:

1. Save the `lambda_handler` code in the prompt to a Python file. Include the import statements. Let's assume you name this file `lambda_handler.py`.
2. Save the generated test code to a file, for example `test_lambda_handler.py`, in the same directory.

You will most likely need to modify the input statement for importing `lambda_handler` that was generated for the test code. For example, if you followed our example naming convention and both files are in the same directory, then the import statement will be:

```python
from lambda_handler import lambda_handler
```

Now you can run the tests using the following shell command in the same directory with the files:

```shell
python -m unittest
```

Do the tests pass? How good are the tests themselves? Can you modify the prompt with suggestions for improving the quality of the tests. For example, what "corner cases" should the tests cover?

#### Second Example: Generate Tests for Multiple Functions

Try running the output tests the same way as before.

In [None]:
code2="""
Use Python's "unittest" library to generate unit tests for the following code:

import numpy as np
import matplotlib.pyplot as plt
import time

def load_data(fname):
    points = np.loadtxt(fname, delimiter=',') 
    y_ = points[:,1]
    # append '1' to account for the intercept
    x_ = np.ones([len(y_),2]) 
    x_[:,0] = points[:,0]
    # display plot
    #plt.plot(x_[:,0], y_, 'ro')
    #plt.xlabel('x-axis')
    #plt.ylabel('y-axis')
    #plt.show()
    print('data loaded. x:{} y:{}'.format(x_.shape, y_.shape))
    return x_, y_

def evaluate_cost(x_,y_,params):
    tempcost = 0
    for i in range(len(y_)):
        tempcost += (y_[i] - ((params[0] * x_[i,0]) + params[1])) ** 2 
    return tempcost / float(10000)   

def evaluate_gradient(x_,y_,params):
    m_gradient = 0
    b_gradient = 0
    N = float(len(y_))
    for i in range(len(y_)):
        m_gradient += -(2/N) * (x_[i,0] * (y_[i] - ((params[0] * x_[i,0]) + params[1])))
        b_gradient += -(2/N) * (y_[i] - ((params[0] * x_[i,0]) + params[1]))
    return [m_gradient,b_gradient]

"""

replicate_response = granite_via_replicate.invoke(code2)

print(f"Granite response from Replicate: {replicate_response}")

#### Third Example: Generate Tests for "Middleware" Code

We'll also explicit ask for calls to other components to be replaced with "mocks".

In [None]:
code3="""
Use the "pytest" library to generate unit tests for the following code. Use mocks and test data for the calls to Kafka:

if __name__ == "__main__":
    # create Spark session
    spark = SparkSession.builder.appName("TwitterSentimentAnalysis").getOrCreate()
    spark.sparkContext.setLogLevel("ERROR") # Ignore INFO DEBUG output
    df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "localhost:9092") \
        .option("subscribe", topic_name) \
        .load()

    df = df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")

    df = df.withColumn("data", from_json(df.value, Sentiment.get_schema())).select("data.*")
    df = df \
        .withColumn("ts", to_timestamp(from_unixtime(expr("timestamp_ms/1000")))) \
        .withWatermark("ts", "1 seconds") # old data will be removed

    # Preprocess the data
    df = Sentiment.preprocessing(df)

    # text classification to define polarity and subjectivity
    df = Sentiment.text_classification(df)

    assert type(df) == pyspark.sql.dataframe.DataFrame

    row_df = df.select(
        to_json(struct("id")).alias('key'),
        to_json(struct('text', 'lang', 'ts', 'polarity_v', 'polarity', 'subjectivity_v')).alias("value")
    )
 

    # Writing to Kafka
    query = row_df\
        .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \
        .writeStream\
        .format("kafka") \
        .option("kafka.bootstrap.servers", "localhost:9092") \
        .option("topic", output_topic) \
        .option("checkpointLocation", "file:/Users/user/tmp") \
        .start()
 
    query.awaitTermination()"""
 
replicate_response = granite_via_replicate.invoke(code3)

print(f"Granite response from Replicate: {replicate_response}")

#### Fourth Example: More Use of "Test Doubles"

Here is a scenario with the user input code contains a class definition, which is used to generate test code using test doubles.

In [None]:
code4="""
Use the "pytest" library to generate unit tests for the following class definition. Use test doubles and test data as appropriate to test this class:

from pyspark.sql.types import *

class Sentiment:
    def get_schema():
        schema = StructType([
            StructField("created_at", StringType()),
            StructField("id", StringType()),
            StructField("text", StringType()),
            StructField("source", StringType()),
            StructField("truncated", StringType()),
            StructField("in_reply_to_status_id", StringType()),
            StructField("in_reply_to_user_id", StringType()),
            StructField("in_reply_to_screen_name", StringType()),
            StructField("user", StringType()),
            StructField("coordinates", StringType()),
            StructField("place", StringType()),
            StructField("quoted_status_id", StringType()),
            StructField("is_quote_status", StringType()),
            StructField("quoted_status", StringType()),
            StructField("retweeted_status", StringType()),
            StructField("quote_count", StringType()),
            StructField("reply_count", StringType()),
            StructField("retweet_count", StringType()),
            StructField("favorite_count", StringType()),
            StructField("entities", StringType()),
            StructField("extended_entities", StringType()),
            StructField("favorited", StringType()),
            StructField("retweeted", StringType()),
            StructField("possibly_sensitive", StringType()),
            StructField("filter_level", StringType()),
            StructField("lang", StringType()),
            StructField("matching_rules", StringType()),
            StructField("name", StringType()),
            StructField("timestamp_ms", StringType())
        ])
        return schema

    @staticmethod
    def preprocessing(df):
        # words = df.select(explode(split(df.text, " ")).alias("word"))
        df = df.filter(col('text').isNotNull())
        df = df.withColumn('text', regexp_replace('text', r'http\S+', ''))
        df = df.withColumn('text', regexp_replace('text', r'[^\x00-\x7F]+', ''))
        df = df.withColumn('text', regexp_replace('text', r'[\n\r]', ' '))
        df = df.withColumn('text', regexp_replace('text', '@\w+', ''))
        df = df.withColumn('text', regexp_replace('text', '#', ''))
        df = df.withColumn('text', regexp_replace('text', 'RT', ''))
        df = df.withColumn('text', regexp_replace('text', ':', ''))
        df = df.withColumn('source', regexp_replace('source', '<a href="' , ''))

        return df

    # text classification
    @staticmethod
    def polarity_detection(text):
        return TextBlob(text).sentiment.polarity

    @staticmethod
    def subjectivity_detection(text):
        return TextBlob(text).sentiment.subjectivity

    @staticmethod
    def text_classification(words):
        # polarity detection
        polarity_detection_udf = udf(Sentiment.polarity_detection, FloatType())
        words = words.withColumn("polarity_v", polarity_detection_udf("text"))
        words = words.withColumn(
            'polarity',
            when(col('polarity_v') > 0, lit('Positive'))
            .when(col('polarity_v') == 0, lit('Neutral'))
            .otherwise(lit('Negative'))
        )
        # subjectivity detection
        subjectivity_detection_udf = udf(Sentiment.subjectivity_detection, FloatType())
        words = words.withColumn("subjectivity_v", subjectivity_detection_udf("text"))
        return words

"""
replicate_response = granite_via_replicate.invoke(code4)

print(f"Granite response from Replicate: {replicate_response}")

#### Fifth Example: Test External API Calls

In [None]:
code5="""
Use the "pytest" library to generate unit tests for the following code. Use test doubles and test data as appropriate to test this class:

import facebook

token = 'your token'

graph = facebook.GraphAPI(token)
profile = graph.get_object("me")
friends = graph.get_connections("me", "friends")
friend_list = [friend['name'] for friend in friends['data']]
print friend_list"""

replicate_response = granite_via_replicate.invoke(code5)

print(f"Granite response from Replicate: {replicate_response}")

#### Sixth Example: Same as the Fifth Example, but Omit Requesting a Specific Test Library

In [None]:
code6="""
Generate unit tests for the following code. Use test doubles and test data as appropriate to test this class:

import facebook

token = 'your token'

graph = facebook.GraphAPI(token)
profile = graph.get_object("me")
friends = graph.get_connections("me", "friends")
friend_list = [friend['name'] for friend in friends['data']]
print friend_list"""


replicate_response = granite_via_replicate.invoke(code6)

print(f"Granite response from Replicate: {replicate_response}")