# Answering the 200 train problems on the Java250 dataset baseline Gemma

In [1]:
import os
import pandas as pd
import numpy as np
import time

from langchain_community.embeddings import OllamaEmbeddings
from langchain_community.vectorstores import Chroma
from tqdm import tqdm


from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.prompts import PromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_community.llms import Ollama

## Import the models

### Vector store

In [2]:
embeddings = OllamaEmbeddings(model="gemma:7b-instruct", num_gpu = 2,   num_thread = 24)
# load from disk
db = Chroma(persist_directory="bp_chroma_db", embedding_function=embeddings, )
retriever = db.as_retriever()

### LLMs

In [3]:
gemma = Ollama(model="gemma:7b-instruct", num_gpu=2, num_thread = 24, num_ctx = 4096)

## Custom Prompts (Just to tell it to answer in Java):

In [4]:
gemma_template = """ <start_of_turn>user
Use Java code in your answer.
Question: {question}
<end_of_turn>
<start_of_turn>model
"""

java_prompt_gemma = ChatPromptTemplate(input_variables=['question'], output_parser=None, partial_variables={}, 
                       messages=[HumanMessagePromptTemplate(prompt=PromptTemplate(input_variables=['question'], 
                                                                              output_parser=None, partial_variables={}, template= gemma_template, 
                                                                              template_format='f-string', validate_template=True), additional_kwargs={})])

### Chain making

In [5]:
def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)

gemma_chain = (
    {"question": RunnablePassthrough()}
    | java_prompt_gemma
    | gemma
    | StrOutputParser()
)


## QA

read the questions in

In [6]:
Java200 = pd.read_csv("Java200.csv")

Loop for answers

In [8]:
gemma_answers = []
i = 0
for i, problem in enumerate(Java200["id"]):
    question = Java200["desc"][i]
    gm_a = gemma_chain.invoke(question)
    print("Gemma Model answer to: " + problem + "\n\n" + gm_a + "\n\n")
    gemma_answers.append(gm_a)
    print(i)
    print("Complete \n")

Gemma Model answer to: p02779.html

```java
import java.util.*;

public class PairwiseDistinct {

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        int N = scanner.nextInt();

        int[] A = new int[N];

        for (int i = 0; i < N; i++) {
            A[i] = scanner.nextInt();
        }

        boolean isPairwiseDistinct = true;

        for (int i = 0; i < N; i++) {
            for (int j = i + 1; j < N; j++) {
                if (A[i] == A[j]) {
                    isPairwiseDistinct = false;
                    break;
                }
            }
        }

        System.out.println(isPairwiseDistinct ? "YES" : "NO");
    }
}
```


0
Complete 

Gemma Model answer to: p02707.html

```java
import java.util.*;

public class BossSubordinate {

    public static void main(String[] args) {
        Scanner scanner = new Scanner(System.in);

        int N = scanner.nextInt();

        int[] A = new int[N];

        for (int i =

In [9]:
qa_basegemma_java200 = pd.DataFrame(data = {"problem": Java200["id"],
                                  "gemma_answer": gemma_answers})

In [10]:
qa_basegemma_java200.head()

Unnamed: 0,problem,gemma_answer
0,p02779.html,```java\nimport java.util.*;\n\npublic class P...
1,p02707.html,```java\nimport java.util.*;\n\npublic class B...
2,p02771.html,```java\nimport java.util.*;\n\npublic class P...
3,p02417.html,```java\nimport java.util.*;\n\npublic class C...
4,p04030.html,```java\nimport java.util.Scanner;\n\npublic c...


In [11]:
qa_basegemma_java200.to_csv("qa_java200_basegemma.csv")