In [1]:
# Setup the environment
!pip install -q -U immutabledict sentencepiece 
!git clone https://github.com/google/gemma_pytorch.git
!mkdir /kaggle/working/gemma/
!mv /kaggle/working/gemma_pytorch/gemma/* /kaggle/working/gemma/



Cloning into 'gemma_pytorch'...
remote: Enumerating objects: 133, done.[K
remote: Counting objects: 100% (65/65), done.[K
remote: Compressing objects: 100% (41/41), done.[K
remote: Total 133 (delta 40), reused 35 (delta 23), pack-reused 68[K
Receiving objects: 100% (133/133), 2.15 MiB | 20.01 MiB/s, done.
Resolving deltas: 100% (67/67), done.


In [2]:
import sys 
sys.path.append("/kaggle/working/gemma_pytorch/") 
from gemma.config import GemmaConfig, get_config_for_7b, get_config_for_2b
from gemma.model import GemmaForCausalLM
from gemma.tokenizer import Tokenizer
import contextlib
import os
import torch
import json
import random
import pandas as pd
import re

# Load the model
VARIANT = "7b-it-quant" 
MACHINE_TYPE = "cuda" 
weights_dir = '/kaggle/input/gemma/pytorch/7b-it-quant/2' 

@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
  """Sets the default torch dtype to the given dtype."""
  torch.set_default_dtype(dtype)
  yield
  torch.set_default_dtype(torch.float)

model_config = get_config_for_2b() if "2b" in VARIANT else get_config_for_7b()
model_config.tokenizer = os.path.join(weights_dir, "tokenizer.model")
model_config.quant = "quant" in VARIANT

device = torch.device(MACHINE_TYPE)
with _set_default_tensor_type(model_config.get_dtype()):
  model = GemmaForCausalLM(model_config)
  ckpt_path = os.path.join(weights_dir, f'gemma-{VARIANT}.ckpt')
  model.load_weights(ckpt_path)
  model = model.to(device).eval()



# Use the model

USER_CHAT_TEMPLATE = "<start_of_turn>user\n{prompt}<end_of_turn>\n"
MODEL_CHAT_TEMPLATE = "<start_of_turn>model\n{prompt}<end_of_turn>\n"

prompt = (
    USER_CHAT_TEMPLATE.format(
        prompt="What is a good place for travel in the US?"
    )
    + MODEL_CHAT_TEMPLATE.format(prompt="California.")
    + USER_CHAT_TEMPLATE.format(prompt="What can I do in California?")
    + "<start_of_turn>model\n"
)

model.generate(
    USER_CHAT_TEMPLATE.format(prompt=prompt),
    device=device,
    output_len=100,
)


  return self.fget.__get__(instance, owner)()


"There's a vast amount to see and do in California depending on your interests and style of travel. Here's a breakdown to help you decide:\n\n**For the beach lovers:**\n- Malibu, Santa Barbara, Monterey, Monterey Bay, Lake Tahoe\n\n**For the city explorers:**\n- San Francisco, Los Angeles, San Diego, San Jose\n\n**For the outdoors:**\n- Yosemite National Park, Joshua Tree National Park, Death Valley National Park, Lake Tahoe, Redwood National Park"

In [19]:
class PromptGenerator(object):
    def __init__(self):
        """
        Setup generation parameters for Gemma.
        """
        self.output_len = 25
        self.temperature = 0.0
        self.top_p = 1.0

    def create_prompt(self, sample):
        """
        Input sample is a dictionary consisting of following fields
        'table': A dict containing table data and meta-data same as in Assignment 2
        'question': A python string for the question on the table.
        
        The function must return the prompt as a python string.
        """
        # Example prompt
        prompt = """Read the following table and answer the related question based on the table.\n\n"""
        prompt += 'TABLE:\n'
        prompt += 'Column Name: '
        prompt += ','.join(
            [f'"{cc}"' for cc in sample['table']['cols']]
        ) + '\n'
        i=1
        for row in sample['table']['rows']:
            prompt += 'Row '+ str(i)+ ': '
            i+=1
            prompt += ','.join(
                [f'"{rr}"' for rr in row]
            ) + '\n'
        prompt += '\n'
        prompt += 'QUESTION: ' + sample['question'] +'\n'
        prompt += ' Give the column name and row number from the table to answer the question.' + '\n'

        return prompt
    
    def create_prompt_shot(self, sample):
        """
        Input sample is a dictionary consisting of following fields
        'table': A dict containing table data and meta-data same as in Assignment 2
        'question': A python string for the question on the table.
        
        The function must return the prompt as a python string.
        """
        # Example prompt
        prompt = """Read the following table and answer the related question based on the table.\n\n"""
        prompt += 'TABLE:\n'
        prompt += 'Column Name: '
        prompt += ','.join(
            [f'"{cc}"' for cc in sample['table']['cols']]
        ) + '\n'
        i=1
        for row in sample['table']['rows']:
            prompt += 'Row '+ str(i)+ ': '
            i+=1
            prompt += ','.join(
                [f'"{rr}"' for rr in row]
            ) + '\n'
        prompt += '\n'
        prompt += 'QUESTION: ' + sample['question'] + '\n' 
        prompt += ' Give the column name and row number from the table to answer the question.' + '\n'
        prompt += "The column name is '" + sample['label_col'] + "' and the row number is " + sample['label_row'] + '\n'
        prompt += '\n'

        return prompt
    
    def post_process(self, gen_text):
        gen_text = gen_text.split("\n")[0]
        pattern = r"The column name is \'(.+?)\' and the row number is (\d+)"

        match = re.search(pattern, gen_text)

        if match:
                column_name = match.group(1)
                row_number = int(match.group(2))
                column_name = column_name.replace("'", "")
                row_number-=1
                row_list = []
                row_list.append(row_number)
                return column_name, row_list
        else:
                return None, None

    

In [4]:
sample_shot = {
    "question":"When was the opponent Poland and the match type EC -qualifier?",
    "table":{
        "cols":["Date","Location","Opponenent","Result","Match type"],
        "rows":[
            ["29 March 2000","Debrecen","Poland","0-0 (draw)","friendly"],
            ["16 August 2000","Budapest","Austria","1-1 (draw)","friendly"],
            ["3 September 2000","Budapest","Italy","2-2 (draw)","WC -qualifier"],
            ["15 August 2001","Budapest","Germany","2-5 (defeat)","friendly"],
            ["1 September 2001","Tbilisi","Georgia","1-3 (defeat)","WC-qualifier"],
            ["5 September 2001","Budapest","Romania","0-2 (defeat)","WC-qualifier"],
            ["6 October 2001","Parma","Italy","0-0 (draw)","WC-qualifier"],
            ["14 November 2001","Budapest","Macedonia","5-0 (win)","friendly"],
            ["12 February 2002","Larnaca","Czech Rep.","0-2 (defeat)","friendly"],
            ["13 February 2002","Limassol","Switzerland","1-2 (defeat)","friendly"],
            ["8 May 2002","P\u00e9cs","Croatia","0-2 (defeat)","friendly"],
            ["21 August 2002","Budapest","Spain","1-1 (draw)","friendly"],
            ["7 September 2002","Reykjav\u00edk","Iceland","2-0 (win)","friendly"],
            ["12 October 2002","Stockholm","Sweden","1-1 (draw)","EC-qualifier"],
            ["16 October 2002","Budapest","San Marino","3-0 (win)","EC -qualifier"],
            ["20 November 2002","Budapest","Moldova","1-1 (draw)","friendly"],
            ["12 February 2003","Larnaca","Bulgaria","0-1 (defeat)","friendly"],
            ["29 March 2003","Chorz\u00f3w","Poland","0-0 (draw)","EC -qualifier"],
            ["2 April 2003","Budapest","Sweden","1-2 (defeat)","EC -qualifier"],
            ["30 April 2003","Budapest","Luxembourg","5-1 (win)","friendly"],
            ["19 February 2004","Limassol","Latvia","2-1 (win)","friendly"],
            ["21 February 2004","Limassol","Romania","0-3 (defeat)","friendly"],
            ["9 February 2011","Dubai","Azerbaijan","2-0 (win)","friendly"],
            ["29 March 2011","Amsterdam","Netherlands","3-5 (defeat)","EC -qualifier"],
            ["3 June 2011","Luxembourg","Luxembourg","1-0 (win)","friendly"]
        ],
        "types":["text","text","text","text","text"],
        "caption":"National team matches",
    },
    "label_col":"Date",
    "label_row":"17"
}
sample_shot

sample_shot2 = {
    "question":"What is the total score when 7 is the average ranking?",
    "table":{
        "cols":["Average Ranking","Competitive Finish","Couple","Number Of Dances","Total Score","Average"],
        "rows":[
            ["1","1","Bridie & Craig","15","509","35.9"],
            ["2","3","David & Karina","12","360","30.0"],
            ["3","4","Patti & Sandro","10","295","29.5"],
            ["4","2","Anh & Luda","15","421","27.0"],
            ["5","9","Corinne & Csaba","3","77","25.7"],
            ["6","5","Mark & Linda","8","204","25.5"],
            ["7","8","Elka & Michael","4","100","25.0"],
            ["8","6","James & Olya","7","169","24.1"],
            ["9","7","Jessica & Serghei","5","120","24.0"]
        ],
        "types":["real","real","text","real","real","text"],
        "caption":"Average Chart"
    },
    "label_col":"Total Score",
    "label_row":"6"
}
# sample_shot2

# sample_shot3 = {
#     "question":"What was the largest crowd size at a South Melbourne home game?",
#     "table":{
#         "cols":["Home team","Home team score","Away team","Away team score","Venue","Crowd","Date"],
#         "rows":[
#             ["Melbourne","15.15 (105)","St Kilda","5.13 (43)","MCG","11,000","7 June 1947"],
#             ["Hawthorn","10.15 (75)","North Melbourne","10.11 (71)","Glenferrie Oval","12,000","7 June 1947"],
#             ["Carlton","16.21 (117)","Geelong","7.10 (52)","Princes Park","16,000","7 June 1947"],
#             ["South Melbourne","15.9 (99)","Richmond","15.13 (103)","Lake Oval","24,000","7 June 1947"],
#             ["Footscray","12.11 (83)","Fitzroy","7.7 (49)","Western Oval","20,000","7 June 1947"],
#             ["Essendon","9.14 (68)","Collingwood","11.16 (82)","Windy Hill","22,000","7 June 1947"]
#         ],
#         "types":["text","text","text","text","text","real","text"],
#         "caption":"Round 8"
#     },
#     "label_col":"Crowd",
#     "label_row":"3"
# }


In [5]:
# prompt_generator = PromptGenerator()
# prompt_shot = prompt_generator.create_prompt_shot(sample_shot)

In [6]:
# print(prompt_shot)

In [7]:
# samp = {
#     "question":"Which Attendance has an Arena of arrowhead pond of anaheim, and a Loss of giguere (3\u20133)?",
#     "table":{
#         "cols":["Date","Opponent","Score","Loss","Attendance","Series","Arena"],
#         "rows":[
#             ["May 19","Oilers","3\u20131","Bryzgalov (6\u20132)","17,174","0\u20131","Arrowhead Pond of Anaheim"],
#             ["May 21","Oilers","3\u20131","Bryzgalov (6\u20133)","17,264","0\u20132","Arrowhead Pond of Anaheim"],
#             ["May 23","@ Oilers","5\u20134","Bryzgalov (6\u20134)","16,839","0\u20133","Rexall Place"],
#             ["May 25","@ Oilers","6\u20133","Roloson (11\u20135)","16,839","1\u20133","Rexall Place"],
#             ["May 27","Oilers","2\u20131","Giguere (3\u20133)","17,174","1\u20134","Arrowhead Pond of Anaheim"]],
#         "types":["text","text","text","text","real","text","text"],
#         "caption":"Postseason",
#     }    
# }

# samp = {
#     "question":"What was the name for the row with Date From of 2008-02-21?",
#     "table":{
#         "cols":["Date From","Date To","Position","Name","From"],
#         "rows":[["2007-08-08","2007-08-13","GK","Chris Weale","Bristol City"],
#                 ["2007-08-10","End of season","MF","Toumani Diagouraga","Watford"],
#                 ["2007-08-10","End of season","FW","Theo Robinson","Watford"],
#                 ["2007-10-30","End of season","DF","Robbie Threlfall","Liverpool"],
#                 ["2007-11-16","End of season","DF","Lee Collins","Wolverhampton Wanderers"],
#                 ["2008-01-28","2008-04-30","FW","Gary Hooper","Southend United"],
#                 ["2008-02-08","2008-02-10","FW","Sherjill MacDonald","West Bromwich Albion"],
#                 ["2008-02-21","2008-03-24","MF","Stephen Gleeson","Wolverhampton Wanderers"],
#                 ["2008-03-27","2008-04-27","MF","Sammy Igoe","Bristol Rovers"]],
#         "types":["text","text","text","text","text"],
#         "caption":"Loan in"
#     }
# }


In [8]:
# prompt_generator = PromptGenerator()
# prompt_shot = prompt_generator.create_prompt_shot(sample_shot)
# prompt_shot2 = prompt_generator.create_prompt_shot(sample_shot2)
# # prompt_shot3 = prompt_generator.create_prompt_shot(sample_shot3)
# prompt_act = prompt_generator.create_prompt(samp)
# prompt = prompt_shot + prompt_shot2 + prompt_act
# print(prompt)
# answer = model.generate(
#     USER_CHAT_TEMPLATE.format(prompt=prompt),
#     device=device,
#     output_len=prompt_generator.output_len,
#     temperature=prompt_generator.temperature,
#     top_p=prompt_generator.top_p,
# )


In [9]:
# print(answer)

In [10]:
# sample = {
#     "question":"What is the sum of week(s) with an attendance of 30,751?",
#     "table":{
#         "cols":["Week","Date","Opponent","Result","Attendance"],
#         "rows":[
#             ["1","August 6, 1973","San Francisco 49ers","L 27\u201316","65,707"],
#             ["2","August 11, 1973","at Los Angeles Rams","T 21\u201321","54,385"],
#             ["3","August 19, 1973","vs. Cincinnati Bengals at Columbus, Ohio","W 24\u20136","73,421"],
#             ["4","August 25, 1973","vs. Atlanta Falcons at Knoxville","W 20\u201317","40,831"],
#             ["5","September 1, 1973","Detroit Lions","L 16\u201313","64,088"],
#             ["6","September 8, 1973","vs. New York Giants at Akron","L 21\u201310","30,751"]
#         ],
#         "types":["real","text","text","text","real"],
#         "caption":"Exhibition schedule"
#     }
# }
# sample
# prompt_gen = PromptGenerator()
# prompt_act = prompt_gen.create_prompt(sample)
# prompt_shot = prompt_gen.create_prompt_shot(sample_shot)
# prompt_shot2 = prompt_gen.create_prompt_shot(sample_shot2)
# prompt = prompt_shot + prompt_shot2 + prompt_act
# print(prompt)

# gen_text = model.generate(
#     USER_CHAT_TEMPLATE.format(prompt=prompt),
#     device=device,
#     output_len=prompt_gen.output_len,
#     temperature=prompt_gen.temperature,
#     top_p=prompt_gen.top_p,
# )

# print(gen_text)

# answer = prompt_gen.post_process(gen_text)
# answer

In [11]:
data = []
with open('/kaggle/input/a2-val/A2_val.jsonl', 'r') as f:
    for line in f:
        data.append(json.loads(line))

# Sample 100 random samples
random.seed(42)
samples = random.sample(data, 100)



In [None]:
# Generate the answers
prompt_generator = PromptGenerator()
prompt_shot = prompt_generator.create_prompt_shot(sample_shot)
prompt_shot2 = prompt_generator.create_prompt_shot(sample_shot2)
# prompt_shot3 = prompt_generator.create_prompt_shot(sample_shot3)
col_answers = []
row_answers = []
num=1
maxrow=0
for sample in samples:
    print("Sample number is "+ str(num))
    print("Question is "+ sample['question'])
    if len(sample['table']['rows']) > 100:
        sample['table']['rows'] = sample['table']['rows'][:100]
    cur_sample = {'question': sample['question'], 'table': sample['table']}
    maxrow = max(maxrow,len(sample['table']['rows']))
    prompt_act = prompt_generator.create_prompt(cur_sample)
    prompt = prompt_shot + prompt_shot2 + prompt_act
    answer = model.generate(
        USER_CHAT_TEMPLATE.format(prompt=prompt),
        device=device,
        output_len=prompt_generator.output_len,
        temperature=prompt_generator.temperature,
        top_p=prompt_generator.top_p,
    )
    answer_pair = prompt_generator.post_process(answer)
    print(answer)
    print(answer_pair)
    print("Column should be "+ sample['label_col'][0])
    print("Row should be "+ ', '.join([str(rr) for rr in sample['label_row']]))
    print("===================================================================")
    col_answers.append(answer_pair[0])
    row_answers.append(answer_pair[1])
    num+=1
    del cur_sample
    del prompt_act
    del prompt
    del answer
    del answer_pair

correct_col = 0
correct_row = 0
for sample, answer in zip(samples, col_answers):
    if sample['label_col'][0] == answer:
        correct_col += 1
for sample, answer in zip(samples, row_answers):
    if sample['label_row'] == answer:
        correct_row += 1
col_acc = correct_col / len(samples)
row_acc = correct_row / len(samples)

In [22]:
print(col_acc)
print(row_acc)

0.86
0.67
