In [1]:
import re
#import os
import cv2
import random
import torch
#import torchvision
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split

In [3]:
df = pd.read_csv('mustard++_text.csv')
# df.drop(['END_TIME', 'SPEAKER', 'SHOW', 'Sarcasm_Type', 'Valence', 'Arousal'],axis=1,inplace=True)
#df.drop(['END_TIME', 'Sarcasm_Type', 'Valence', 'Arousal'],axis=1,inplace=True)
#df.dropna(subset=['Sarcasm'], inplace=True)

# Preprocess the data

In [4]:
data_dict = []

list_of_text = []
for index, row in df.iterrows():
    if row['Sarcasm'] in [0.0, 1.0]:
      text = row['SENTENCE']
      text = re.sub("[\n]", " ", text)
      list_of_text.append(text)

      tmp = {'key': row['SCENE'], 
             'image': row['KEY'], 
             'text': list_of_text,
             'label': row['Sarcasm']}

      data_dict.append(tmp)
      list_of_text = []
    else:
      text = row['SENTENCE']
      text = re.sub("[\n]", " ", text)
      list_of_text.append(text)

In [5]:
for i in data_dict:
  i['label'] = int(i['label'])

In [6]:
def is_valid_frame(frame):
    return frame is not None and frame.size > 0

failed_data_points = []

videos = []
text = []
labels = []
ids = []

down_width = 384
down_height = 224
down_points = (down_width, down_height)

num_frames = 16
for data in data_dict[:]:
    video_id = data['image']
    video_path = 'videos/final_utterance_videos/'+video_id+'.mp4'
    cam = cv2.VideoCapture(video_path)
    total_frames = int(cam.get(cv2.CAP_PROP_FRAME_COUNT))

    # skip data point which are shorter than num_frames
    if total_frames < num_frames:
        failed_data_points.append(video_path)
        continue

    random_frame_idxs = random.sample(range(total_frames), num_frames)

    frames = []
    for idx, frame_idx in enumerate(sorted(random_frame_idxs)):
        valid_frame = False
        attempts = 0 
        
        while not valid_frame and attempts < 3:
            cam.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cam.read()

            if ret and is_valid_frame(frame):
                resized_frame = cv2.resize(frame, down_points, interpolation=cv2.INTER_LINEAR)
                frames.append(resized_frame)
                valid_frame = True
            else:
                attempts += 1
                if frame_idx < total_frames - 1:
                    frame_idx += 1
                else:
                    frame_idx -= 1

    # if any frames are corrupted, skip data point
    if len(frames) < num_frames:
        failed_data_points.append(video_path)
        continue

    # print(f'video: {video_id}, frames {len(frames)}')

    video = np.array(frames)
    tensor_video = torch.from_numpy(video)
    videos.append(tensor_video)

    text.append(data['text'])
    labels.append(data['label'])
    ids.append(data['key'])

# Sample (show dependent - speaker independet) sets

In [None]:
train_data = []
for index, row in df.iterrows():
    if row['SHOW'] in ['BBT', 'SV']:
        train_data.append(row['SCENE'])

In [None]:
train_video = []
train_text = []
train_label = []
train_id = []
rest_videos = []
rest_text = []
rest_labels = []
rest_ids = []

for index, id in enumerate(ids):
    if id in train_data:
        train_video.append(videos[index])
        train_text.append(text[index])
        train_label.append(labels[index])
        train_id.append(ids[index])
    else:
        rest_videos.append(videos[index])
        rest_text.append(text[index])
        rest_labels.append(labels[index])
        rest_ids.append(ids[index])

In [None]:
val_text, test_text, val_video, test_video, val_label, test_label, val_id, test_id = train_test_split(rest_text, rest_videos, rest_labels, rest_ids, test_size=0.5, stratify=rest_labels)

# Sample balanced (train, val, test) sets

In [None]:
train_text, temp_text, train_video, temp_video, train_label, temp_label, train_id, temp_id = train_test_split(text, videos, labels, ids, test_size=0.2, stratify=labels)

In [None]:
val_text, test_text, val_video, test_video, val_label, test_label, val_id, test_id = train_test_split(temp_text, temp_video, temp_label, temp_id, test_size=0.5, stratify=temp_label)

# Sample random (train, val, test) sets

In [7]:
index = np.arange(len(labels))
np.random.shuffle(index)

In [8]:
train, val_test = train_test_split(index, test_size=0.2)
val, test = train_test_split(val_test, test_size=0.5)

In [None]:
test_text = [text[i] for i in test]
test_video = [videos[i] for i in test]
test_label = [labels[i] for i in test]
test_id = [ids[i] for i in test]

train_text = [text[i] for i in train]
train_video = [videos[i] for i in train]
train_label = [labels[i] for i in train]
train_id = [ids[i] for i in train]

val_text = [text[i] for i in val]
val_video = [videos[i] for i in val]
val_label = [labels[i] for i in val]
val_id = [ids[i] for i in val]

# Save the data

In [10]:
torch.save(train_video, f"preprocessed/video_train.pt")
torch.save(train_text, f"preprocessed/text_train.pt")
torch.save(train_label, f"preprocessed/labels_train.pt")
torch.save(train_id, f"preprocessed/ids_train.pt")

torch.save(val_video, f"preprocessed/video_val.pt")
torch.save(val_text, f"preprocessed/text_val.pt")
torch.save(val_label, f"preprocessed/labels_val.pt")
torch.save(val_id, f"preprocessed/ids_val.pt")

torch.save(test_video, f"preprocessed/video_test.pt")
torch.save(test_text, f"preprocessed/text_test.pt")
torch.save(test_label, f"preprocessed/labels_test.pt")
torch.save(test_id, f"preprocessed/ids_test.pt")

In [None]:
torch.save(train_video, f"preprocessed/video_train_2.pt")
torch.save(train_text, f"preprocessed/text_train_2.pt")
torch.save(train_label, f"preprocessed/labels_train_2.pt")
torch.save(train_id, f"preprocessed/ids_train_2.pt")

torch.save(val_video, f"preprocessed/video_val_2.pt")
torch.save(val_text, f"preprocessed/text_val_2.pt")
torch.save(val_label, f"preprocessed/labels_val_2.pt")
torch.save(val_id, f"preprocessed/ids_val_2.pt")

torch.save(test_video, f"preprocessed/video_test_2.pt")
torch.save(test_text, f"preprocessed/text_test_2.pt")
torch.save(test_label, f"preprocessed/labels_test_2.pt")
torch.save(test_id, f"preprocessed/ids_test_2.pt")