In [34]:
# adapted from https://github.com/huggingface/transformers/blob/master/examples/contrib/mm-imdb/utils_mmimdb.py
import json
import os
from collections import Counter

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset

POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
class ImageEncoder(nn.Module):
    def __init__(self, args):
        super().__init__()
        model = torchvision.models.resnet152(pretrained=True)
        modules = list(model.children())[:-2] # remove last two layers
        self.model = nn.Sequential(*modules)
        self.pool = nn.AdaptiveAvgPool2d(POOLING_BREAKDOWN[args.num_image_embeds])

    def forward(self, x):
        # Bx3x224x224 -> Bx2048x7x7 -> Bx2048xN -> BxNx2048
        out = self.pool(self.model(x))
        out = torch.flatten(out, start_dim=2)
        out = out.transpose(1, 2).contiguous()
        return out  # BxNx2048
    
class JsonlDataset(Dataset):
    def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length, 
                 image_only=False, text_only=False, use_transformed_tensors=False):
        self.data = [json.loads(l) for l in open(data_path)]
        self.data_dir = os.path.dirname(data_path)
        self.tokenizer = tokenizer
        self.labels = labels
        self.n_classes = len(labels)
        self.max_seq_length = max_seq_length

        self.transforms = transforms

        self.image_only=image_only
        self.text_only=text_only
        self.use_transformed_tensors=use_transformed_tensors

        self.classes=[
            "Crime",
            "Drama",
            "Thriller",
            "Action",
            "Comedy",
            "Romance",
            "Documentary",
            "Short",
            "Mystery",
            "History",
            "Family",
            "Adventure",
            "Fantasy",
            "Sci-Fi",
            "Western",
            "Horror",
            "Sport",
            "War",
            "Music",
            "Musical",
            "Animation",
            "Biography",
            "Film-Noir",
        ]

    def __len__(self):
        return len(self.data)
  
    def get_sentence(self, index):
        sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"],
                                                          add_special_tokens=True,
                                                          max_length=args.max_seq_length))
        #start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
        #sentence = sentence[: self.max_seq_length]
        return sentence
  
    def get_image(self, index):
        if self.use_transformed_tensors:
            id_ = self.data[index]["img"].split('.')[0]
            return torch.load(os.path.join(self.data_dir, '{}.pt'.format(id_)))
        
        image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
        image = self.transforms(image)
        return image
      
    def __getitem__(self, index):
        label = torch.zeros(self.n_classes)
        label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1

        if self.image_only:  
            image = self.get_image(index)
            id_ = self.data[index]["img"].split('.')[0]
            return {"image": image, "id": id_,"label": label}
        elif self.text_only:
            sentence = self.get_sentence(index)
            return {"sentence": sentence, "label": label}

        sentence = self.get_sentence(index)
        image = self.get_image(index)

        return {
            #"image_start_token": start_token,
            #"image_end_token": end_token,
            "sentence": sentence,
            "image": image,
            "label": label,
        }


    def get_label_frequencies(self):
        label_freqs = Counter()
        for row in self.data:
            label_freqs.update(row["label"])
        return label_freqs

def get_mmimdb_labels():
      return [
        "Crime",
        "Drama",
        "Thriller",
        "Action",
        "Comedy",
        "Romance",
        "Documentary",
        "Short",
        "Mystery",
        "History",
        "Family",
        "Adventure",
        "Fantasy",
        "Sci-Fi",
        "Western",
        "Horror",
        "Sport",
        "War",
        "Music",
        "Musical",
        "Animation",
        "Biography",
        "Film-Noir",
      ]


def get_image_transforms():
  return transforms.Compose(
    [
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
    ]
  )
    
