In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset, DataLoader,random_split
import torchvision
from torchvision import datasets, models, transforms
import numpy as np
import matplotlib.pyplot as plt
import time
import os
import PIL
import pickle

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
WORKING_PATH=""
TEXT_LENGTH=75
TEXT_HIDDEN=256

In [None]:
%cd /content/drive/MyDrive/Colab\ Notebooks/pytorch-multimodal_sarcasm_detection

/content/drive/MyDrive/Colab Notebooks/pytorch-multimodal_sarcasm_detection


## Build data set corresponding to image path

In [None]:
def load_data():
  data_set = dict()
  for dataset in ["train"]:
    file = open(os.path.join(WORKING_PATH,"text_data/",dataset+".txt"),"rb")
    for line in file:
      content = eval(line)
      image = content[0]
      sentence = content[1]
      group = content[2]
      if os.path.isfile(os.path.join(WORKING_PATH,"image_data/",image+".jpg")):
        data_set[int(image)]={"text":sentence,"group":group}

    for dataset in ["test","valid"]:
        file=open(os.path.join(WORKING_PATH,"text_data/",dataset+".txt"),"rb")
        for line in file:
            content=eval(line)
            image=content[0]
            sentence=content[1]
            group=content[3] #2
            if os.path.isfile(os.path.join(WORKING_PATH,"image_data/",image+".jpg")):
                data_set[int(image)]={"text":sentence,"group":group}  
                
  return data_set

In [None]:
data_set = load_data()

# Load all training data

In [None]:
def load_word_index():
  word2index = pickle.load(open(os.path.join(WORKING_PATH,"text_embedding/vocab.pickle"),"rb"),encoding="latin1")
  return word2index

In [None]:
def load_image_labels():
  img2labels=dict()
  with open(os.path.join(WORKING_PATH,'multilabel_database/','img_to_five_words.txt'),'rb') as file:
    for line in file:
      content=eval(line)
      img2labels[int(content[0])] = content[1:]
  label2index=pickle.load(open(os.path.join(WORKING_PATH,'multilabel_database_embedding/vocab.pickle'),'rb'))
  return img2labels, label2index

In [None]:
img2labels,label2index=load_image_labels()

In [None]:
word2index = load_word_index()

# Build dataloader

In [None]:
# save data_set to dataloder
class my_data_set(Dataset):
  def __init__(self, data):
    self.data = data
    self.image_ids = list(data.keys())

    # add image path to data_set
    for id in data.keys():
      self.data[id]['image_path'] = os.path.join(WORKING_PATH,'image_data/',str(id)+'.jpg')

    # add text index to data_set
    for id in data.keys():
      text = self.data[id]['text'].split()
      text_index = torch.empty(TEXT_LENGTH,dtype=torch.long)
      curr_length=len(text)
      for i in range(TEXT_LENGTH):
        if i >= curr_length:
          text_index[i]=word2index["<pad>"]
        elif text[i] in word2index:
          text_index[i] = word2index[text[i]]
        else:
          text_index[i] = word2index['<unk>']
      self.data[id]["text_index"] = text_index

  # the image feature loader - resnet 50 result
  def __image_feature_loader(self,id):
    attribute_feature = np.load(os.path.join(WORKING_PATH,"image_feature_data",str(id)+".npy"))
    return torch.from_numpy(attribute_feature)

  # the attribute index loader - 5 words label
  def __attribute_loader(self,id):
    labels = img2labels[id]
    label_index = list(map(lambda label: label2index[label], labels))
    return torch.tensor(label_index)

  # the text index loader
  def __text_index_loader(self,id):
    return self.data[id]['text_index']

  def image_loader(self,id):
    path=self.data[id]['image_path']
    img_pil = PIL.Image.open(path)
    transfrm = transforms.Compose([transforms.Resize((448,448)),
                                   transforms.ToTensor()])
    img_tensor = transform(img_pil)
    return img_tensor

  def text_loader(self,id):
    return self.data[id]['text']
  
  def label_loader(self,id):
    return img2labels[id]

  def __getitem__(self,index):
    id = self.image_ids[index]
    text_index = self.__text_index_loader(id)
    image_feature = self.__image_feature_loader(id)
    attribute_index = self.__attribute_loader(id)
    group = self.data[id]['group']
    return text_index, image_feature, attribute_index, group, id

  def __len__(self):
    return len(self.image_ids)

In [None]:
all_Data = my_data_set(data_set)

# Split data

In [None]:
def train_val_test_split(all_Data, train_fraction, val_fraction):
  train_val_test_count = [int(len(all_Data)*train_fraction),int(len(all_Data)*val_fraction),0]
  train_val_test_count[2] = len(all_Data) - sum(train_val_test_count)
  return random_split(all_Data, train_val_test_count, generator=torch.Generator().manual_seed(42))

