# Answering the 200 train problems on the Java250 dataset Gemma and RAG Gemma finetuned¶

In [1]:
import os
import pandas as pd
import numpy as np
import time

from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from tqdm import tqdm


from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_community.llms import Ollama

## Import the models

### Vector store

In [2]:
embeddings = OllamaEmbeddings(model="gemma:7b-instruct", num_gpu = 2,   num_thread = 24)
# load from disk
db = Chroma(persist_directory="bp_chroma_db", embedding_function=embeddings, )
retriever = db.as_retriever()

### LLMs

In [18]:
codellama = Ollama(model="codellama:latest", num_gpu=2, num_thread = 24, num_ctx = 4096, top_k = 10, top_p = 0.5, temperature = 0.6, timeout = 120)
gemma = Ollama(model="gemma:7b-instruct", num_gpu=2, num_thread = 24, num_ctx = 4096, top_k = 10, top_p = 0.5, temperature = 0.6, timeout = 120)

## RAG Prompts:

In [23]:
gemma_rag_template = """ <start_of_turn>user
You are a helpful assistant for Java coding. Use the following pieces of context to answer the question at the end. Only use Java code.
If you don't know the answer, just say that you don't know, don't try to make up an answer.
Use Java best practices to guide your code and fast cpu time code, but do not explain that you are doing it in your answer.
Only show the Java code and a short explanation of what it does.
Context: {context} 
Question: {question}
<end_of_turn>
<start_of_turn>model
"""
java_prompt_gemma_rag = ChatPromptTemplate(input_variables=['question', 'context'], output_parser=None, partial_variables={}, 
                       messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question', 'context'], 
                                                                              output_parser=None, partial_variables={}, template= gemma_rag_template, 
                                                                              template_format='f-string', validate_template=True), additional_kwargs={})])

In [5]:
gemma_template = """ <start_of_turn>user
Use Java code in your answer.
Question: {question}
<end_of_turn>
<start_of_turn>model
"""

java_prompt_gemma = ChatPromptTemplate(input_variables=['question'], output_parser=None, partial_variables={}, 
                       messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], 
                                                                              output_parser=None, partial_variables={}, template= gemma_template, 
                                                                              template_format='f-string', validate_template=True), additional_kwargs={})])

In [6]:
codellama_template = """[INST]<<SYS>> Use Java code in your answer. <</SYS>>
                Question: {question}
            [/INST]"""
java_prompt_codellama = ChatPromptTemplate(input_variables=['question'], output_parser=None, partial_variables={}, 
                       messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], 
                                                                              output_parser=None, partial_variables={}, template= codellama_template, 
                                                                              template_format='f-string', validate_template=True), additional_kwargs={})])

### Chain making

In [24]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | java_prompt_gemma_rag
    | gemma
    | StrOutputParser()
)

gemma_chain = (
    {"question": RunnablePassthrough()}
    | java_prompt_gemma
    | gemma
    | StrOutputParser()
)

codellama_chain = (
    {"question": RunnablePassthrough()}
    | java_prompt_codellama
    | codellama
    | StrOutputParser()
)

## QA

read the questions in

In [8]:
Java200 = pd.read_csv("Java200.csv")

In [12]:
codellama_chain.invoke(Java200["desc"][5])

'  public static void main(String[] args) {\n    Scanner sc = new Scanner(System.in);\n    int a = sc.nextInt();\n    int b = sc.nextInt();\n\n    if (a == 1 && b == 2 || a == 2 && b == 3 || a == 3 && b == 1) {\n      System.out.println("The correct choice is: " + (a == 1 ? "A" : (a == 2 ? "B" : "C")));\n    } else {\n      System.out.println("No solution");\n    }\n  }'

In [10]:
gemma_chain.invoke(Java200["desc"][0])

