In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import pandas as pd
import numpy as np
tokenizer_vi2en = AutoTokenizer.from_pretrained("vinai/vinai-translate-vi2en-v2", src_lang="vi_VN")
model_vi2en = AutoModelForSeq2SeqLM.from_pretrained("vinai/vinai-translate-vi2en-v2")
device_vi2en = torch.device("cuda")
model_vi2en.to(device_vi2en)

def translate_vi2en(vi_texts: str) -> str:
    input_ids = tokenizer_vi2en(vi_texts, padding=True, return_tensors="pt").to(device_vi2en)
    output_ids = model_vi2en.generate(
        **input_ids,
        decoder_start_token_id=tokenizer_vi2en.lang_code_to_id["en_XX"],
        num_return_sequences=1,
        num_beams=5,
        early_stopping=True,
        max_length=1024,
    )
    en_texts = tokenizer_vi2en.batch_decode(output_ids, skip_special_tokens=True)
    return en_texts

df_test = pd.read_csv("../private/info.csv")
df_test['caption'] = df_test['moreInfo'].replace(np.nan, '')
df_test['description'] = df_test['moreInfo'].replace(np.nan, '')
df_test['moreInfo'] = df_test['moreInfo'].replace(np.nan, '')

text_cap = list(df_test['caption'])
text_des = list(df_test['description'])
text_more_info = list(df_test['moreInfo'])
img = list(df_test['bannerImage'])

#batch_size = 8
eng = []

In [None]:
from imagen_pytorch import Unet, ImagenTrainer, ElucidatedImagen
import random

import pandas as pd
from imagen_pytorch import t5

T5_name = "google/t5-v1_1-base"
unet1 = Unet(
    dim = 128,
    cond_dim = 512,
    dim_mults = (1, 2, 3, 4),
    num_resnet_blocks = 3,
    attn_dim_head = 64,
    attn_heads = 8,
    layer_attns = (False, True, True, True),
    layer_cross_attns = (False, True, True, True),
    memory_efficient = False,
)
unet2 = Unet(
    dim = 128,
    cond_dim = 256, 
    dim_mults = (1, 2, 3, 4),
    num_resnet_blocks = (2, 4, 8, 8),
    attn_dim_head = 64,
    attn_heads = 8,
    layer_attns = (False, False, False, True),
    layer_cross_attns = (False, False, False, True),
    memory_efficient = True,
)
imagen = ElucidatedImagen(
    unets = [unet1, unet2],
    image_sizes = (64, 256),
    cond_drop_prob = 0.1,
    num_sample_steps = (128, 128),
    sigma_min = 0.002,
    sigma_max = (80, 160),
    sigma_data = 0.5,
    rho = 7,
    P_mean = -1.2,
    P_std = 1.2,
    S_churn = 80,
    S_tmin = 0.05,
    S_tmax = 50,
    S_noise = 1.003,
    text_encoder_name=T5_name
)

random_seed = int(random.random() * 1000)
trainer = ImagenTrainer(
    imagen = imagen,
).cuda()

trainer.load('../saved_model/model.pt')
batch_size = 16
stop_unet = 2

In [None]:
# FOR SUBMISSION 1
import time 

all_predicted_time = []
for i in range(0, len(text_cap), batch_size):
    t1 = time.time()

    if (i + batch_size >= len(text_cap)):
        batch_size = len(text_cap) - i
        
    #input_ = preprocess(prompt)
    caption = text_cap[i: i+batch_size]
    des = text_des[i: i+batch_size]
    img_name = img[i: i+batch_size]
    more_info = text_more_info[i: i+batch_size]
    all_info = []
    for j in range(0, len(caption)):
        all_info.append(caption[j] + ". " + des[j] + ". " + more_info[j])
    trans_text = translate_vi2en(all_info)

    #forward = model.generate(trans_text)
    info_embeds, mask = t5.t5_encode_text(trans_text, return_attn_mask = True, name=T5_name)
    images = trainer.sample(batch_size = batch_size, text_embeds=info_embeds, return_pil_images = True, stop_at_unet_number=stop_unet) # returns List[Image]

    for j in range(batch_size):
        images[j] = images[j].resize((1024,533))
        images[j].save('../results/submission1/jupyter_' + img_name[j])
        
    #result = postprocess(filename, forward)
    t2 = time.time()
    predicted_time = t2 - t1
    for j in range(batch_size):
        all_predicted_time.append((img_name[j], predicted_time/batch_size))
    
   

df_time = pd.DataFrame(all_predicted_time, columns=['fname', 'time'])
df_time.to_csv("results/time_submission1.csv", index=False)

In [None]:
# FOR SUBMISSION 2
import time 

all_predicted_time = []
for i in range(0, len(text_cap), batch_size):
    t1 = time.time()

    if (i + batch_size >= len(text_cap)):
        batch_size = len(text_cap) - i
        
    #input_ = preprocess(prompt)
    caption = text_cap[i: i+batch_size]
    des = text_des[i: i+batch_size]
    img_name = img[i: i+batch_size]
    more_info = text_more_info[i: i+batch_size]
    all_info = []
    for j in range(0, len(caption)):
        all_info.append(caption[j] + ". " + des[j] + ". " + more_info[j])
    trans_text = translate_vi2en(all_info)

    #forward = model.generate(trans_text)
    info_embeds, mask = t5.t5_encode_text(trans_text, return_attn_mask = True, name=T5_name)
    images = trainer.sample(batch_size = batch_size, text_embeds=info_embeds, return_pil_images = True, stop_at_unet_number=stop_unet) # returns List[Image]

    for j in range(batch_size):
        images[j] = images[j].resize((1024,533))
        images[j].save('../results/submission2/jupyter_' + img_name[j])
        
    #result = postprocess(filename, forward)
    t2 = time.time()
    predicted_time = t2 - t1
    for j in range(batch_size):
        all_predicted_time.append((img_name[j], predicted_time/batch_size))

df_time = pd.DataFrame(all_predicted_time, columns=['fname', 'time'])
df_time.to_csv("results/time_submission2.csv", index=False)