# Load model and dataset
First, we load the model and dataset. The model is pretrained and saved locally under `rootdir` in `config.py`.

In [1]:
import os
import sys
# adding project directory to path, this is a bit hacky and may not work for all
sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.abspath(''))))

from evaluation import Demo
demo = Demo()
model, dataset, run = demo.load_run(model_path="../model_weights/")

Loading model from: ../model_weights/999.pt
{'d_model': 512, 'd_ff_mult': 2, 'nhead': 2, 'num_layers': 4, 'field_encoder_layers': 2, 'field_decoder_layers': 3, 'num_emb': 'periodic', 'tie_numerical_embeddings': False, 'tie_numerical_decoders': False, 'tie_mask_embeddings': True, 'epochs': 1000, 'batch_size': 32, 'lr': 0.0001, 'weight_decay': 0, 'dropout': 0.0, 'mask_rate': [-1, 0.5], 'wandb': True, 'tags': ['MaskOnlyLossAttn', 'grad_works'], 'device': 'cuda:0', 'seed': 42, 'rootdir': '/logdir', 'ckpt': '', 'text_model': 'custom', 'tie_embeddings': True, 'tokenizer': 'gpt2', 'text_decoder_layers': 4, 'text_encoder_layers': 4, 'use_mup': True, 'num_fields': 12, 'vocab_size': 50258, 'fields': Fields([('numerical', ['phone.weight', 'phone.height', 'phone.depth', 'phone.width', 'phone.display_size', 'phone.battery', 'phone.launch.day', 'phone.launch.month', 'phone.launch.year']), ('categorical', ['phone.oem', 'phone.network_edge']), ('text', ['phone.model'])]), 'categorical_num_classes': {'

Let's take a look at the fields we can play around with and the range of values they can take.

In [2]:
demo.print_fields()

* numerical
	- phone.weight        [4.0, 2018.3]
	- phone.height        [23.0, 451.8]
	- phone.depth         [0.0, 75.0]
	- phone.width         [15.7, 283.2]
	- phone.display_size  [2.413, 46.736]
	- phone.battery       [0.0, 13.550867004960905]
	- phone.launch.day    [1.0, 31.0]
	- phone.launch.month  [1.0, 12.0]
	- phone.launch.year   [1994.0, 2020.0]
* categorical
	- phone.oem
	- phone.network_edge
* text
	- phone.model


For a full description of the data we can also directly look at the dataframe saved in the dataset object.

In [3]:
dataset._df.describe()

Unnamed: 0,phone.weight,phone.display_size,phone.height,phone.width,phone.depth,phone.battery,phone.launch.year,phone.launch.day,phone.launch.month
count,9626.0,9268.0,10287.0,10287.0,10287.0,9554.0,10588.0,507.0,9411.0
mean,147.995917,10.821959,129.369272,65.739584,13.01498,10.731901,2011.634964,15.948718,6.130592
std,93.984026,5.059618,34.636882,24.429045,5.249419,0.978098,5.158068,8.376122,3.391471
min,4.0,2.413,23.0,15.7,0.0,0.0,1994.0,1.0,1.0
25%,96.0,6.096,106.0,49.0,8.9,9.967226,2008.0,8.0,3.0
50%,130.0,10.922,123.0,63.0,11.5,10.732167,2012.0,16.0,6.0
75%,165.0,13.97,148.7,73.3,15.9,11.551228,2015.0,24.0,9.0
max,2018.3,46.736,451.8,283.2,75.0,13.550867,2020.0,31.0,12.0


```
        The method below is used to sample data from the dataset.

        Args:
            num (int): The number of samples to be drawn from the dataset.
            input_dict (dict): A dictionary containing the data we want to give as input to the model.
                The keys are the field names and the values are the corresponding data.
            mask_none (bool, optional): If True, the fields with None values in the input_dict will be masked.
                Otherwise, they will be resampled. Defaults to False.
            temp (int, optional): The temperature parameter for the sampling process. 
                Higher values make the sampling more random, lower values make it more deterministic. 
                Defaults to 0.
            resample_given (bool, optional): If True, the fields with given values in the input_dict will
                be resampled. Defaults to False.

        Returns:
            dict: A dictionary containing the sampled data. The keys are the field names and the values are the corresponding sampled data.
```

In [4]:
using_dict = {'phone.weight': None, # Set to None to hide
              'phone.height' : None, # 129
              'phone.depth' : 20, 
              'phone.width' : 60, 
              'phone.display_size' : 150, 
              'phone.battery' : 10, 
              'phone.launch.day' : 5, 
              'phone.launch.month': 4, 
              'phone.launch.year': None, 
              'phone.oem' :None, 
              'phone.network_edge': None, 
              'phone.model':"Galaxy S4"} # Set to None to hide

result = demo.sample(1, using_dict, mask_none=True, temp=0., resample_given=False)
result.result_dict = {k: v for k, v in result.result_dict.items() if using_dict[k] is None}
print(result)

phone.weight         143.78152465820312
phone.height         123.86957550048828
phone.launch.year    2014.651123046875
phone.oem            Samsung
phone.network_edge   Yes
