In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from transformers import AutoTokenizer, AutoModelForCausalLM

from extractor import addImagePath, textExtraction, imageExtraction, textExtractReverse

# 1. Check the max length of the text data

In [10]:
subUrlList = ['wendys','mcdonalds', 'mcdonalds_switzerland','mcdonaldscanada','sonicdrivein']
max_length = 0
for subUrl in subUrlList:
    dirPath = '../Data/Instagram/Filter_' + subUrl + '.csv'
    data = pd.read_csv(dirPath)
    
    for i in range(len(data['caption'])):
        if len(str(data['caption'][i]).split()) > max_length:
            max_length = len(str(data['caption'][i]).split())
            word = data['caption'][i]
max_length

373

In [74]:
max_length = 0
word = ''
dirPath = '../Data/Oxford_HIC/oxford_hic_data.csv'
data = pd.read_csv(dirPath)
max_length = 0
index =0 
counter  = 5
# find the max word count of the text data['caption']
for i in range(len(data['caption'])):
    if len(str(data['caption'][i]).split()) > max_length:
        max_length = len(str(data['caption'][i]).split())
        word = data['caption'][i]
        index = i       
        #
print(max_length, i)
word

9729 3657846




# 2. Load the data and split the data

In [2]:
# if args.img - dir == 'Oxford_HIC':
#     dirPath = '../Data/Oxford_HIC/oxford_hic_data.csv'
#     imgPath = '../Data/Oxford_HIC/oxford_img/'
# else:
dirPath = '../Data/Instagram/Filter_' + 'wendys' + '.csv'
imgPath = '../Data/Instagram/' + 'wendys' + '_img/'
# load data
data = pd.read_csv(dirPath)
data = addImagePath(data, imgPath)
# split data
train, test = train_test_split(data, test_size=0.2, random_state=42)

train_text = textExtraction(train['caption'])
train_image = imageExtraction(train['image_id'])
train_funny_score = torch.tensor(train['funny_score'].to_numpy())
test_text = textExtraction(test['caption'])
test_image = imageExtraction(test['image_id'])
test_funny_score = torch.tensor(test['funny_score'].to_numpy())

100%|██████████| 293/293 [00:00<00:00, 1097.37it/s]
100%|██████████| 293/293 [00:08<00:00, 36.02it/s]
100%|██████████| 74/74 [00:00<00:00, 1692.97it/s]
100%|██████████| 74/74 [00:02<00:00, 30.50it/s]


In [3]:
train_text.shape, train_image.shape, train_funny_score.shape

(torch.Size([293, 373, 768]), torch.Size([293, 64, 768]), torch.Size([293]))

In [4]:
train_dataset = torch.utils.data.TensorDataset(train_text, train_image, train_funny_score)
test_dataset = torch.utils.data.TensorDataset(test_text, test_image, test_funny_score)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)

# 3. LLM Test

In [4]:
### 不確定是否為官方的 Gemini ############################################################################
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
tokenizer = AutoTokenizer.from_pretrained("describeai/gemini")
gemini = AutoModelForSeq2SeqLM.from_pretrained("describeai/gemini")
#######################################################################################################
gemini



In [13]:
### 官方的Gemma #########################################################################################
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
# gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto", revision="float16")
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto",  torch_dtype=torch.bfloat16)
########################################################################################################



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gemma.to(device)
vocab_size = 256128  # 词汇表大小
embedding_dim = 768  # 嵌入维度，与你的图像嵌入维度相同
text_embedding = nn.Embedding(vocab_size, embedding_dim).to(device)

words = "👻🔥😂😁👍🤦‍♀️🤦‍♂️🤷‍♀️🤷‍♂️✌🤞😉😎🎶😢💖🎉🌹💋👏🐱‍💻🐱‍🐉🐱‍👓✔👀😃✨😆🤔🤢🎁🫢 ha ha"
tokens = tokenizer(words, truncation=True, padding= 'max_length', max_length=100, return_tensors="pt")
output = text_embedding(tokens['input_ids'].to(device))

import torch
import torch.nn.functional as F

