In [1]:
!git clone https://github.com/time-series-foundation-models/lag-llama/

Cloning into 'lag-llama'...
remote: Enumerating objects: 328, done.[K
remote: Counting objects: 100% (166/166), done.[K
remote: Compressing objects: 100% (79/79), done.[K
remote: Total 328 (delta 114), reused 108 (delta 85), pack-reused 162[K
Receiving objects: 100% (328/328), 234.56 KiB | 5.10 MiB/s, done.
Resolving deltas: 100% (155/155), done.


In [2]:
cd /content/lag-llama

/content/lag-llama


In [3]:
!pip install -r requirements.txt --quiet # this could take some time # ignore the errors displayed by colab

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m57.2/57.2 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.1/7.1 MB[0m [31m29.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.7/67.7 kB[0m [31m5.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m207.3/207.3 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m71.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m778.1/778.1 kB[0m [31m43.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m301.8/301.8 kB[0m [31m22.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [4]:
# download the pretrained model weights from HuggingFace 🤗

!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir /content/lag-llama

Downloading 'lag-llama.ckpt' to '/content/lag-llama/.huggingface/download/lag-llama.ckpt.b5a5c4b8a0cfe9b81bdac35ed5d88b5033cd119b5206c28e9cd67c4b45fb2c96.incomplete'
lag-llama.ckpt: 100% 29.5M/29.5M [00:00<00:00, 43.6MB/s]
Download complete. Moving file to /content/lag-llama/lag-llama.ckpt
/content/lag-llama/lag-llama.ckpt


In [5]:
# import the required packages and the lag llama estimator object

from itertools import islice

from matplotlib import pyplot as plt
import matplotlib.dates as mdates
from tqdm.autonotebook import tqdm

import torch
from gluonts.evaluation import make_evaluation_predictions, Evaluator
from gluonts.dataset.repository.datasets import get_dataset

from gluonts.dataset.pandas import PandasDataset
import pandas as pd

from lag_llama.gluon.estimator import LagLlamaEstimator

  from tqdm.autonotebook import tqdm


# Lag-Llama prediction function

We create a function for Lag-Llama inference that we can reuse for all different types of dataset below. This function returns the predictions for the given prediction horizon. The forecast will be of shape (num_samples, prediction_length), where num_samples is the number of samples sampled from the predicted probability distribution for each timestep.

In [6]:
def get_lag_llama_predictions(dataset, prediction_length, device, context_length=32, use_rope_scaling=False, num_samples=100, nonnegative_pred_samples=True):
    ckpt = torch.load("lag-llama.ckpt", map_location=device) # Uses GPU since in this Colab we use a GPU.
    estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

    rope_scaling_arguments = {
        "type": "linear",
        "factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]),
    }

    estimator = LagLlamaEstimator(
        ckpt_path="lag-llama.ckpt",
        prediction_length=prediction_length,
        context_length=context_length, # Lag-Llama was trained with a context length of 32, but can work with any context length

        # estimator args
        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],
        rope_scaling=rope_scaling_arguments if use_rope_scaling else None,

        nonnegative_pred_samples=True,

        batch_size=8,
        num_parallel_samples=100,
        device=device,
    )

    lightning_module = estimator.create_lightning_module()
    transformation = estimator.create_transformation()
    predictor = estimator.create_predictor(transformation, lightning_module)

    forecast_it, ts_it = make_evaluation_predictions(
        dataset=dataset,
        predictor=predictor,
        num_samples=num_samples
    )
    forecasts = list(forecast_it)
    tss = list(ts_it)

    return forecasts, tss

In [7]:
# Load data

from google.colab import files
uploaded = files.upload()

Saving stock_tsb_zero_remove.csv to stock_tsb_zero_remove.csv


In [8]:
# read data

stock_df = pd.read_csv("stock_tsb_zero_remove.csv")

# make the yearmonth column as date format

stock_df['yearmonth'] = pd.to_datetime(stock_df['yearmonth']) + pd.offsets.MonthEnd(0)

stock_df = stock_df.sort_values(by = 'yearmonth')

stock_df = stock_df.set_index('yearmonth')

stock_df.head()

  stock_df['yearmonth'] = pd.to_datetime(stock_df['yearmonth']) + pd.offsets.MonthEnd(0)


Unnamed: 0_level_0,Unnamed: 0,region,district,site_code,product_type,product_code,stock_distributed,stock_ordered,unique_id,id,zero_per,site_type
yearmonth,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1
2016-01-31,1,ABIDJAN 1-GRANDS PONTS,ADJAME-PLATEAU-ATTECOUBE,C1014,Implant,AS27137,0,0,2016 JanC1014AS27137,C1014AS27137,0.770833,Health Center
2016-01-31,29473,N'ZI-IFOU-MORONOU,BONGOUANOU,C4017,Oral,AS27000,0,0,2016 JanC4017AS27000,C4017AS27000,0.541667,Hospital
2016-01-31,6481,ABIDJAN 2,ABOBO-EST,C1080,Injectable,AS27133,0,0,2016 JanC1080AS27133,C1080AS27133,0.729167,Hospital
2016-01-31,29521,N'ZI-IFOU-MORONOU,BONGOUANOU,C4017,Oral,AS27132,0,0,2016 JanC4017AS27132,C4017AS27132,0.708333,Hospital
2016-01-31,6433,ABIDJAN 2,ABOBO-EST,C1080,Implant,AS27138,0,0,2016 JanC1080AS27138,C1080AS27138,0.791667,Hospital


In [9]:
# create test and train dataset

series = stock_df[['id', 'stock_distributed']]

series_train = series[series.index <= '2019-9-30']
series_test = series[series.index > '2019-9-30']

# Create the Pandas

series = PandasDataset.from_long_dataframe(series, target="stock_distributed", item_id="id", freq = "1M")
#series_train = PandasDataset.from_long_dataframe(series_train, target="stock_distributed", item_id="id", freq = "1M")
#series_test = PandasDataset.from_long_dataframe(series_test, target="stock_distributed", item_id="id", freq = "1M")

In [25]:
# define parameters

prediction_length = 3
context_length = 32
num_samples = 100
device = torch.device("cuda")

In [66]:
# Get predictions with zero-shot inference

forecasts, tss = get_lag_llama_predictions(series, prediction_length, device, num_samples)

In [67]:
# Create an empty DataFrame to store predicted values

prob_pred = pd.DataFrame(columns=['yearmonth', 'id', 'model'] + [f'X{i}' for i in range(num_samples)])
prob_pred

Unnamed: 0,yearmonth,id,model,X0,X1,X2,X3,X4,X5,X6,...,X90,X91,X92,X93,X94,X95,X96,X97,X98,X99


In [68]:
# extract values

n = len(forecasts)

for m in range(n):

  sample_fc = forecasts[m]
  # Extract data
  item_id = sample_fc.item_id
  samples = sample_fc.samples

  # Create DataFrame
  num_columns = samples.shape[0]
  df = pd.DataFrame(samples.T, columns=[f'X{i}' for i in range(num_columns)])
  # Add item_id and yearmonth columns
  df['id'] = item_id
  df['yearmonth'] = [sample_fc.start_date + i for i in range(samples.shape[1])]
  df['model'] = 'lag_llama'

  # Reorder columns
  df = df[['id', 'yearmonth', 'model'] + [f'X{i}' for i in range(num_columns)]]

  # Creating final df
  prob_pred = pd.concat([prob_pred, df])


  prob_pred = pd.concat([prob_pred, df])


In [69]:
prob_pred

Unnamed: 0,yearmonth,id,model,X0,X1,X2,X3,X4,X5,X6,...,X90,X91,X92,X93,X94,X95,X96,X97,X98,X99
0,2019-10,C1004AS27000,lag_llama,2.173428e+01,4.046378e+01,1.477195e+01,0.000000e+00,1.277177e+01,5.177918e+01,3.292208e+01,...,1.993336e+01,1.350238e+01,2.351783e+01,6.076965e+01,1.426944e+01,2.767580e+01,29.182123,7.977084e+01,0.000000e+00,2.978211e+01
1,2019-11,C1004AS27000,lag_llama,1.108250e+01,4.350146e+01,2.786629e+01,1.101834e+00,1.388194e+01,1.117645e+02,3.341034e+01,...,3.285697e+01,2.188196e+01,1.087734e+01,9.335412e+00,1.127616e+01,3.931279e+01,31.682423,9.479538e+01,0.000000e+00,1.152457e+01
2,2019-12,C1004AS27000,lag_llama,1.389767e+01,1.486355e+01,4.402084e+01,5.530740e+00,0.000000e+00,5.546110e+01,3.965796e+01,...,5.988293e+01,2.008001e+01,8.966405e+00,2.793580e+00,3.480209e+02,4.732578e+01,63.575516,5.185420e+01,0.000000e+00,1.753835e+01
0,2019-10,C1004AS27132,lag_llama,0.000000e+00,1.057741e-01,6.646120e-02,0.000000e+00,3.226418e-01,0.000000e+00,2.067505e-01,...,0.000000e+00,0.000000e+00,0.000000e+00,6.929759e-02,0.000000e+00,4.601087e-02,0.000000,0.000000e+00,0.000000e+00,0.000000e+00
1,2019-11,C1004AS27132,lag_llama,7.104607e-02,1.052780e-01,9.073451e-02,9.266825e-02,2.147137e-01,1.578619e-02,1.907760e-01,...,1.631934e-01,4.905418e-02,1.920495e-01,9.299868e-02,5.645051e-02,0.000000e+00,0.067114,1.686305e-01,6.745011e-02,0.000000e+00
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1,2019-11,C5066AS27137,lag_llama,0.000000e+00,2.530593e-11,5.027360e-12,0.000000e+00,0.000000e+00,5.131550e-11,0.000000e+00,...,0.000000e+00,0.000000e+00,0.000000e+00,2.595462e-11,5.684781e-11,1.459599e-13,0.000000,1.087402e-10,1.199558e-10,1.213459e-11
2,2019-12,C5066AS27137,lag_llama,0.000000e+00,0.000000e+00,2.872233e-10,1.208314e-10,2.567598e-10,0.000000e+00,3.859910e-11,...,0.000000e+00,3.952998e-11,3.440024e-10,7.503220e-11,0.000000e+00,0.000000e+00,0.000000,0.000000e+00,9.333727e-12,6.643133e-12
0,2019-10,C5066AS27138,lag_llama,6.378444e-11,2.248599e-10,0.000000e+00,2.759202e-11,0.000000e+00,3.457211e-11,0.000000e+00,...,4.918544e-11,2.438416e-10,6.910883e-11,0.000000e+00,0.000000e+00,8.266671e-11,0.000000,2.351117e-10,0.000000e+00,8.163956e-11
1,2019-11,C5066AS27138,lag_llama,1.134768e-10,0.000000e+00,3.899087e-10,0.000000e+00,1.038429e-12,4.748914e-11,1.205735e-10,...,0.000000e+00,2.098022e-11,0.000000e+00,4.914228e-11,5.582003e-11,1.402541e-10,0.000000,0.000000e+00,0.000000e+00,1.181900e-11


In [None]:
evaluator = Evaluator()
agg_metrics, ts_metrics = evaluator(iter(tss), iter(forecasts))
print("CRPS:", agg_metrics['mean_wQuantileLoss'])

Running evaluation: 755it [00:00, 21962.75it/s]


CRPS: 0.8225779293172094


  return arr.astype(dtype, copy=True)
  return arr.astype(dtype, copy=True)


In [70]:
# save file

prob_pred.to_csv('lag_llama_pred.csv', index = False)
files.download("lag_llama_pred.csv")

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>