In [1]:
"""
This file is used to load and run a trained model and test backwards diffusion.
"""
import os
import sys
# adding project directory to path, this is a bit hacky and may not work for all
sys.path.insert(0, os.path.abspath(os.path.dirname(os.path.abspath(''))))

import torch
from kbgen.config import rootdir
from kbgen.data.datasets import GSM
from kbgen.utils.log import RunTracker
import numpy as np
from matplotlib import pyplot as plt
from collections import defaultdict
import warnings
import pandas as pd
plt.style.use("./mine.mplstyle")

In [2]:
# DATA -----------------------------------------------------
device = "cuda" # "cuda:0" if torch.cuda.is_available() else "cpu"
# Load Wandb Run -----------------------------------------------------
# You need to pick a run that is saved locally
model_name = "07-29-16-36-26queriesnonz-random_rate_L2td4_te4_d512_periodic"
logdir = os.path.join(rootdir, "models", model_name)
run = RunTracker.from_logdir(logdir, force_device=device)
dataset = GSM.from_config(run.config, update=True)
model = run.load_latest_model().to(device)
model.eval()
accuracy = model.accuracy
print(run.config)

Loading model from: /logdir/models/07-29-16-36-26queriesnonz-random_rate_L2td4_te4_d512_periodic/250.pt
{'d_model': 512, 'd_ff_mult': 2, 'nhead': 2, 'num_layers': 2, 'field_encoder_layers': 2, 'field_decoder_layers': 3, 'text_decoder_layers': 4, 'text_encoder_layers': 4, 'num_emb': 'periodic', 'tie_numerical_embeddings': False, 'tie_numerical_decoders': False, 'tie_mask_embeddings': False, 'pretrain_epochs': 0, 'epochs': 1000, 'batch_size': 32, 'lr': 0.0001, 'weight_decay': 0, 'dropout': 0.0, 'mask_rate': [-1, 0.5], 'wandb': True, 'tags': ['queriesnonz', 'random_rate'], 'device': 'cuda', 'seed': 42, 'rootdir': '/logdir', 'ckpt': None, 'tie_embeddings': True, 'text_model': 'custom', 'tokenizer': 'gpt2', 'all_on_gpu': True, 'use_mup': True, 'num_fields': 12, 'vocab_size': 50258, 'fields': Fields([('numerical', ['phone.weight', 'phone.height', 'phone.depth', 'phone.width', 'phone.display_size', 'phone.battery', 'phone.launch.day', 'phone.launch.month', 'phone.launch.year']), ('categorical

In [4]:
input_dict = dataset.input_dict.iloc[dataset.val_idx].to(device)
attention_mask = dataset.pad_mask_dict.iloc[dataset.val_idx.to(device)

In [5]:
from kbgen.diffusion import HybridDiffusion
class EncoderDecoder:
    def __init__(self, fields, numerical_decoder, categorical_decoder, detokonize) -> None:
        self.fields = fields
        self.numerical_decoder = numerical_decoder
        self.detokenize = detokonize
        self.categorical_decoder = categorical_decoder

    def decode(self, x, from_logits=False):
        if from_logits:
            for field in self.fields["categorical"] + self.fields["text"]:
                x[field] = x[field].argmax(-1)
            for field in self.fields["categorical"]:
                x[field]  = self.categorical_decoder(field, x[field])
        decoded = {}
        for field in self.fields.all_fields:
            if field in self.fields["numerical"]:
                decoded[field] = self.numerical_decoder(field, x[field]).tolist()
            else:
                decoded[field] = self.detokenize(x[field])
        return decoded
    

sample_ids = dataset.val_idx:10]
sample_input_dict = dataset.input_dict.iloc[sample_ids].to(device)
sample_attention_mask = dataset.pad_mask_dict.iloc[sample_ids].to(device)

input_shapes = {}
for field, tensor in sample_input_dict.items():
    input_shapes[field] = tensor.shape[1:]
print(input_shapes)

hmd = HybridDiffusion(model, dataset.fields, dataset.categorical_id_to_token, input_shapes)
encdec = EncoderDecoder(dataset.fields, dataset.numerical_decode, dataset.categorical_id_to_token, dataset.tokenizer.batch_decode)

{'phone.weight': torch.Size([]), 'phone.height': torch.Size([]), 'phone.depth': torch.Size([]), 'phone.width': torch.Size([]), 'phone.display_size': torch.Size([]), 'phone.battery': torch.Size([]), 'phone.launch.day': torch.Size([]), 'phone.launch.month': torch.Size([]), 'phone.launch.year': torch.Size([]), 'phone.oem': torch.Size([7]), 'phone.network_edge': torch.Size([18]), 'phone.model': torch.Size([20]), 'phone.oem_idx': torch.Size([]), 'phone.network_edge_idx': torch.Size([])}


In [6]:
sample_mask = torch.zeros((len(sample_ids), hmd.numel), dtype=torch.bool, device=device)

In [7]:
sample_input_dict["phone.model"]

tensor([[   55,   417,   571,   380,   718, 50256,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [14662,  3582, 16904, 50256,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [   45,  4304, 50256,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [   57,  1120, 50256,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [38956, 50256,     0,     0,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [   52,    23,  3829, 50256,     0,     0,     0,     0,     0,     0,
             0,     0,     0,     0,     0,     0,     0,     0,     0,     0],
        [ 7575,   634, 21355, 50256,     0,   

In [8]:
single_step_preds = model(sample_input_dict, key_padding_mask=sample_attention_mask, property_mask=hmd._bool_mask_to_float(sample_mask))
encdec.decode(single_step_preds, from_logits=True)

{'phone.weight': [[136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828],
  [136.8832550048828]],
 'phone.height': [[126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578],
  [126.19513702392578]],
 'phone.depth': [[13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379],
  [13.292008399963379]],
 'phone.width': [[62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.78573989868164],
  [62.785739898681

In [9]:
# sample_mask = torch.randint(0,2, (len(sample_ids), hmd.numel), dtype=torch.bool, device=device)
decoded = encdec.decode(hmd.generate_sample(1, sample_input_dict, mask=sample_mask, leaps=20, undo=False, temperature=0))
pd.DataFrame(decoded)

Unnamed: 0,phone.weight,phone.height,phone.depth,phone.width,phone.display_size,phone.battery,phone.launch.day,phone.launch.month,phone.launch.year,phone.oem,phone.network_edge,phone.model
0,89.000038,83.999992,26.000002,74.999985,-44364.910156,-13564.417969,-30029.0,11.0,2003.0,Siemens,No,Xelibri 6
1,404.0,204.999985,9.0,138.199997,20.320002,12.13603,-30029.0,1.0,2014.0,BLU,Yes,Life View Tab
2,115.000107,106.499985,13.7,52.000004,6.096002,9.453271,-30029.0,1.0,2007.0,Nokia,"Class 32, 296 / 177.6 kbits",N76
3,119.999954,100.900024,21.499998,51.799999,6.096002,9.893301,-30029.0,11.0,2014.0,Yezz,No,Z50
4,127.999901,95.0,25.000004,48.699989,-44364.910156,10.502831,-30029.0,7.0,2005.0,Philips,No,968
5,98.999969,98.799988,18.200005,48.999996,5.587999,-13564.417969,-30029.0,2.0,2006.0,LG,,U890
6,108.999908,124.999969,24.000004,48.999996,-44364.910156,9.645658,-30029.0,-11010.0,2001.0,Motorola,No,Timeport 280
7,95.0,106.000008,12.400004,44.000011,4.572,9.645658,-30029.0,11.0,2005.0,NEC,No,e121
8,167.0,136.999985,9.799996,70.999992,12.7,11.241387,-30029.0,4.0,2014.0,Nokia,Up to 236.8 kbps,Lumia 930
9,160.199951,143.399979,8.899999,72.399986,12.7,10.966506,-30029.0,-11010.0,2013.0,Archos,Yes,50 Platinum