def find_closest_embeddings(x, embedding_matrix, top_k=1):
    # Normalize both the input tensor x and the embedding matrix
    x = F.normalize(x, dim=1)  # Normalize input tensor along feature dimension
    embedding_matrix = F.normalize(embedding_matrix, dim=1)  # Normalize embedding matrix
    
    # Compute cosine similarity between x and embedding matrix
    similarity = torch.matmul(x, embedding_matrix.T)  # Shape: [10, 50265]
    
    # Find top-k closest embeddings for each tensor in x
    top_k_values, top_k_indices = torch.topk(similarity, top_k, dim=1)
    
    return top_k_indices, top_k_values


# print(output.squeeze(0).shape)
top_k_indices, top_k_values = find_closest_embeddings(output.squeeze(0), text_embedding.weight)
# top_k_indices.shape
indices = tokenizer.decode(top_k_indices.squeeze(-1))
print(indices)

torch.Size([100, 768])
<pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad><bos>👻🔥😂😁👍🤦‍♀️🤦‍♂️🤷‍♀️🤷‍♂️✌🤞😉😎🎶😢💖🎉🌹💋👏🐱‍💻🐱‍🐉🐱‍👓✔👀😃✨😆🤔🤢🎁🫢 ha ha


In [171]:
words = "👻🔥😂😁👍🤦‍♀️🤦‍♂️🤷‍♀️🤷‍♂️✌🤞😉😎🎶😢💖🎉🌹💋👏🐱‍💻🐱‍🐉🐱‍👓✔👀😃✨😆🤔🤢🎁🫢 ha ha"
tokens = tokenizer(words, truncation=True, padding= 'max_length', max_length=100, return_tensors="pt")
tokens

{'input_ids': tensor([[     0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      2, 242538, 237638, 236471, 238429, 237019, 240525,  67292,
         240525,  68399, 239921,  67292, 239921,  68399, 239529, 241807, 238309,
         238859, 240438, 240116, 239208, 239548, 240315, 240887, 238499, 242993,
         235879, 242482, 242993, 235879, 245092, 242993, 235879, 246943, 237488,
         239220, 239938, 236309, 239312, 238918, 241769, 241227, 248165,    661,
            661]]), 'attention_mask': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [14]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
input_text = "Give me three best book."
input_ids = tokenizer(input_text, return_tensors="pt").to(device)
input_ids

# outputs = gemma.generate(**input_ids, max_new_tokens=200)
# print(tokenizer.decode(outputs[0]))



RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when checking argument for argument index in method wrapper_CUDA__index_select)

