# Use Remote Granite Code Models (20B) with LangChain for Unit Test Code Generation

## Introduction and Setup
 
This notebook demonstrates how to generate unit tests for Python functions and classes using inference calls against a model hosted remotely on [Replicate](https://replicate.com/). The use case targets developers who are looking to streamline the process of creating unit tests with minimal manual effort. The system takes Python code as input and returns unit test code, incorporating "test doubles" for external dependencies.  The notebook depends on Granite [`Utils`](https://github.com/ibm-granite-community/utils) package for integration with LLMs using Langchain framework.
 
 
#### Pre-requisites

To run this notebook, ensure you have the following:

1. Python version: 3.9 or higher
2. langchain_community<0.3.0
3. langchain_ollama<0.2.0
4. replicate

#### Model Details:

1. Model Platform : Replicate
2. Model : IBM Granite 20b Code Instruct 8k
3. Model Version : ibm-granite/granite-20b-code-instruct-8k:409a0c68b74df416c7ae2a3f1552101123356f5a2c6e46d681629b62904c605b

#### Program 
1. Input : Python code/snippets with instructions for test packages that need to utilized and optional type of unit test case scenarios to be covered
2. Output : Python code with unit test packages and libraries, test doubles , assert implementation for Unit testing of given input 

#### Disclaimer

Results of 20b code instruct granite model using 8k context appears convincing than 8b code instruct granite model with 128k context. The code generated may need additional modification dependending on the libraries and user input 


### Install required Langchain replicate packages

In [None]:
!pip install  langchain_community replicate

### Install Granite `utils` package

This package is a thin shim with various functions that are required for notebooks.

To see the implementation of its functions, see the [utils repo](https://github.com/ibm-granite-community/utils/tree/main).

In [1]:
!python -m pip install git+https://github.com/ibm-granite-community/utils

Collecting git+https://github.com/ibm-granite-community/utils
  Cloning https://github.com/ibm-granite-community/utils to c:\users\012721744\appdata\local\temp\pip-req-build-qf9bp88y
  Resolved https://github.com/ibm-granite-community/utils to commit bc18a5c8b8f3d645032ec765f9d3415e359c87e0
  Installing build dependencies: started
  Installing build dependencies: finished with status 'done'
  Getting requirements to build wheel: started
  Getting requirements to build wheel: finished with status 'done'
  Preparing metadata (pyproject.toml): started
  Preparing metadata (pyproject.toml): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/ibm-granite-community/utils 'C:\Users\012721744\AppData\Local\Temp\pip-req-build-qf9bp88y'

[notice] A new release of pip is available: 24.0 -> 24.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
from ibm_granite_community.langchain_utils import find_langchain_model 

### Define a Prompt

The cells below demonstrate a remote option and a local option for model inference.

Both will perform a blocking call using the following system prompt:

In [3]:
prompt =   """Role: Python Code Generator.
           User Input: <Python code>, optional Test libraries, output file locations ..
           Output: Python code for unit testing success and failure conditions of the given input <python code> leveraging the test libraries 
           Validity: Generates error-free unit test code for the input <python code> by importing those libraries
           Test Libraries: User provided test libraries."""

## Remote Model using Replicate

### Establish Replicate Account

To use this remote option, create an account at [Replicate](https://replicate.com).

### Provide your API token


Obtain your REPLICATE_API_TOKEN at replicate.com/account/api-tokens

There are three ways to provide this value to the cells below. In order of precedence:

1. As an environment variable
2. As a Google colab secret
3. Supplied by the user using getpass()

Here: Created a  environment variable `REPLICATE_API_TOKEN='xxxxx'` in `.env` file in the current directory.

### Choose a Model

Two Granite Code models are available in the [`ibm-granite`](https://replicate.com/ibm-granite) org at Replicate.

The `find_langchain_model` function below imports the `replicate` package.

Model Arguments are defined using input parameters dictionary


In [4]:
#model_id = "ibm-granite/granite-8b-code-instruct-128k"

model_id="ibm-granite/granite-20b-code-instruct-8k"
 
input_parameters = {      
        "top_k": 60,
        "top_p": 0.3, 
        "max_tokens": 1000,
        "min_tokens": 0,
        "temperature": 0.3, 
        "presence_penalty": 0,
        "frequency_penalty": 0,
        "system_prompt": prompt
        }
granite_via_replicate = find_langchain_model(platform="replicate", model_id=model_id,
                                              model_kwargs=input_parameters
        )

### Perform Inference

#### Below use case covers generation of unit test code for the input code leveraging pytest library 

In [5]:

codes="""Use pytest test library to generate unit test cases for the below

import sqlite3

class Database:
    def __init__(self, db_name):
        #Initialize with the database name.
        self.db_name = db_name
        self.conn = None
        self.cursor = None

    def connect(self):
        #Establish a connection to the SQLite database.
        self.conn = sqlite3.connect(self.db_name)
        self.cursor = self.conn.cursor()
        print("Connected to the database.")

    def fetch_data(self, query):
        #Fetch data from the database using the provided SQL query.
        if not self.cursor:
            raise RuntimeError("Database not connected.")
        self.cursor.execute(query)
        return self.cursor.fetchall()

    def close(self):
        #Close the database connection.
        if self.conn:
            self.conn.close()
            print("Database connection closed.")

def main():
    # Initialize the database connection
    db = Database('example.db')
    
    # Connect to the database
    db.connect()
    
    # Define a query to fetch data
    query = "SELECT * FROM users"  # Change this query according to your table structure
    
    # Fetch data
    try:
        data = db.fetch_data(query)
        print("Fetched Data:")
        for row in data:
            print(row)
    except Exception as e:
        print(f'An error occurred: {e}')
    finally:
        # Close the database connection
        db.close()

if __name__ == "__main__":
    main()
"""


replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: Here is an example of how you can use pytest to generate unit test cases for the given code:

import pytest
from database import Database

@pytest.fixture
def db():
 # Initialize the database connection
 db = Database('example.db')
 
 # Connect to the database
 db.connect()
 
 yield db
 
 # Close the database connection
 db.close()

def test_fetch_data(db):
 # Define a query to fetch data
 query = "SELECT * FROM users"  # Change this query according to your table structure
 
 # Fetch data
 data = db.fetch_data(query)
 
 # Assert that the fetched data is not empty
 assert data
 
 # Assert that the fetched data has the expected structure
 expected_structure = [(1, 'John', 'Doe'), (2, 'Jane', 'Smith')]  # Change this to match your table structure
 assert data == expected_structure

def test_fetch_data_with_invalid_query(db):
 # Define an invalid query to fetch data
 query = "SELECT * FROM nonexistent_table"
 
 # Fetch data and assert that an exception is r

#### Invoke the model to generate test cases for application code 

In [6]:
codes=""" Use unittest library to generate the unit test cases for the below
import json

def lambda_handler(event, context):
    if 'queryStringParameters' in event:    # If parameters
        print(event['queryStringParameters']['first_name'])
        print(event['queryStringParameters']['last_name'])
        body = 'Hello {} {}!'.format(event['queryStringParameters']['first_name'], 
                                    event['queryStringParameters']['last_name'])  
    else:    # If no parameters
        print('No parameters!')
        body = 'Who are you?'
        
    return {
        'statusCode': 200,
        'body': json.dumps(body)
    }"""



replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: Here's an example of how you can generate unit test cases for the given function using the unittest library:

import json
import unittest
from unittest.mock import patch

def lambda_handler(event, context):
 if 'queryStringParameters' in event:    # If parameters
 print(event['queryStringParameters']['first_name'])
 print(event['queryStringParameters']['last_name'])
 body = 'Hello {} {}!'.format(event['queryStringParameters']['first_name'], 
 event['queryStringParameters']['last_name'])  
 else:    # If no parameters
 print('No parameters!')
 body = 'Who are you?'
 
 return {
 'statusCode': 200,
 'body': json.dumps(body)
 }
class LambdaHandlerTestCase(unittest.TestCase):
 @patch('__main__.print')
 def test_with_parameters(self, mock_print):
 event = {
 'queryStringParameters': {
 'first_name': 'John',
 'last_name': 'Doe'
 }
 }
 result = lambda_handler(event, None)
 self.assertEqual(result['statusCode'], 200)
 self.assertEqual(result['body'], '"Hello Joh

#### Here is another example how user can provide multiple functions code as input

In [7]:
codes="""Here are the function definitions to be unit tested
import numpy as np
import matplotlib.pyplot as plt
import time

def load_data(fname):
    points = np.loadtxt(fname, delimiter=',') 
    y_ = points[:,1]
    # append '1' to account for the intercept
    x_ = np.ones([len(y_),2]) 
    x_[:,0] = points[:,0]
    # display plot
    #plt.plot(x_[:,0], y_, 'ro')
    #plt.xlabel('x-axis')
    #plt.ylabel('y-axis')
    #plt.show()
    print('data loaded. x:{} y:{}'.format(x_.shape, y_.shape))
    return x_, y_

def evaluate_cost(x_,y_,params):
    tempcost = 0
    for i in range(len(y_)):
        tempcost += (y_[i] - ((params[0] * x_[i,0]) + params[1])) ** 2 
    return tempcost / float(10000)   

def evaluate_gradient(x_,y_,params):
    m_gradient = 0
    b_gradient = 0
    N = float(len(y_))
    for i in range(len(y_)):
        m_gradient += -(2/N) * (x_[i,0] * (y_[i] - ((params[0] * x_[i,0]) + params[1])))
        b_gradient += -(2/N) * (y_[i] - ((params[0] * x_[i,0]) + params[1]))
    return [m_gradient,b_gradient]

"""


replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")
 

Granite response from Replicate: To perform unit testing on the given functions, we can use a unit testing framework such as unittest in Python. Here is an example of how we can write unit tests for the given functions:

import unittest
import numpy as np
import matplotlib.pyplot as plt
import time

class TestFunctions(unittest.TestCase):
 def test_load_data(self):
 points = np.loadtxt('data.csv', delimiter=',') 
 y_ = points[:,1]
 # append '1' to account for the intercept
 x_ = np.ones([len(y_),2]) 
 x_[:,0] = points[:,0]
 # display plot
 #plt.plot(x_[:,0], y_, 'ro')
 #plt.xlabel('x-axis')
 #plt.ylabel('y-axis')
 #plt.show()
 self.assertEqual(x_.shape, (10000, 2))
 self.assertEqual(y_.shape, (10000,))
 def test_evaluate_cost(self):
 x_ = np.array([[1, 2], [3, 4], [5, 6]])
 y_ = np.array([1, 2, 3])
 params = np.array([0.5, 0.5])
 cost = evaluate_cost(x_, y_, params)
 self.assertEqual(cost, 0.5)
 def test_evaluate_gradient(self):
 x_ = np.array([[1, 2], [3, 4], [5, 6]])
 y_ = np.array([

#### Example for user input leveraging middleware code

In [8]:
codes=""" For the below code write unit test code using pytest library package
if __name__ == "__main__":
    # create Spark session
    spark = SparkSession.builder.appName("TwitterSentimentAnalysis").getOrCreate()
    spark.sparkContext.setLogLevel("ERROR") # Ignore INFO DEBUG output
    df = spark \
        .readStream \
        .format("kafka") \
        .option("kafka.bootstrap.servers", "localhost:9092") \
        .option("subscribe", topic_name) \
        .load()

    df = df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")

    df = df.withColumn("data", from_json(df.value, Sentiment.get_schema())).select("data.*")
    df = df \
        .withColumn("ts", to_timestamp(from_unixtime(expr("timestamp_ms/1000")))) \
        .withWatermark("ts", "1 seconds") # old data will be removed

    # Preprocess the data
    df = Sentiment.preprocessing(df)

    # text classification to define polarity and subjectivity
    df = Sentiment.text_classification(df)

    assert type(df) == pyspark.sql.dataframe.DataFrame

    row_df = df.select(
        to_json(struct("id")).alias('key'),
        to_json(struct('text', 'lang', 'ts', 'polarity_v', 'polarity', 'subjectivity_v')).alias("value")
    )
 

    # Writing to Kafka
    query = row_df\
        .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") \
        .writeStream\
        .format("kafka") \
        .option("kafka.bootstrap.servers", "localhost:9092") \
        .option("topic", output_topic) \
        .option("checkpointLocation", "file:/Users/user/tmp") \
        .start()
 
    query.awaitTermination()"""

 
replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: Here is a sample unit test code using pytest library package for the given code:
```
import pytest
from pyspark.sql import SparkSession
from pyspark.sql.functions import from_json, to_json, struct, to_timestamp, from_unixtime
from sentiment import Sentiment
topic_name = "twitter_topic"
output_topic = "output_topic"
@pytest.fixture(scope="module")
def spark_session():
 # create Spark session
 spark = SparkSession.builder.appName("TwitterSentimentAnalysis").getOrCreate()
 spark.sparkContext.setLogLevel("ERROR") # Ignore INFO DEBUG output
 return spark
def test_df_type(spark_session):
 df = spark_session         .readStream         .format("kafka")         .option("kafka.bootstrap.servers", "localhost:9092")         .option("subscribe", topic_name)         .load()
 
 df = df.selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)")
 
 df = df.withColumn("data", from_json(df.value, Sentiment.get_schema())).select("data.*")
 df = df         .withColumn("ts"

#### Here is a scenario with user input code as class definition to generate test cases code using test doubles

In [9]:
codes="""Below is the class definition to be tested using pytest library and fixtures. Use below class instances for unit test modules

class Sentiment:
    def get_schema():
        schema = StructType([
            StructField("created_at", StringType()),
            StructField("id", StringType()),
            StructField("text", StringType()),
            StructField("source", StringType()),
            StructField("truncated", StringType()),
            StructField("in_reply_to_status_id", StringType()),
            StructField("in_reply_to_user_id", StringType()),
            StructField("in_reply_to_screen_name", StringType()),
            StructField("user", StringType()),
            StructField("coordinates", StringType()),
            StructField("place", StringType()),
            StructField("quoted_status_id", StringType()),
            StructField("is_quote_status", StringType()),
            StructField("quoted_status", StringType()),
            StructField("retweeted_status", StringType()),
            StructField("quote_count", StringType()),
            StructField("reply_count", StringType()),
            StructField("retweet_count", StringType()),
            StructField("favorite_count", StringType()),
            StructField("entities", StringType()),
            StructField("extended_entities", StringType()),
            StructField("favorited", StringType()),
            StructField("retweeted", StringType()),
            StructField("possibly_sensitive", StringType()),
            StructField("filter_level", StringType()),
            StructField("lang", StringType()),
            StructField("matching_rules", StringType()),
            StructField("name", StringType()),
            StructField("timestamp_ms", StringType())
        ])
        return schema

    @staticmethod
    def preprocessing(df):
        # words = df.select(explode(split(df.text, " ")).alias("word"))
        df = df.filter(col('text').isNotNull())
        df = df.withColumn('text', regexp_replace('text', r'http\S+', ''))
        df = df.withColumn('text', regexp_replace('text', r'[^\x00-\x7F]+', ''))
        df = df.withColumn('text', regexp_replace('text', r'[\n\r]', ' '))
        df = df.withColumn('text', regexp_replace('text', '@\w+', ''))
        df = df.withColumn('text', regexp_replace('text', '#', ''))
        df = df.withColumn('text', regexp_replace('text', 'RT', ''))
        df = df.withColumn('text', regexp_replace('text', ':', ''))
        df = df.withColumn('source', regexp_replace('source', '<a href="' , ''))

        return df

    # text classification
    @staticmethod
    def polarity_detection(text):
        return TextBlob(text).sentiment.polarity

    @staticmethod
    def subjectivity_detection(text):
        return TextBlob(text).sentiment.subjectivity

    @staticmethod
    def text_classification(words):
        # polarity detection
        polarity_detection_udf = udf(Sentiment.polarity_detection, FloatType())
        words = words.withColumn("polarity_v", polarity_detection_udf("text"))
        words = words.withColumn(
            'polarity',
            when(col('polarity_v') > 0, lit('Positive'))
            .when(col('polarity_v') == 0, lit('Neutral'))
            .otherwise(lit('Negative'))
        )
        # subjectivity detection
        subjectivity_detection_udf = udf(Sentiment.subjectivity_detection, FloatType())
        words = words.withColumn("subjectivity_v", subjectivity_detection_udf("text"))
        return words

"""
replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: To test the `Sentiment` class using pytest and fixtures, you can create a test module with the following code:
```python
import pytest
from pyspark.sql.types import *
from pyspark.sql.functions import col, explode, split, udf, when
from textblob import TextBlob
from sentiment import Sentiment

@pytest.fixture
def sentiment():
 return Sentiment()

def test_get_schema(sentiment):
 schema = sentiment.get_schema()
 assert isinstance(schema, StructType)
 assert len(schema.fields) == 23

def test_preprocessing(sentiment):
 df = spark.createDataFrame([("test",)], ["text"])
 df = sentiment.preprocessing(df)
 assert df.count() == 1
 assert df.first().text == "test"

def test_polarity_detection(sentiment):
 assert sentiment.polarity_detection("test") == 0.0

def test_subjectivity_detection(sentiment):
 assert sentiment.subjectivity_detection("test") == 0.0

def test_text_classification(sentiment):
 words = spark.createDataFrame([("test",)], ["text"])
 words = sen

#### Below example showcases code input with external api function calls

In [10]:
codes=""" Given the below code, provide unit testing code 

import facebook

token = 'your token'

graph = facebook.GraphAPI(token)
profile = graph.get_object("me")
friends = graph.get_connections("me", "friends")
friend_list = [friend['name'] for friend in friends['data']]
print friend_list"""

replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: Here is an example of unit testing code for the given code:

import facebook
import unittest
from unittest.mock import patch, MagicMock

class TestFacebookAPI(unittest.TestCase):
 @patch('facebook.GraphAPI')
 def test_get_friend_list(self, mock_graph):
 # Create a mock graph object with a predefined friend list
 mock_graph_instance = MagicMock()
 mock_graph_instance.get_object.return_value = {'name': 'John Doe'}
 mock_graph_instance.get_connections.return_value = {
 'data': [
 {'name': 'Alice'},
 {'name': 'Bob'},
 {'name': 'Charlie'}
 ]
 }
 mock_graph.return_value = mock_graph_instance
 
 # Call the function being tested
 token = 'your token'
 graph = facebook.GraphAPI(token)
 profile = graph.get_object("me")
 friends = graph.get_connections("me", "friends")
 friend_list = [friend['name'] for friend in friends['data']]
 
 # Assert that the friend list is correct
 self.assertEqual(friend_list, ['Alice', 'Bob', 'Charlie'])
 
if __name__ == '__main__':
 un

#### Another example with external api function call code input but with explicit instruction to use pytest library

In [11]:
codes=""" Given the below code, provide unit testing code with pytest

import facebook

token = 'your token'

graph = facebook.GraphAPI(token)
profile = graph.get_object("me")
friends = graph.get_connections("me", "friends")
friend_list = [friend['name'] for friend in friends['data']]
print friend_list"""


replicate_response = granite_via_replicate.invoke(codes)

print(f"Granite response from Replicate: {replicate_response}")

Granite response from Replicate: Here's an example of unit testing code using pytest for the given code:

import pytest
import facebook

@pytest.fixture
def token():
 return 'your token'

def test_get_friend_list(token):
 graph = facebook.GraphAPI(token)
 profile = graph.get_object("me")
 friends = graph.get_connections("me", "friends")
 friend_list = [friend['name'] for friend in friends['data']]
 assert len(friend_list) > 0
 for friend in friend_list:
 assert isinstance(friend, str)

This code defines a fixture called token that returns the same string every time it's called. Then it defines a test function called test_get_friend_list that takes the token fixture as an argument. Inside the test function, it creates a GraphAPI object using the token, gets the user's profile and friends, creates a list of friend names, and checks that the list is not empty and that each friend name is a string. If any of these checks fail, pytest will report an error. 
