In [10]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import time
import tqdm
from tqdm.notebook import tqdm as notebooktqdm
import matplotlib.pyplot as plt
import pickle
from sklearn.model_selection import train_test_split
from PIL import Image
import timm
from timm.layers import BatchNormAct2d
import os
from encoding_module import EcodingModel

In [11]:
# work place
work_dir = './'
os.chdir(work_dir)

In [12]:
class YoutubeDataset(Dataset):
    def __init__(self, data, doc2vec):
        self.ids = list(data['video_id'])
        self.titles = doc2vec # pretrained doc2vec features
        self.data = data # video_id, metadata, views(y) from csv file
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.image_encoder = timm.create_model('efficientnet_b1_pruned', features_only =True, pretrained=True)
        model = timm.create_model('efficientnet_b1_pruned', pretrained=True)
        data_cfg = timm.data.resolve_data_config(model.pretrained_cfg)
        self.transform = timm.data.create_transform(**data_cfg)

    def __len__(self):
        return len(self.ids)

    def __getitem__(self, idx):
        # index order of video_id, meta, y are same
        video_id = self.ids[idx]
        
        image = Image.open( work_dir+'medium_15287/{}.jpg'.format(video_id))
        image = self.transform(image)
        # image = torch.FloatTensor(np.array(image)).permute(2, 0, 1).unsqueeze(dim=0)
        self.image_encoder.eval()
        feature_map = self.image_encoder(torch.unsqueeze(image,0))[-1].squeeze() # (320,6,10)
        
        title = self.titles[video_id] # get video title
        title = torch.FloatTensor(np.array(title, dtype=np.float16))
        
        meta = torch.FloatTensor(self.data[['period_day', 'subscriber_count']].to_numpy()[idx]) # get metadata
        
        y = np.log10(self.data['views'].to_numpy() + 1) # add 1 for zero views
        y = np.expand_dims(y, axis=1) # add batch dimension
        y = torch.FloatTensor(y[idx]) # get log10(views+1) by idx value
        
        return video_id, feature_map, title, meta, y

In [13]:
# add nomarlizing
data = pd.read_csv('./train.csv')
mean_period = data['period_day'].mean()
std_period = data['period_day'].std()
mean_sub = data['subscriber_count'].mean()
std_sub = data['subscriber_count'].std()
print(mean_period, std_period, mean_sub, std_sub)

data['period_day'] = (data['period_day'] - mean_period)/std_period
data['subscriber_count'] = (data['subscriber_count']-mean_sub)/std_sub

train_data, valid_data = train_test_split(data, test_size = 0.1, random_state = 55)
test_data = pd.read_csv('./test.csv')
# train_data = train_data[:1000]
# valid_data = valid_data[:100]
# test_data = test_data[:100]
print('Train Dataset Size : ',len(train_data))
print('Validation Dataset Size : ',len(valid_data))
print('Test Dataset Size : ',len(test_data))

data.head()

335.8148713475796 497.7613157973895 1784323.5617822357 3833786.6144638904
Train Dataset Size :  1000
Validation Dataset Size :  100
Test Dataset Size :  100


Unnamed: 0.1,Unnamed: 0,video_id,publish_time,publish_date,channel_id,title,views,period_day,channel_title,subscriber_count,...,dislikes,comment_count,description,desc_len,len_title,No_tags,video_error_or_removed,trending_date,comments_disabled,ratings_disabled
0,1636,uGOskK94nPU,14:00:35,2023-01-18,UCPqyMgj9n1GxSU-RyCjqLPA,"Küçük, Orta ve Büyük Tabak Meydan Okumasi | Ye...",2005786,-0.457679,Multi DO Turkish,-0.405949,...,False,116.0,"Bir sürü yiyecek elbette harikadır, ancak daha...",318.0,75.0,0.0,False,0.0,False,False
1,4459,6Xu_RjV0Wjo,11:00:03,2023-05-05,UCsRNwIyd1WnjLR89mWtrC-A,일산 분위기 좋은 카페 찾으시나요? 여기어때요!,93,-0.672641,저두영 jodooyoung,-0.465411,...,False,0.0,일산 분위기 좋은 카페 좋아하세요? 저두영 \n분위기가 좋은 일산 대형 베이커리 카...,360.0,26.0,0.0,False,0.0,False,False
2,5168,MPU7iOsUJtI,08:30:04,2023-04-28,UC5GHkCa1PhGycqnEbp50KrQ,[와차밥] 차돌박이 된장국수 오이김치 고추김치 요리 먹방 Soybean Paste ...,87779,-0.658578,버들Buddle,-0.389517,...,False,74.0,🎵Music provided by 브금대통령\n🎵Track : Candy - htt...,70.0,98.0,0.0,False,0.0,False,False
3,14076,GZh1Qkr7uoA,08:12:19,2015-03-11,UCdwLMpZeFXN5ODyCO9warSw,BJ애봉이먹방 깐풍새우 & 고추잡채 & 불짜장 & 볶음짬뽕,120826,5.308137,애봉이,-0.441815,...,False,38.0,(미국NPR 공영라디오방송취재중) USA NPR radio 좋아요 와 구독하기 ...,201.0,72.0,0.0,False,0.0,False,False
4,13448,4omqsDYxQGw,01:00:00,2020-03-13,UC9d1Mz9bzCE-t_t0lPnrjPA,Fettuccine Alfredo Mukbang,233562,1.633685,BenDeen,-0.134156,...,False,580.0,Check out my Instagram: https://www.instagram....,219.0,26.0,0.0,False,0.0,False,False


In [14]:
# open doc2vec data and conver to dict
with open('./title_doc2vec_10', 'rb') as f:
    doc2vec = pickle.load(f)

data_dict=dict()
for row in doc2vec:
    vid=row[0]
    vec=row[1:]
    data_dict[vid]=vec

doc2vec = data_dict
print(len(doc2vec))

15287


In [16]:
train_dataset = YoutubeDataset(train_data, doc2vec)
valid_dataset = YoutubeDataset(valid_data, doc2vec)
test_dataset = YoutubeDataset(test_data, doc2vec)

train_loader = DataLoader(train_dataset, batch_size = batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size = 1)
test_loader = DataLoader(test_dataset, batch_size = 1)