In [127]:
gemma

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm((2048,), eps=1e-06)
        (post_attention_layernorm): GemmaRMSNorm((2048,), eps=1e-

# 4. Generator

In [None]:
gemma = AutoModelForCausalLM.from_pretrained("google/gemma-2b", device_map="auto",  torch_dtype=torch.bfloat16)

In [162]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.self_att = nn.MultiheadAttention(768, 1)
        self.multi_att = nn.MultiheadAttention(768, 8)
        self.layer_norm = nn.LayerNorm(768)
        self.linear = nn.Linear(768, 768)
        # 減去最後一層
        self.gemma = nn.Sequential(*list(gemma.children())[:-1])
        self.lm_head = nn.Sequential(*list(gemma.children())[1:])
        self.linearFunnyScore = nn.Linear(2048, 1)
    
    def selfAttention(self, x):
        self_out,_ = self.self_att(x, x, x)
        self_out = self.layer_norm(self_out + x)
        return self_out
    
    def multiheadAttention(self, x):
        multi_out,_ = self.multi_att(x, x, x)
        multi_out = self.linear(multi_out)
        multi_out = self.layer_norm(multi_out + x)
        return multi_out
    
    def coAttention(self, x, y):
        # x: self, y: another
        co_out,_ = self.multi_att(x, y, y)
        co_out = self.linear(co_out)
        co_out = self.layer_norm(co_out + y)
        return co_out
    
    def feedForward(self, x):
        ff_out = self.linear(x)
        ff_out = self.layer_norm(ff_out + x)
        return ff_out
    
    def gemmaGenerate(self, x):
        
        def find_closest_embeddings(x, embedding_matrix, top_k=1):
            # Normalize both the input tensor x and the embedding matrix
            x = nn.functional.normalize(x, dim=2)  # Normalize input tensor along feature dimension
            embedding_matrix = nn.functional.normalize(embedding_matrix, dim=1)  # Normalize embedding matrix
        
            # Compute cosine similarity between x and embedding matrix
            similarity = torch.matmul(x, embedding_matrix.T)  # Shape: [10, 50265]
        
            # Find top-k closest embeddings for each tensor in x
            top_k_values, top_k_indices = torch.topk(similarity, top_k, dim=2)
        
            return top_k_indices, top_k_values
        
        
        with torch.no_grad():
            vocab_size = 50265  # 词汇表大小
            embedding_dim = 768  # 嵌入维度，与你的图像嵌入维度相同
            text_embedding = nn.Embedding(vocab_size, embedding_dim).to(device)
            top_k_indices, top_k_values = find_closest_embeddings(x, text_embedding.weight)
            
            # 直接使用gemma生成
            # input = {'input_ids': top_k_indices.squeeze(-1),'attention_mask': torch.ones(top_k_indices.squeeze(-1).shape)}
            # output = gemma.generate(**input, max_new_tokens=200)
            
            # 使用gemma作為model的一部分
            output = self.gemma(top_k_indices.squeeze(-1))
            # output[0] = last_hidden_state
            # output[1] = past_key_values
        return output[0]
               
    
    def forward(self, text, image):
        max_seq_len = max(text.shape[1], image.shape[1])
        text = nn.functional.pad(text, (0, 0, 0, max_seq_len - text.shape[1]))
        image = nn.functional.pad(image, (0, 0, 0, max_seq_len - image.shape[1]))
        text = text.transpose(0, 1)
        image = image.transpose(0, 1)
        
        text = self.selfAttention(text)
        text = self.feedForward(text)
        
        image = self.multiheadAttention(image)
        
        text = self.coAttention(text, image)
        image = self.coAttention(image, text)
        
        feature_fusion = text + image
        feature_fusion = self.feedForward(feature_fusion)        
        feature_fusion = feature_fusion.squeeze(-1)
        feature_fusion = feature_fusion.transpose(0, 1)
        last_hidden_state = self.gemmaGenerate(feature_fusion)
        output_text = self.lm_head(last_hidden_state)
        print(last_hidden_state.shape)
        output_funny_score = self.linearFunnyScore(last_hidden_state.to(torch.float32))
        
        return output_text, output_funny_score

In [163]:
# empty cuda memory
import gc
torch.cuda.empty_cache()
gc.collect()

model = Generator()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
# gemma.to(device)
model

Generator(
  (self_att): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (multi_att): MultiheadAttention(
    (out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
  )
  (layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (linear): Linear(in_features=768, out_features=768, bias=True)
  (gemma): Sequential(
    (0): GemmaModel(
      (embed_tokens): Embedding(256000, 2048, padding_idx=0)
      (layers): ModuleList(
        (0-17): 18 x GemmaDecoderLayer(
          (self_attn): GemmaSdpaAttention(
            (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (k_proj): Linear(in_features=2048, out_features=256, bias=False)
            (v_proj): Linear(in_features=2048, out_features=256, bias=False)
            (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
            (rotary_emb): GemmaRotaryEmbedding()
          )
  

In [164]:
epochs = 10
train_losses = []
test_losses = []
for epoch in range(epochs):
    train_loss = 0
    test_loss = 0
    with tqdm(train_loader, unit="batch") as tepoch:
        for text, image, funny_score in tepoch:
            optimizer.zero_grad()
            logits, output_funny_score = model(text.to(device).to(torch.float32), image.to(device).to(torch.float32))
            # output = gemma.generate(**temp_output, max_new_tokens=200)
            # print(output)

    #         loss = criterion(output, funny_score)
    #         loss.backward()
    #         optimizer.step()
    #         train_loss += loss.item()
    #         tepoch.set_postfix(loss=train_loss)
    # train_losses.append(train_loss)
    # with tqdm(test_loader, unit="batch") as tepoch:
    #     for text, image, funny_score in tepoch:
    #         output = model(text, image)
    #         loss = criterion(output, funny_score)
    #         test_loss += loss.item()
    #         tepoch.set_postfix(loss=test_loss)
    # test_losses.append(test_loss)

  0%|          | 0/10 [00:00<?, ?batch/s]

torch.Size([32, 373, 2048])


 10%|█         | 1/10 [00:06<00:59,  6.64s/batch]


KeyboardInterrupt: 

In [165]:
logits.shape, output_funny_score.shape

(torch.Size([32, 373, 256000]), torch.Size([32, 373, 1]))

# 5. 直接生成時，將其變成可閱讀的文字

In [41]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
# Assuming `output` is your tensor with shape [batch_size, seq_len, embedding_dim]
filteroutput = output[:, 373:]  # Keep only the last 373 tokens

# 有時後空格會失效，所以手動插入空格
def insert_zeros(tensor):
    batch_size, seq_len = tensor.shape
    zeros = torch.zeros(batch_size, 2 * seq_len - 1, dtype=torch.long)
    # zeros = zeros * 6
    zeros[:, ::2] = tensor
    return zeros

# 處理後的 Tensor
result = insert_zeros(filteroutput)
# print(result)
tryoo = tokenizer.batch_decode(result, skip_special_tokens=False)
print(tryoo[10])
tryoo = tokenizer.batch_decode(filteroutput, skip_special_tokens=False)
print(tryoo[10])

# 5. Loss Function

In [97]:
# ##################Loss for G and Ds##############################
def discriminator_loss(netD, real_imgs, fake_imgs, conditions,
                       real_labels, fake_labels):
    # Forward
    real_features = netD(real_imgs)
    fake_features = netD(fake_imgs.detach())
    # loss
    #
    cond_real_logits = netD.COND_DNET(real_features, conditions)
    cond_real_errD = nn.BCELoss()(cond_real_logits, real_labels)
    cond_fake_logits = netD.COND_DNET(fake_features, conditions)
    cond_fake_errD = nn.BCELoss()(cond_fake_logits, fake_labels)
    #
    batch_size = real_features.size(0)
    cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size])
    cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])

    if netD.UNCOND_DNET is not None:
        real_logits = netD.UNCOND_DNET(real_features)
        fake_logits = netD.UNCOND_DNET(fake_features)
        real_errD = nn.BCELoss()(real_logits, real_labels)
        fake_errD = nn.BCELoss()(fake_logits, fake_labels)
        errD = ((real_errD + cond_real_errD) / 2. +
                (fake_errD + cond_fake_errD + cond_wrong_errD) / 3.)
    else:
        errD = cond_real_errD + (cond_fake_errD + cond_wrong_errD) / 2.
    return errD

# text, image, funny_score
def generator_loss(netsD, image_encoder, fake_imgs, real_labels,
                   words_embs, sent_emb, match_labels,
                   cap_lens, class_ids):
    # numDs = len(netsD)
    batch_size = real_labels.size(0)
    logs = ''
    # Forward
    errG_total = 0
    features = netsD[i](fake_imgs[i])
    cond_logits = netsD[i].COND_DNET(features, sent_emb)
    cond_errG = nn.BCELoss()(cond_logits, real_labels)
    if netsD[i].UNCOND_DNET is  not None:
        logits = netsD[i].UNCOND_DNET(features)
        errG = nn.BCELoss()(logits, real_labels)
        g_loss = errG + cond_errG
    else:
        g_loss = cond_errG
    errG_total += g_loss
    # err_img = errG_total.data[0]
    logs += 'g_loss%d: %.2f ' % (i, g_loss.data[0])

    # Ranking loss
    if i == (numDs - 1):
        # words_features: batch_size x nef x 17 x 17
        # sent_code: batch_size x nef
        region_features, cnn_code = image_encoder(fake_imgs[i])
        w_loss0, w_loss1, _ = words_loss(region_features, words_embs,
                                         match_labels, cap_lens,
                                         class_ids, batch_size)
        w_loss = (w_loss0 + w_loss1) * \
                 cfg.TRAIN.SMOOTH.LAMBDA
        # err_words = err_words + w_loss.data[0]

        s_loss0, s_loss1 = sent_loss(cnn_code, sent_emb,
                                     match_labels, class_ids, batch_size)
        s_loss = (s_loss0 + s_loss1) * \
                 cfg.TRAIN.SMOOTH.LAMBDA
        # err_sent = err_sent + s_loss.data[0]

        errG_total += w_loss + s_loss
        logs += 'w_loss: %.2f s_loss: %.2f ' % (w_loss.data[0], s_loss.data[0])
    return errG_total, logs