In [1]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/



Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 133, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 133 (delta 40), reused 35 (delta 23), pack-reused 68[K
Receiving objects: 100% (133/133), 2.15 MiB | 5.89 MiB/s, done.
Resolving deltas: 100% (67/67), done.


In [2]:
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
import json
import random
import pandas as pd

# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()



# Use the model

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="What is a good place for travel in the US?"
    )
    + MODEL_CHAT_TEMPLATE.format(prompt="California.")
    + USER_CHAT_TEMPLATE.format(prompt="What can I do in California?")
    + "<start_of_turn>model\n"
)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)


  return self.fget.__get__(instance, owner)()


"The Golden Gate Bridge:\n- Take a scenic drive across the Golden Gate Bridge or a boat ride underneath it.\nThe Golden Gate Bridge is a must-see attraction in San Francisco's Bay Area.\n\nAltos and Monterey:\n- Visit the charming town of Altos' historical downtown and the world-renowned Monterey Bay Aquarium.\n\nLake Tahoe:\n- Hike or drive to the stunning Lake Tahoe, a jewel of the Sierra Nevada Mountains.\n\nSan Francisco:\n- Explore the Golden"

In [3]:
class PromptGenerator(object):
    def __init__(self):
        """
        Setup generation parameters for Gemma.
        """
        self.output_len = 20
        self.temperature = 0.0
        self.top_p = 1.0

    def create_prompt(self, sample):
        """
        Input sample is a dictionary consisting of following fields
        'table': A dict containing table data and meta-data same as in Assignment 2
        'question': A python string for the question on the table.
        
        The function must return the prompt as a python string.
        """
        # Example prompt
        prompt = """Read the following table and answer the related question.\n\n"""
        prompt += 'TABLE:\n'
        prompt += ','.join(
            [f'"{cc}"' for cc in sample['table']['cols']]
        ) + '\n'
        for row in sample['table']['rows']:
            prompt += ','.join(
                [f'"{rr}"' for rr in row]
            ) + '\n'
        prompt += '\n'
        prompt += 'QUESTION: ' + sample['question'] + '\n'
        prompt += 'Now give the correct column and the correct rows'

        return prompt
    
    def create_prompt_shot(self, sample):
        """
        Input sample is a dictionary consisting of following fields
        'table': A dict containing table data and meta-data same as in Assignment 2
        'question': A python string for the question on the table.
        
        The function must return the prompt as a python string.
        """
        # Example prompt
        prompt = """Read the following table and answer the related question.\n\n"""
        prompt += 'TABLE:\n'
        prompt += ','.join(
            [f'"{cc}"' for cc in sample['table']['cols']]
        ) + '\n'
        for row in sample['table']['rows']:
            prompt += ','.join(
                [f'"{rr}"' for rr in row]
            ) + '\n'
        prompt += '\n'
        prompt += 'QUESTION: ' + sample['question'] + '\n'
#         print(sample['label_col'])
        prompt += 'The correct column is ' + sample['label_col'] + '\n'
        # convert sample['label_rows'] to string

        prompt += 'The correct rows are ' + ', '.join([str(rr) for rr in sample['label_rows']]) + '\n'
        prompt += '\n'

        return prompt
    
    def post_process(self, gen_text):
        """
        Input gen_text is a python string generated by Gemma for the prompt.
        
        The function must return a single python tuple (int, string)
        indicating the row and the column of the answer cell.
        """
#         gen_text = gen_text.split('\n')[0]
#         # now remove the first 4 words from the generated text
#         gen_text = ' '.join(gen_text.split()[4:])
        return gen_text

    