# DataLoader Setup

In [None]:
train_fraction = 0.8
val_fraction = 0.1
batch_size = 32
train_set, val_set, test_set = train_val_test_split(all_Data, train_fraction, val_fraction)

In [None]:
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=True)

In [None]:
play_loader = DataLoader(test_set,batch_size=1,shuffle=True)

# Example of the data

In [120]:
if __name__ == "__main__":
  for text_index, image_feature, attribute_index, group, id in train_loader:
    print("text:", text_index.shape, text_index.type())
    print("image feature:",image_feature.shape,image_feature.type())
    print("attribute:", attribute_index.shape, attribute_index.type())
    print("group:", group, group.type())
    print("image id:", id, id.type())
    break    

text: torch.Size([32, 75]) torch.LongTensor
image feature: torch.Size([32, 196, 2048]) torch.FloatTensor
attribute: torch.Size([32, 5]) torch.LongTensor
group: tensor([0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1, 1,
        0, 0, 1, 0, 1, 0, 1, 0]) torch.LongTensor
image id: tensor([822224237905793024, 819329457823555584, 820679152785362945,
        874074139098853376, 822589640704192512, 820412718398210048,
        724359642034393088, 692513161967226880, 925082504532516864,
        822950519182336001, 719365080123580416, 823311358649307137,
        845792552288763905, 823318303179440129, 779248413447774208,
        820050464607784962, 823310708662222850, 819694597718888448,
        820418811283156993, 690633649985851392, 896370448950333442,
        916038986413543424, 839676220862267392, 702671588467077120,
        821504607650193408, 820055019718377472, 880274089881071616,
        822223585456558080, 722388056917999616, 820057662113140739,
        694661584761

# Text feature extract

In [121]:
class ExtractTextFeature(torch.nn.Module):
    def __init__(self,text_length,hidden_size,dropout_rate=0.2):
        super(ExtractTextFeature, self).__init__()
        self.hidden_size=hidden_size
        self.text_length=text_length
        embedding_weight=self.getEmbedding()
        self.embedding_size=embedding_weight.shape[1]
        self.embedding=torch.nn.Embedding.from_pretrained(embedding_weight)
        self.biLSTM=torch.nn.LSTM(input_size=200,hidden_size=hidden_size,bidirectional=True,batch_first=True)

        # early fusion
        self.Linear_1=torch.nn.Linear(200,hidden_size)
        self.Linear_2=torch.nn.Linear(200,hidden_size)
        self.Linear_3=torch.nn.Linear(200,hidden_size)
        self.Linear_4=torch.nn.Linear(200,hidden_size)

        # dropout
        self.dropout=torch.nn.Dropout(dropout_rate)

    def forward(self, input, guidence):
        embedded=self.embedding(input).view(-1, self.text_length, self.embedding_size)

        if(guidence is not None):
            # early fusion
            hidden_init=torch.stack([torch.relu(self.Linear_1(guidence)),torch.relu(self.Linear_2(guidence))],dim=0)
            cell_init=torch.stack([torch.relu(self.Linear_3(guidence)),torch.relu(self.Linear_4(guidence))],dim=0)
            output,_=self.biLSTM(embedded,(hidden_init,cell_init))
        else:
            output,_=self.biLSTM(embedded,None)

        # dropout
        output=self.dropout(output)

        RNN_state=torch.mean(output,1)
        return RNN_state,output

    def getEmbedding(self):
        return torch.from_numpy(np.loadtxt("text_embedding/vector.txt", delimiter=' ', dtype='float32'))

In [140]:
test = ExtractTextFeature(75,256)

In [142]:
test.getEmbedding().shape

torch.Size([12280, 200])

In [153]:
x = torch.randn(32,200)
y = torch.randn(32,75)

In [154]:
if __name__ == "__main__":
    test=ExtractTextFeature(LoadData.TEXT_LENGTH, LoadData.TEXT_HIDDEN)
    for text_index,image_feature,attribute_index,group,id in LoadData.train_loader:
        result,seq=test(text_index,x)
        print(result.shape)
        print(seq.shape)

        # print(text_index.type())
        break

torch.Size([32, 512])
torch.Size([32, 75, 512])


In [129]:
class testnet(torch.nn.Module):
  def __init__(self,text_length,hidden_size,dropout_rate=0.2):
    super(testnet, self).__init__()
    self.Linear_1 = torch.nn.Linear(200,hidden_size)

  def forward(self, guidence):

    hidden_init = torch.relu(self.Linear_1(guidence))

    return hidden_init

In [134]:
if __name__ == "__main__":
  test = testnet(75,256)
  x = torch.randn(2,200)
  result = test(x)
  print(result.shape)

torch.Size([2, 256])