'```java\nimport java.util.*;\n\npublic class PairwiseDistinct {\n\n    public static void main(String[] args) {\n        Scanner scanner = new Scanner(System.in);\n\n        int N = scanner.nextInt();\n\n        int[] A = new int[N];\n\n        for (int i = 0; i < N; i++) {\n            A[i] = scanner.nextInt();\n        }\n\n        boolean isPairwiseDistinct = true;\n\n        for (int i = 0; i < N; i++) {\n            for (int j = i + 1; j < N; j++) {\n                if (A[i] == A[j]) {\n                    isPairwiseDistinct = false;\n                    break;\n                }\n            }\n        }\n\n        System.out.println(isPairwiseDistinct ? "YES" : "NO");\n    }\n}\n```'

In [11]:
rag_chain.invoke(Java200["desc"][4])

'```java\nimport java.util.Scanner;\n\npublic class Keyboard {\n\n    public static void main(String[] args) {\n        Scanner scanner = new Scanner(System.in);\n        String s = scanner.nextLine();\n\n        String result = "";\n        for (char c : s.toCharArray()) {\n            switch (c) {\n                case \'0\':\n                    result += "0";\n                    break;\n                case \'1\':\n                    result += "1";\n                    break;\n                case \'B\':\n                    result = result.substring(0, result.length() - 1);\n                    break;\n            }\n        }\n\n        System.out.println(result);\n    }\n}\n```\n\nThe code reads the string of keystrokes from the input, then iterates over each character in the string. It uses a switch statement to handle each key: \'0\' inserts \'0\', \'1\' inserts \'1\', and \'B\' deletes the last character from the result string. Finally, it prints the resulting string.'

Loop for answers

In [25]:
rag_answers = []
gemma_answers = []
i = 0
for problem in Java200["id"]:
    question = Java200["desc"][i]
    rag_a = rag_chain.invoke(question)
    print("RAG Model answer to: " + problem + "\n\n" + rag_a + "\n\n")
    gm_a = gemma_chain.invoke(question)
    print("Gemma Model answer to: " + problem + "\n\n" + gm_a + "\n\n")
    rag_answers.append(rag_a)
    gemma_answers.append(gm_a)
    print(i)
    print("Complete \n")
    i+=1

RAG Model answer to: p02779.html

```java
import java.util.Scanner;

public class PairwiseDistinct {

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        int n = scanner.nextInt();

        int[] a = new int[n];

        for (int i = 0; i < n; i++) {
            a[i] = scanner.nextInt();
        }

        boolean isPairwiseDistinct = true;

        for (int i = 0; i < n; i++) {
            for (int j = i + 1; j < n; j++) {
                if (a[i] == a[j]) {
                    isPairwiseDistinct = false;
                    break;
                }
            }
        }

        System.out.println(isPairwiseDistinct ? "YES" : "NO");
    }
}
```

This code reads the number of elements in the sequence and then reads the elements themselves. It then iterates over the sequence and checks if any two elements are equal. If it finds any pair of equal elements, it sets the `isPairwiseDistinct` variable to `false` and breaks out of the loo

In [26]:
qa_java200 = pd.DataFrame(data = {"problem": Java200["id"],
                                  "rag_answer": rag_answers,
                                  "gemma_answer": gemma_answers})

In [27]:
qa_java200.head()

Unnamed: 0,problem,rag_answer,gemma_answer
0,p02779.html,```java\nimport java.util.Scanner;\n\npublic c...,```java\nimport java.util.*;\n\npublic class P...
1,p02707.html,```java\nimport java.util.*;\n\npublic class S...,```java\nimport java.util.*;\n\npublic class B...
2,p02771.html,```java\nimport java.util.Scanner;\n\npublic c...,```java\nimport java.util.Scanner;\n\npublic c...
3,p02417.html,```java\nimport java.util.*;\n\npublic class C...,```java\nimport java.util.*;\n\npublic class C...
4,p04030.html,```java\nimport java.util.Scanner;\n\npublic c...,```java\nimport java.util.Scanner;\n\npublic c...


In [28]:
qa_java200.to_csv("qa_java200.csv")

In [None]:
codellama_answers = []
i = 0
for problem in Java200["id"]:
    question = Java200["desc"][i]
    clm_a = codellama_chain.invoke(question)
    print("Codellama answer to: " + problem + "\n\n" + clm_a + "\n\n")
    codellama_answers.append(clm_a)
    print(i)
    print("Complete \n")
    i+=1