In [122]:
import dvc.api
import pandas as pd
import sys 
sys.path.append('../../')
from main_multitask_multimodal import LLM_MultitaskMultimodal

In [123]:
from pytorch_lightning.utilities.deepspeed import convert_zero_checkpoint_to_fp32_state_dict
save_path = '/workspaces/multitask-llm-rnd/modelling/models/product_attribute_extraction_text_generation/version_1/epoch=4-step=8094.ckpt'
output_path = f'{save_path}/pytorch_model.bin'


In [124]:
convert_zero_checkpoint_to_fp32_state_dict(save_path, output_path)

Processing zero checkpoint '/workspaces/multitask-llm-rnd/modelling/models/product_attribute_extraction_text_generation/version_1/epoch=4-step=8094.ckpt/checkpoint'
Detected checkpoint of type zero stage 2, world_size: 6
Parsing checkpoint created by deepspeed==0.7.3
Reconstructed fp32 state dict with 283 params 582401281 elements
Saving fp32 state dict to /workspaces/multitask-llm-rnd/modelling/models/product_attribute_extraction_text_generation/version_1/epoch=4-step=8094.ckpt/pytorch_model.bin


In [125]:
model = LLM_MultitaskMultimodal.load_from_checkpoint(output_path)

INFO:root:Unused kwargs when getting google/mt5-base: {'distance_func': 'cosine', 'loss_type': 'cross-entropy', 'margin': None, 'hidden_states_type': 'encoder-last', 'add_simcse': False, 'manual_loss_type': 'manual_mse', 'auto_task_weight': False, 'multitask_specs_dict': {'clm_singlemodal_wishtitledesp2attrkvpair': None}, 'head_dict': {}}


In [126]:
def format_input(title, description, category):
    template = f'[title start] {title} [title end] [description start] {description} [description end] [taxonomy start] {category} [taxonomy end]'
    task_prefix = 'Generate attribute key value pairs for product with description and taxonomy: '
    return task_prefix + template

In [127]:
df = pd.read_json(dvc.api.get_url( 
    'datasets/data/wish_attr_extract_label/processed/appen_020323_030323_delivered_030623_validated_product_attr_textonly_train.json', 
    repo='git@github.com:ContextLogic/multitask-llm-rnd.git'
), lines=True)

In [138]:
rec = df.sample(1).to_dict('records')[0]

In [139]:
input_text = format_input(rec['title'], rec['description'], rec['category'])

In [140]:
rec

{'label_ordering': 44790,
 'sample_method': 'only_text',
 'pid': '60198af7f883788328ddf725',
 'category': 'Home & Garden > Home Decor > Flags, Banners & Accessories',
 'title': 'Spring Dissent 2009 - Orange banner flag 150*90 cm',
 'description': 'Spring Dissent 2009 - Orange flag banner flags 3*5 inch',
 'main_img_url': nan,
 'rater_output_processed': 'Home & Garden > Home Decor > Flags, Banners & Accessories > Alpha Size > 3.5 inch\nHome & Garden > Home Decor > Flags, Banners & Accessories > Item Types > Flag\nHome & Garden > Home Decor > Flags, Banners & Accessories > Item Types > Banner\nHome & Garden > Home Decor > Flags, Banners & Accessories > Primary Color > Orange\nHome & Garden > Home Decor > Flags, Banners & Accessories > Alpha Size > 3x5 inch',
 'attr_name_value_pairs_normalized': [['Primary Color', 'Orange']],
 'attr_name_value_pairs_custom': [['Alpha Size', '3.5 inch'],
  ['Alpha Size', '3x5 inch'],
  ['Item Types', 'Banner'],
  ['Item Types', 'Flag']],
 'attr_name_value_

In [141]:
print(input_text)

Generate attribute key value pairs for product with description and taxonomy: [title start] Spring Dissent 2009 - Orange banner flag 150*90 cm [title end] [description start] Spring Dissent 2009 - Orange flag banner flags 3*5 inch [description end] [taxonomy start] Home & Garden > Home Decor > Flags, Banners & Accessories [taxonomy end]


In [142]:
inputs = model.tokenizer(input_text, return_tensors='pt')

In [143]:
decoder_inputs = model.tokenizer('<pad>Materials|', return_tensors='pt')

In [144]:
model.tokenizer.batch_decode(decoder_inputs['input_ids'][:,:-1])

['<pad> Materials|']

In [145]:
model.eval()
model.tokenizer.batch_decode(model.transformer.generate(
    decoder_input_ids=decoder_inputs['input_ids'][:,:-1], 
    **inputs))

['<pad> Materials|Polyester</s>']

In [146]:
model.eval()
model.tokenizer.batch_decode(model.transformer.generate(**inputs))

['<pad></s>']

In [147]:
model.eval()
model.tokenizer.batch_decode(model.transformer.generate(**inputs, min_length=100, max_length=100))

['<pad> Materials|Polyester [NL] Primary Color|Multicolor [NL] Primary Color|White [NL] Shape|Rectangular [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flowers [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flag [NL] Theme|Flag [NL] Theme']

In [16]:
recs = df.sample(4).to_dict('records')

In [20]:
prompts = []
for rec in recs[:1]:
    prompts.append(format_input(rec['title'], rec['description'], rec['category']) + ' -> ' + \
                  rec['attr_name_value_pairs_normalized_text'])
rec = recs[-1]
prompts.append(format_input(rec['title'], rec['description'], rec['category']) + ' -> ')
print('\n\n'.join(prompts))

Generate attribute key value pairs for product with description and taxonomy: [title start] 316L Solid Stainless Steel  Necklace [title end] [description start] Authenticity Guaranteed
Hypoallergenic Jewelry
Comfort Fit Design
Safe on Skin
Made to Last a Lifetime
Designed in ITALY
Certifed Gemstones
Creation Method: Lab Created 
Guaranteed to Retain its Color and Shine
316L Solid Stainless Steel Necklace [description end] [taxonomy start] Jewelry & Accessories > Necklaces & Pendants > Power Necklaces [taxonomy end] -> Department|Women
Materials|Argentium Plated Stainless Steel
Materials|Stainless Steel

Generate attribute key value pairs for product with description and taxonomy: [title start] Simple Style Rolls Royce Pillowcase Throw Pillows Soft Sofa Cushion Covers Square Pillowcase [title end] [description start] Item: Pillow cases
Material: Cotton （As shown in the figure）
Style: Modern Printed/Personality DIY decoration
Pattern Type: Portrait/photo/Animal/plants/Letter and so on
Pr

In [22]:
input_text = '\n'.join(prompts)

In [23]:
model.eval()
model.tokenizer.batch_decode(model.transformer.generate(**model.tokenizer(input_text, return_tensors='pt'), min_length=10))



['<pad> Primary Color|Black [NL] Primary Color|Green</s>']