In [4]:
sample_shot = {
    "question":"When was the opponent Poland and the match type EC -qualifier?",
    "table":{
        "cols":["Date","Location","Opponenent","Result","Match type"],
        "rows":[
            ["29 March 2000","Debrecen","Poland","0-0 (draw)","friendly"],
            ["16 August 2000","Budapest","Austria","1-1 (draw)","friendly"],
            ["3 September 2000","Budapest","Italy","2-2 (draw)","WC -qualifier"],
            ["15 August 2001","Budapest","Germany","2-5 (defeat)","friendly"],
            ["1 September 2001","Tbilisi","Georgia","1-3 (defeat)","WC-qualifier"],
            ["5 September 2001","Budapest","Romania","0-2 (defeat)","WC-qualifier"],
            ["6 October 2001","Parma","Italy","0-0 (draw)","WC-qualifier"],
            ["14 November 2001","Budapest","Macedonia","5-0 (win)","friendly"],
            ["12 February 2002","Larnaca","Czech Rep.","0-2 (defeat)","friendly"],
            ["13 February 2002","Limassol","Switzerland","1-2 (defeat)","friendly"],
            ["8 May 2002","P\u00e9cs","Croatia","0-2 (defeat)","friendly"],
            ["21 August 2002","Budapest","Spain","1-1 (draw)","friendly"],
            ["7 September 2002","Reykjav\u00edk","Iceland","2-0 (win)","friendly"],
            ["12 October 2002","Stockholm","Sweden","1-1 (draw)","EC-qualifier"],
            ["16 October 2002","Budapest","San Marino","3-0 (win)","EC -qualifier"],
            ["20 November 2002","Budapest","Moldova","1-1 (draw)","friendly"],
            ["12 February 2003","Larnaca","Bulgaria","0-1 (defeat)","friendly"],
            ["29 March 2003","Chorz\u00f3w","Poland","0-0 (draw)","EC -qualifier"],
            ["2 April 2003","Budapest","Sweden","1-2 (defeat)","EC -qualifier"],
            ["30 April 2003","Budapest","Luxembourg","5-1 (win)","friendly"],
            ["19 February 2004","Limassol","Latvia","2-1 (win)","friendly"],
            ["21 February 2004","Limassol","Romania","0-3 (defeat)","friendly"],
            ["9 February 2011","Dubai","Azerbaijan","2-0 (win)","friendly"],
            ["29 March 2011","Amsterdam","Netherlands","3-5 (defeat)","EC -qualifier"],
            ["3 June 2011","Luxembourg","Luxembourg","1-0 (win)","friendly"]
        ],
        "types":["text","text","text","text","text"],
        "caption":"National team matches",
    },
    "label_col":"Date",
    "label_row":[17]
}
sample_shot

sample_shot2 = {
    "question":"What is the total score when 7 is the average ranking?",
    "table":{
        "cols":["Average Ranking","Competitive Finish","Couple","Number Of Dances","Total Score","Average"],
        "rows":[
            ["1","1","Bridie & Craig","15","509","35.9"],
            ["2","3","David & Karina","12","360","30.0"],
            ["3","4","Patti & Sandro","10","295","29.5"],
            ["4","2","Anh & Luda","15","421","27.0"],
            ["5","9","Corinne & Csaba","3","77","25.7"],
            ["6","5","Mark & Linda","8","204","25.5"],
            ["7","8","Elka & Michael","4","100","25.0"],
            ["8","6","James & Olya","7","169","24.1"],
            ["9","7","Jessica & Serghei","5","120","24.0"]
        ],
        "types":["real","real","text","real","real","text"],
        "caption":"Average Chart"
    },
    "label_col":"Total Score",
    "label_row":[6]
}
sample_shot2

sample_shot3 = {
    "question":"What is the name of the subject who ran in the general election for Queen Anne's County State's Attorney?",
    "table":{
        "cols":["Year","Office","Election","Subject","Party","Votes"],
        "rows":[
            ["2002","Queen Anne's County State's Attorney","General","Frank Kratovil","Democratic","9,169"],
            ["2006","Queen Anne's County State's Attorney","General","Frank Kratovil","Democratic","13,894"],
            ["2008","U.S. House , Maryland's 1st district","Primary","Frank Kratovil","Democratic","28,566"],
            ["2008","U.S. House , Maryland's 1st district","General","Frank Kratovil","Democratic","177,065"],
            ["2010","U.S. House , Maryland's 1st district","General","Andy Harris","Republican","155,118"]],
        "types":["real","text","text","text","text","real"],
        "caption":"Electoral history"
    },
    "label_col":"Subject",
    "label_row":[0,1]
}
sample_shot3

{'question': "What is the name of the subject who ran in the general election for Queen Anne's County State's Attorney?",
 'table': {'cols': ['Year', 'Office', 'Election', 'Subject', 'Party', 'Votes'],
  'rows': [['2002',
    "Queen Anne's County State's Attorney",
    'General',
    'Frank Kratovil',
    'Democratic',
    '9,169'],
   ['2006',
    "Queen Anne's County State's Attorney",
    'General',
    'Frank Kratovil',
    'Democratic',
    '13,894'],
   ['2008',
    "U.S. House , Maryland's 1st district",
    'Primary',
    'Frank Kratovil',
    'Democratic',
    '28,566'],
   ['2008',
    "U.S. House , Maryland's 1st district",
    'General',
    'Frank Kratovil',
    'Democratic',
    '177,065'],
   ['2010',
    "U.S. House , Maryland's 1st district",
    'General',
    'Andy Harris',
    'Republican',
    '155,118']],
  'types': ['real', 'text', 'text', 'text', 'text', 'real'],
  'caption': 'Electoral history'},
 'label_col': 'Subject',
 'label_row': [0, 1]}

