In [1]:
%load_ext jupyter_black

In [2]:
from src.fite.api import PipelineEngine

engine = PipelineEngine.load_from_pyproject("pyproject.toml")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
base_line = """\
TAF [TXM01 TNM07] KBLV 071600Z 0716/0822 VRB06KT 8000 BR SCT020 OVC035 QNH3025INS
BECMG 0802/0803 VRB06KT 4800 BR SCT009 OVC020 QNH3026INS
BECMG 0808/0809 VRB06KT 8000 -RA OVC009 QNH3030INS TX13/0721Z TN07/0813Z"""


base = "TAF [TXM01 TNM07] KBLV 071600Z 0716/0822"
lines = """VRB06KT 8000 BR SCT020 OVC035 QNH3025INS
BECMG 0802/0803 VRB06KT 4800 BR SCT009 OVC020 QNH3026INS
BECMG 0808/0809 VRB06KT 8000 -RA OVC009 QNH3030INS TX13/0721Z TN07/0813Z""".split(
    " "
)

prompts: list[str] = []
for line in lines:
    base += " " + line
    prompts.append(base)


engine.list_models()

['gpt2-taf-base1', 'gpt2-taf-base1.dev1']

In [4]:
import numpy as np
import pandas as pd

base_pipeline = engine.get_pipeline("gpt2-taf-base1")
dev_pipeline = engine.get_pipeline("gpt2-taf-base1.dev1")
assert base_pipeline.tokenizer.encode(base_line) == dev_pipeline.tokenizer.encode(
    base_line
)
actual_forecast = np.array(base_pipeline.tokenizer.encode(base_line))
actual_forecast  # type: ignore

array([50257,    58, 29551,    44,   486, 29025,    44,  2998,    60,
         509,  9148,    53,  8753, 36150,    57,  8753,  1433,    14,
        2919,  1828,  6453,    33,  3312, 42176, 38055, 11177,   311,
        4177, 33618,   440, 15922, 44215,  1195, 33863,  1270,  1495,
       20913,   198,    33,  2943, 20474,   657, 30863,    14, 33057,
          18,  6453,    33,  3312, 42176,  4764,   405, 11177,   311,
        4177, 28694,   440, 15922, 33618,  1195, 33863,  1270,  2075,
       20913,   198,    33,  2943, 20474,   657, 28362,    14, 33057,
          24,  6453,    33,  3312, 42176, 38055,   532,  3861,   440,
       15922, 28694,  1195, 33863,  1270,  1270, 20913, 15326,  1485,
          14,  2998,  2481,    57, 29025,  2998,    14,  2919,  1485,
          57])

In [5]:
results = base_pipeline.generate(
    prompts,
    strategy="GREEDY",
)
for prompt, result in zip(prompts, results):
    generated_text = "\n".join(result)
    print(prompt + "..." + generated_text[len(prompt) :], "\n")

TAF [TXM01 TNM07] KBLV 071600Z 0716/0822 VRB06KT... 9999 FEW200 QNH3030INS
BECMG 0721/0722 VRB06KT 9999 FEW200 QNH3030INS
BECMG 0814/0815 VRB06KT 9999 FEW200 QNH3030INS TXM01/0721Z TNM07/0812Z 

TAF [TXM01 TNM07] KBLV 071600Z 0716/0822 VRB06KT 8000... -SN OVC010 620108 QNH3003INS
BECMG 0722/0723 VRB06KT 9999 NSW BKN007 OVC015 620079 QNH3006INS
BECMG 0814/0815 VRB06KT 9999 BKN015 620159 QNH3017INS TXM01/0700Z TNM07/0812Z 

TAF [TXM01 TNM07] KBLV 071600Z 0716/0822 VRB06KT 8000 BR... SCT015 BKN020 QNH3021INS
BECMG 0721/0722 VRB06KT 9999 NSW FEW020 BKN040 QNH3021INS
BECMG 0811/0812 VRB06KT 9999 FEW200 QNH3025INS TXM01/0721Z TNM07/0812Z 

TAF [TXM01 TNM07] KBLV 071600Z 0716/0822 VRB06KT 8000 BR SCT020... BKN030 620303 QNH2998INS
BECMG 0721/0722 VRB06KT 9999 NSW FEW030 SCT250 QNH2999INS
BECMG 0814/0815 VRB06KT 9999 FEW200 QNH3000INS TXM01/0721Z TNM07/0812Z 

