# Mystery box - 4 - Extract box value from images
### Dennis Bakhuis - 10th November 2022
### https://linkedin.com/in/dennisbakhuis/

In [None]:
import json
import datetime
from pathlib import Path
import os

import cv2

import matplotlib.pyplot as plt
import numpy as np
import torch

import transformers
from transformers import (
    DonutProcessor,
    VisionEncoderDecoderModel, 
    VisionEncoderDecoderConfig,
)

from tqdm.auto import tqdm
import pandas as pd

from tqdm_batch import batch_process

In [None]:
os.environ['TOKENIZERS_PARALLELISM'] = 'true'  # make some warning disappear

# Prepare model for inference

In [None]:
transformers.logging.disable_default_handler()
processor = DonutProcessor.from_pretrained('bakhuisdennis/donut-base-mysterybox')
model = VisionEncoderDecoderModel.from_pretrained('bakhuisdennis/donut-base-mysterybox')

In [None]:
def run_prediction(image, model=model, processor=processor):
    # image = prep_image(image)
    pixel_values = processor(
        image,
        random_padding=False,
        return_tensors="pt",
    ).pixel_values
    
    task_prompt = "<s>"
    decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

    # run inference
    outputs = model.generate(
        pixel_values,
        decoder_input_ids=decoder_input_ids,
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=1,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )

    # process output
    prediction = processor.batch_decode(outputs.sequences)[0]
    prediction = processor.token2json(prediction)

    return prediction

In [None]:
files = list(sorted(Path('../data/mystery_box/images_unlabeled').glob('*.png')))
times = [int(x.stem) for x in files]

In [None]:
from tqdm_batch import batch_process

def batch_process_function(file_name, model, processor):
    image = cv2.imread(str(file_name))
    pred = run_prediction(image, model=model, processor=processor)
    ix = int(Path(file_name).stem)
    return (ix, pred)

result = batch_process(
    files,
    batch_process_function,
    model=model,
    processor=processor,
    n_workers=4,
    sep_progress=True,
)

## Combine data with previously labeled

In [None]:
def try_get_int(data: dict):
    """Lazy way to check for missing predictions."""
    try:
        return int(data['distance'])
    except:
        return None

proc_results = [{
        'run_time': ix, 
        'distance': try_get_int(data),
    }
    for ix, data in sorted(result, key=lambda x: x[0])
]

df = pd.DataFrame(proc_results)

Combine infered data with our labeled set.

In [None]:
with open('../data/mystery_box/metadata.jsonl') as f:
    train_data = [json.loads(x) for x in f.readlines()]

train_data = [
    {
        'run_time': int(Path(x['file_name']).stem),
        'distance': int(json.loads(x['text'])['distance']),
    } for x in train_data
]

train_data = pd.DataFrame(train_data)

In [None]:
df = (
    pd
    .concat([df, train_data], ignore_index=True, )
    .sort_values('run_time')
    .reset_index(drop=True)
)

Make a proper time column.

In [None]:
image_file = '../data/raw/Session22/VID_20221019_171145.mp4'

start_time = pd.to_datetime(
    Path(image_file).stem[4:],
    format="%Y%m%d_%H%M%S",
)

In [None]:
df = (
    df
    .assign(
        time=start_time + pd.to_timedelta(df.run_time, unit='s'),
    )
)

## Repair the noisy data

In [None]:
df.distance.plot()

In [None]:
window = df.distance.rolling(window=3, min_periods=1).mean()

df['filtered'] = df['distance'].where(
    (df.distance - window).abs() < 5, 
    np.nan,
).interpolate()

plt.rcParams.update({'font.size': 22})
fig, ax = plt.subplots(figsize=(12, 6))
_ = df[['distance', 'filtered']].plot(ax=ax)

In [None]:
df.to_parquet('../data/image_values.parquet', index=False)

## Combine with GPS data

In [None]:
df = pd.read_parquet('../data/image_values.parquet')

Make time the index.

In [None]:
df = (
    df
    .dropna(subset='filtered')
    .drop_duplicates(subset='time')
    .set_index('time')
)

In [None]:
gps = pd.read_parquet('../data/mystery_box/gps_data.parquet')

Make sure that gps is withing the time limits of our image data.

In [None]:
gps = (
    gps
    .loc[gps.time > df.index.min()]
    .loc[gps.time < df.index.max()]
)

Add the A value to the GPS dataset.

In [None]:
def get_distance(time):
    return df.iloc[df.index.get_indexer([time], method='nearest')[0]].filtered

gps['A'] = gps.time.apply(get_distance).astype(int)

In [None]:
gps.distance.plot()

In [None]:
gps.to_parquet('../data/mystery_box_dataset.parquet', index=False)