In [None]:
%load_ext jupyter_black

In [None]:

from fite.main import Engine, HyperParameterStrategy

In [None]:
pipeline = Engine["TAF"]

In [None]:
for strategy in HyperParameterStrategy:
    results, *_ = pipeline.generate_forecast("TAF KDAA 282100Z 2821/3003 32010G15KT 3200", strategy=strategy)
    results = '\n '.join(results)
    print(f"""{strategy.name=} {strategy.value=}\n{results}\n""")



In [62]:
from typing import TypedDict, Any
from pathlib import Path
import dataclasses
import pandas as pd
from datasets.arrow_dataset import Dataset as ArrowDataset
from datasets.dataset_dict import DatasetDict  
from typing import TypeVar, Generic, Literal, Iterable

_KT = TypeVar("_KT")
_VT_co = TypeVar("_VT_co", covariant=True)

class JSONLine(TypedDict):
    """A JSON Lines file"""
    metadata: Any
    prompt:str
    completion: str


DatasetDictType = TypedDict("DatasetDictType", {"train": list[JSONLine], "validation": list[JSONLine], "test": list[JSONLine]})




class Dataset(ArrowDataset):
    """A dataset dictionary"""

    def __getitem__(self, __key: Literal["train", "test", "validate"]) -> ArrowDataset:
        return super().__getitem__(__key) # type: ignore
    
        
    @classmethod
    def from_json(cls, path: str, split: float = 0.2, shuffle: bool = True) -> "DatasetDict":
        return ArrowDataset.from_json(path).train_test_split(test_size=split, shuffle=shuffle) # type: ignore

class TAFDataset(DatasetDict):
    def forward_metadata(self, s:pd.Series):
        return s.str.extract(r"\sTX?(?P<max_temp>M?\d{2})\/\d{4}Z\sTN?(?P<min_temp>M?\d{2})\/\d{4}Z$")

from src.fite.util import SpecialTokens
TAFDataset.from_json("store/gpt2-taf-base1/training-data.jsonl")
import toml
@dataclasses.dataclass
class RawTextFileHandler:
    """A class to handle raw text files"""
    path: Path
    split_pattern: str = r"\n+###+\n+"
    split: float = 0.2
    shuffle: bool = True
    def metadata_handle(self, _:pd.Series) -> pd.DataFrame:

        raise NotImplementedError

    def __post_init__(self):
        import re
        sep = re.compile(self.split_pattern)
        with self.path.open("r") as f:
            s = pd.Series(sep.split(f.read()), name="text").str.strip()

        df = pd.DataFrame(s.to_frame().join(self.metadata_handle(s)).pipe(self._generate_jsonl),
        columns=["prompt", "completion", "metadata"]
        ).drop_duplicates(ignore_index=True)

        self._frame = df

    @staticmethod
    def _generate_jsonl(df:pd.DataFrame) -> Iterable[tuple[str,str,str]]:
        """
        iterate over the rows splitting each word, the split text is used to create the prompt and completion.
        the __text__ is popped from each dict to create the prompt and completion.
        the remaining dict is used as the metadata.
        """
        df["text"] = df.text.str.strip().str.split()
        df.columns =  df.columns.str.replace("_", "-")
        # print(df.set_index([col for col in df.columns if col != "text"]).text.items())
        
        
        
        for _, metadata in df.iterrows():
            prompt = SpecialTokens.bos_token
            text_list = metadata.pop("text") 
            metadata = (f"{SpecialTokens.metadata}\n" + '\n'.join(f'{k} = {v}' for k, v in metadata.items()))
            
            for i, text in enumerate(text_list):
                prompt += f"{text} "
                completion = " ".join(text_list[i  :]) + SpecialTokens.eos_token
                yield metadata, prompt, completion


    def to_dataset(self) -> DatasetDict:
        return ArrowDataset.from_pandas(self._frame).train_test_split(test_size=self.split, shuffle=self.shuffle)

class TAFTextFile(RawTextFileHandler):
    def metadata_handle(self, s:pd.Series) -> pd.DataFrame:
        return s.str.extract(r"\sTX?(?P<maximum_temperature>M?\d{2})\/\d{4}Z\sTN?(?P<minimum_temperature>M?\d{2})\/\d{4}Z$")




x = (TAFTextFile(Path("store/gpt2-taf-base1/training-data.txt")).to_dataset())#["train"].to_pandas()#.to_pandas(
print(x)


Using custom data configuration default-7bc700dd09388bc4
Found cached dataset json (/home/leaver2000/.cache/huggingface/datasets/json/default-7bc700dd09388bc4/0.0.0)


DatasetDict({
    train: Dataset({
        features: ['prompt', 'completion', 'metadata'],
        num_rows: 19461
    })
    test: Dataset({
        features: ['prompt', 'completion', 'metadata'],
        num_rows: 4866
    })
})