In [5]:
# sample = {
#     "question":"What is the sum of week(s) with an attendance of 30,751?",
#     "table":{
#         "cols":["Week","Date","Opponent","Result","Attendance"],
#         "rows":[
#             ["1","August 6, 1973","San Francisco 49ers","L 27\u201316","65,707"],
#             ["2","August 11, 1973","at Los Angeles Rams","T 21\u201321","54,385"],
#             ["3","August 19, 1973","vs. Cincinnati Bengals at Columbus, Ohio","W 24\u20136","73,421"],
#             ["4","August 25, 1973","vs. Atlanta Falcons at Knoxville","W 20\u201317","40,831"],
#             ["5","September 1, 1973","Detroit Lions","L 16\u201313","64,088"],
#             ["6","September 8, 1973","vs. New York Giants at Akron","L 21\u201310","30,751"]
#         ],
#         "types":["real","text","text","text","real"],
#         "caption":"Exhibition schedule"
#     }
# }
# sample
# prompt_gen = PromptGenerator()
# prompt_act = prompt_gen.create_prompt(sample)
# prompt_shot = prompt_gen.create_prompt_shot(sample_shot)
# prompt_shot2 = prompt_gen.create_prompt_shot(sample_shot2)
# prompt = prompt_shot + prompt_shot2 + prompt_act
# print(prompt)

# gen_text = model.generate(
#     USER_CHAT_TEMPLATE.format(prompt=prompt),
#     device=device,
#     output_len=prompt_gen.output_len,
#     temperature=prompt_gen.temperature,
#     top_p=prompt_gen.top_p,
# )

# print(gen_text)

# answer = prompt_gen.post_process(gen_text)
# answer

In [6]:
data = []
with open('/kaggle/input/a2-val/A2_val.jsonl', 'r') as f:
    for line in f:
        data.append(json.loads(line))

# Sample 100 random samples
random.seed(42)
samples = random.sample(data, 20)



In [7]:
# Generate the answers
prompt_generator = PromptGenerator()
prompt_shot = prompt_generator.create_prompt_shot(sample_shot)
prompt_shot2 = prompt_generator.create_prompt_shot(sample_shot2)
prompt_shot3 = prompt_generator.create_prompt_shot(sample_shot3)
answers = []
num=1
maxrow=0
for sample in samples:
    print("Sample number is "+ str(num))
    print("Question is "+ sample['question'])
    ## print the number of rows and columns in the table
#     print("Number of rows in the table is "+ str(len(sample['table']['rows'])))
#     print("Number of columns in the table is "+ str(len(sample['table']['cols'])))
    if len(sample['table']['rows']) > 100:
        sample['table']['rows'] = sample['table']['rows'][:100]
    cur_sample = {'question': sample['question'], 'table': sample['table']}
    maxrow = max(maxrow,len(sample['table']['rows']))
    prompt_act = prompt_generator.create_prompt(cur_sample)
    prompt = prompt_shot + prompt_shot2 + prompt_shot3 + prompt_act
    answer = model.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        device=device,
        output_len=prompt_generator.output_len,
        temperature=prompt_generator.temperature,
        top_p=prompt_generator.top_p,
    )
#     print("Gen text is "+ answer)
    answer = prompt_generator.post_process(answer)
    print("Answer got is " + answer)
    print("Column should be "+ sample['label_col'][0])
    print("Row should be "+ sample['label_row'])
    print("===================================================================")
#     answers.append(answer)
    num+=1
    del cur_sample
    del prompt_act
    del prompt
    del answer

# Now find the accuracy of the model
# correct = 0
# for sample, answer in zip(samples, answers):
#     if sample['label_col'][0] == answer:
#         correct += 1

# accuracy = correct / len(samples)
# print("accuracy is "+ accuracy)
# print("maxrow is "+ maxrow)
# print("correct ans is "+ correct)

TypeError: can only concatenate str (not "list") to str

In [None]:
print(correct)