TAF [TXM01 TNM07] KBLV 071600Z 0716/0822 VRB06KT 8000 BR SCT020 OVC035... 620208 QNH2993INS
BECMG 0721/0722 VRB06KT 9999 NSW SCT030 BKN

In [6]:
def make_data(arr):
    return {
        "mean": np.mean(arr),
        "std": np.std(arr),
        "min": np.min(arr),
        "max": np.max(arr),
        "median": np.median(arr),
        "var": np.var(arr),
        "sum": np.sum(arr),
        "diff": np.sum(arr) - np.sum(actual_forecast),
    }


def make_results(pipeline):
    data_array = []
    for result in pipeline.generate(prompts, strategy="GREEDY"):
        completed_text = "\n".join(result)
        encoding = pipeline.tokenizer(completed_text)
        input_ids = encoding["input_ids"]
        data = make_data(input_ids)
        data_array.append(data)

    return pd.DataFrame(data_array)


base_df = make_results(base_pipeline)
dev_df = make_results(dev_pipeline)
base_df

Unnamed: 0,mean,std,min,max,median,var,sum,diff
0,7751.042553,12015.557835,14,50257,2167.0,144373600.0,728598,-353832
1,9154.601852,12554.136162,14,50257,2919.0,157606300.0,988697,-93733
2,9575.217822,13239.365383,14,50257,2919.0,175280800.0,967097,-115333
3,10060.647059,13648.086041,14,50257,2943.0,186270300.0,1026186,-56244
4,10743.834646,14214.383723,14,50257,2943.0,202048700.0,1364467,282037
5,9940.267327,13775.164698,14,50257,2919.0,189755200.0,1003967,-78463
6,12135.621622,14762.537679,14,50257,2998.0,217932500.0,1347054,264624
7,12135.621622,14762.537679,14,50257,2998.0,217932500.0,1347054,264624
8,11134.447619,14593.067942,14,50257,2919.0,212957600.0,1169117,86687
9,11134.447619,14593.067942,14,50257,2919.0,212957600.0,1169117,86687


In [7]:
dev_df

Unnamed: 0,mean,std,min,max,median,var,sum,diff
0,9292.863014,13380.785223,14,50257,2919.0,179045400.0,678379,-404051
1,9292.863014,13380.785223,14,50257,2919.0,179045400.0,678379,-404051
2,9980.398058,13334.711225,14,50257,2943.0,177814500.0,1027981,-54449
3,10246.402597,14057.510047,14,50257,2998.0,197613600.0,788973,-293457
4,10611.436893,14416.658266,14,50257,2943.0,207840000.0,1092978,10548
5,10451.144231,14205.091277,14,50257,2931.0,201784600.0,1086919,4489
6,11548.06383,14742.100047,14,50257,2970.5,217329500.0,1085518,3088
7,11548.06383,14742.100047,14,50257,2970.5,217329500.0,1085518,3088
8,10327.107843,14201.302197,14,50257,2919.0,201677000.0,1053365,-29065
9,10327.107843,14201.302197,14,50257,2919.0,201677000.0,1053365,-29065


In [8]:
base_df - dev_df

Unnamed: 0,mean,std,min,max,median,var,sum,diff
0,-1541.820461,-1365.227388,0,0,-752.0,-34671780.0,50219,50219
1,-138.261162,-826.649061,0,0,0.0,-21439080.0,310318,310318
2,-405.180236,-95.345842,0,0,-24.0,-2533728.0,-60884,-60884
3,-185.755539,-409.424006,0,0,-55.0,-11343340.0,237213,237213
4,132.397752,-202.274543,0,0,0.0,-5791331.0,271489,271489
5,-510.876904,-429.926579,0,0,-12.0,-12029460.0,-82952,-82952
6,587.557792,20.437632,0,0,27.5,603004.9,261536,261536
7,587.557792,20.437632,0,0,27.5,603004.9,261536,261536
8,807.339776,391.765745,0,0,0.0,11280650.0,115752,115752
9,807.339776,391.765745,0,0,0.0,11280650.0,115752,115752
