In [1]:
#################
# Data Preparation
#################
import os
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as transforms

dataset = []
data_root = './data/tmp'
for obj_dir in os.listdir(data_root):
    images = []
    for i in range(12):
        im_path = os.path.join(data_root, obj_dir, f"a_{i:0>3}_depth0001.png")
        images.append(
            np.array(Image.open(im_path).convert('RGB')))
    
    spt_path = os.path.join(data_root, obj_dir, "a_script.spt")
    with open(spt_path) as file:
        spt = file.read()

    dataset.append({
        "images": images,
        "script": spt,
    })
transform = transforms.Compose(
    [transforms.ToTensor(),     
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
     
def load_obj_data(idx):
    im_pth = []
    for im in dataset[idx]['images']:
        im_pth.append(transform(im))
    im_pth = torch.stack(im_pth, dim=0)
    return im_pth, dataset[idx]['script']

In [2]:
#################
# Model
#################
import os
import torch.nn as nn
from transformers import CodeGenTokenizer, CodeGenForCausalLM
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

MODEL_PATH = os.path.expanduser("./codegen-350M-mono/")
tokenizer = CodeGenTokenizer.from_pretrained(MODEL_PATH)
model = CodeGenForCausalLM.from_pretrained(MODEL_PATH).to(device)

class FuseEmb(nn.Module):
    def __init__(self, real_emb) -> None:
        super().__init__()
        self.prompt = True
        self.real_emb = real_emb

        # [batch, 3, H, W] -> [batch, 1024]
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 32, kernel_size=5, stride=2),
            nn.ReLU(inplace=True),
            nn.Flatten(),
            nn.Linear(119072, 1024),
        )
        self.im_prompt = None

    def set_im_input(self, images):
        h = self.encoder(images)
        self.im_prompt = h.view(1, len(images), 1024)
        self.prompt = True

    def forward(self, input):
        if self.prompt:
            # print("prompt", input.shape)
            self.prompt = False
            e = self.real_emb(input)
            N_prompt = self.im_prompt.shape[1]
            e[:, :N_prompt, :] = self.im_prompt
            return e
        else:
            return self.real_emb(input)

# [key step] replace embedding layer
emb = FuseEmb(model.transformer.wte).to(device)
model.transformer.wte = emb

In [3]:
#################
# Fine-tune
#################
from transformers import get_scheduler
from torch.optim import AdamW
from tqdm import tqdm

optimizer = AdamW(model.parameters(), lr=2e-4)

num_epochs = 100
num_training_steps = num_epochs * len(dataset)
# lr_scheduler = get_scheduler(
#     name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
# )

pbar = tqdm(range(num_epochs))
for epoch in pbar:
    for idx in range(len(dataset)):
        model.train()

        ims, spt = load_obj_data(idx)
        n_ims = ims.shape[0]
        
        fake_prompt = torch.arange(0, n_ims).view(1, -1).long()
        spt_ids = tokenizer(spt, return_tensors="pt").input_ids

        full_text = torch.cat([fake_prompt, spt_ids], dim=1).to(device)
        attention_mask = torch.tensor([[1]*fake_prompt.shape[1] + [0]*spt_ids.shape[1]]).to(device)

        model.transformer.wte.set_im_input(ims.to(device))
        outputs = model(full_text, 
                        labels=full_text, 
                        attention_mask=attention_mask)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        # lr_scheduler.step()
        optimizer.zero_grad()

        pbar.set_description(f"loss={loss.item():.4f}")

  attn_weights = torch.where(causal_mask, attn_weights, mask_value)
loss=0.0002: 100%|██████████| 100/100 [02:49<00:00,  1.69s/it]


In [9]:
#################
# Test
#################
idx = 0
ims, spt = load_obj_data(idx)
n_ims = ims.shape[0]

fake_prompt = torch.arange(0, n_ims).view(1, -1).long().to(device)
model.transformer.wte.set_im_input(ims.to(device))
generated_ids = model.generate(fake_prompt, max_length=50)

print(f"\n\n[Prediction]\n")
print(tokenizer.decode(generated_ids[0][n_ims:], skip_special_tokens=True))
print(f"\n\n[Ground Truth]\n")
print(spt)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.




[Prediction]

ADD CUBE
SELECT FACE LOCATION LEFT 2
RESIZE ALL -
DELETE
SOLIDIFY 7
MOD_
SOLIDIFY 3


[Ground Truth]

ADD CUBE
SELECT FACE LOCATION LEFT 2
RESIZE ALL - 4
DELETE
MOD_SOLIDIFY 3