def load_examples(args, tokenizer, split='train'):
    path = os.path.join(args.data_dir, '{}.jsonl'.format(split))
    transforms = get_image_transforms()
    labels = get_mmimdb_labels()
    dataset = JsonlDataset(path, tokenizer,
                           transforms, labels, 
                           args.max_seq_length - 2,
                           image_only=args.image_only, 
                           text_only=args.text_only,
                           use_transformed_tensors=args.use_transformed_tensors
                          )
    return dataset

class Args:
    def __init__(self,
                 data_dir='/home/miaortizma/work/datanfs/mmimdb/dataset',
                 output_dir='/homee/miaortizma/work/datanfs/home/miaortizma/mm_output',
                 model_name_or_path='my_model',
                 tokenizer_name='bert-base-uncased',
                 num_image_embeds=1,
                 max_seq_length=128,
                 image_only=False,
                 text_only=False,
                 per_gpu_train_batch_size=8,
                 num_train_epochs=100,
                 patience=5,
                 gradient_accumulation_steps=20,
                 use_transformed_tensors=False,
                ):
        
        self.data_dir = data_dir
        self.output_dir = output_dir

        self.model_name_or_path=model_name_or_path
        self.tokenizer_name=tokenizer_name

        self.num_image_embeds=num_image_embeds
        self.max_seq_length=max_seq_length

        self.num_train_epochs=num_train_epochs
        self.gradient_accumulation_steps=gradient_accumulation_steps
        self.patience=patience

        self.image_only = image_only
        self.text_only = text_only
        self.use_transformed_tensors=use_transformed_tensors

        self.seed=0
        self.n_gpu=1
        self.do_lower_case=True
        #self.cache_dir=None
        #self.max_steps=-1
        self.dropout_prob=0.5
        #self.weight_decay=0.0
        #self.learning_rate=5e-5
        #self.adam_epsilon=1e-8
        #self.max_grad_norm=1.0
        #self.warmup_steps=0
        #self.num_workers=0
        labels = get_mmimdb_labels()
        num_labels = len(labels)
        self.num_labels = num_labels


In [None]:
import random
import numpy as np

def set_seed(args):
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

In [None]:
def collate_fn_image_only(batch):
    img_tensor = torch.stack([row["image"] for row in batch])
    tgt_tensor = torch.stack([row["label"] for row in batch])
    return img_tensor, tgt_tensor

def collate_fn_text_only(batch):
    lens = [len(row["sentence"]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
    text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        text_tensor[i_batch, :length] = input_row["sentence"]
        mask_tensor[i_batch, :length] = 1

    tgt_tensor = torch.stack([row["label"] for row in batch])

    return (text_tensor, mask_tensor), tgt_tensor

def collate_fn_modal(batch):
    
    lens = [len(row["sentence"]) for row in batch]
    bsz, max_seq_len = len(batch), max(lens)

    mask_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)
    text_tensor = torch.zeros(bsz, max_seq_len, dtype=torch.long)

    for i_batch, (input_row, length) in enumerate(zip(batch, lens)):
        text_tensor[i_batch, :length] = input_row["sentence"]
        mask_tensor[i_batch, :length] = 1
        
    img_tensor = torch.stack([row["image"] for row in batch])
    tgt_tensor = torch.stack([row["label"] for row in batch])
    
    return (text_tensor, mask_tensor, img_tensor), tgt_tensor


In [42]:
import torch
from tqdm.notebook import tqdm, trange

def transform_tensors(split):
    args = Args(image_only=True)
    dataset = load_examples(args, None, split=split)
    for data in tqdm(dataset):    
        tensor_path = args.data_dir + '/{}.pt'.format(data['id'])
        torch.save(data['image'], tensor_path)
        
#transform_tensors('train')

HBox(children=(FloatProgress(value=0.0, max=15552.0), HTML(value='')))




In [23]:
#transform_tensors('val')

HBox(children=(FloatProgress(value=0.0, max=2608.0), HTML(value='')))




In [29]:
#transform_tensors('test')

HBox(children=(FloatProgress(value=0.0, max=7799.0), HTML(value='')))




