# Import & Settings

### Import

In [2]:
from dateutil.relativedelta import relativedelta

import numpy as np
from tqdm import tqdm

import pandas as pd; pd.set_option("display.max_columns", None)
from sklearn.preprocessing import MinMaxScaler

import torch
from torchvision import transforms
from PIL import Image

### Settings

In [3]:
batch_size = 128

# Read and Process Data

### Read data

In [4]:
df_train = pd.read_csv("../visuelle/train.csv", parse_dates=["release_date"])
g_trend = pd.read_csv("../visuelle/gtrends.csv", parse_dates=["date"])

In [5]:
df_train.head()

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,external_code,season,category,release_date,day,week,month,year,image_path,color,fabric,extra
0,0.004695,0.073239,0.061972,0.066667,0.046009,0.043192,0.026291,0.019718,0.012207,0.014085,0.010329,0.00939,1,SS17,long sleeve,2016-12-01,0.5,0.923077,1.0,0.998514,PE17/00001.png,yellow,acrylic,hem
1,0.005634,0.109859,0.128638,0.135211,0.082629,0.098592,0.06385,0.052582,0.034742,0.138967,0.159624,0.055399,2,SS17,long sleeve,2016-12-01,0.5,0.923077,1.0,0.998514,PE17/00002.png,brown,acrylic,hem
2,0.002817,0.207512,0.177465,0.095775,0.041315,0.030047,0.015023,0.006573,0.010329,0.005634,0.002817,0.001878,3,SS17,culottes,2016-12-02,0.666667,0.923077,1.0,0.998514,PE17/00003.png,blue,scuba crepe,hem
3,0.000939,0.044131,0.046948,0.041315,0.028169,0.031925,0.031925,0.023474,0.016901,0.028169,0.020657,0.00939,4,SS17,long sleeve,2016-12-02,0.666667,0.923077,1.0,0.998514,PE17/00004.png,yellow,acrylic,sleeveless
4,0.006573,0.098592,0.125822,0.120188,0.068545,0.046948,0.043192,0.034742,0.030047,0.029108,0.033803,0.00939,5,SS17,long sleeve,2016-12-02,0.666667,0.923077,1.0,0.998514,PE17/00005.png,grey,acrylic,hem


### Process data

In [194]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data, g_trend):
        self.data = data.iloc[:, :12].values
        
        self.sales_data = self.get_sales_data(data)
        self.gtrend_data = self.get_gtrend_data(data, g_trend)

    def get_sales_data(self, data):
        return data.iloc[:, :12].values
    
    def get_gtrend_data(self, data, g_trend):
        release_date = data["release_date"]
        gtrend_start_date = release_date.apply(lambda x: x - relativedelta(weeks=52))
        display(g_trend)
        

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return len(self.data)

dataset = Dataset(df_train, g_trend)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False)
next(iter(dataloader))
""

Unnamed: 0,date,long sleeve,culottes,miniskirt,short sleeves,printed,short cardigan,solid colours,trapeze dress,sleeveless,long cardigan,sheath dress,short coat,medium coat,doll dress,long dress,shorts,long coat,tracksuit,drop sleeve,patterned,kimono dress,medium cardigan,shirt dress,maxi,capris,gitana skirt,long duster,yellow,brown,blue,grey,green,black,red,white,orange,violet,acrylic,scuba crepe,tulle,angora,faux leather,georgette,lurex,nice,crepe,satin cotton,silky satin,fur,matte jersey,plisse,velvet,lace,cotton,piquet,plush,bengaline,jacquard,frise,technical,cady,dark jeans,light jeans,ity,plumetis,polyviscous,dainetto,webbing,foam rubber,chanel,marocain,macrame,embossed,heavy jeans,nylon,tencel,paillettes,chambree,chine crepe,muslin cotton or silk,linen,tactel,viscose twill,cloth,mohair,mutton,scottish,milano stitch,devore,hron,ottoman,fluid,flamed,fluid polyviscous,shiny jersey,goose
0,2015-10-05,55.0,41.0,66.0,58.0,58.0,22.0,24.0,43.0,66.0,74.0,42.0,44.0,10.0,73.0,70.0,38.0,36.0,30.0,11.0,75.0,25.0,0.0,78.6,63.0,37.0,0.0,24.0,83,66,77,18,79,25,61,77,70,56,57,18,81,65,50,25,23,20,35,61,10,54,35,26,54,70,68,33,35,13,73,80,89,29,50,63,62,20,-1.0,0,55,46,73,81,54,56,55,86,36,18,0,20,-1.0,52,33,0,82,62,45,57,0,29,17,57,80,49,-1.0,0,34
1,2015-10-12,55.0,44.0,73.0,71.0,61.0,31.0,37.0,21.2,63.0,59.0,52.0,47.0,29.0,85.0,72.0,37.0,33.0,27.0,33.0,83.0,40.0,0.0,75.0,64.0,32.0,0.0,14.0,80,66,95,18,76,26,62,77,71,57,58,0,86,63,52,31,21,18,37,37,59,58,70,38,54,68,69,23,38,27,76,76,89,26,56,66,72,30,-1.0,0,54,72,72,80,56,74,32,85,39,12,0,27,-1.0,55,48,0,82,66,43,54,0,17,14,61,80,41,-1.0,0,39
2,2015-10-19,60.0,42.0,73.0,62.0,61.0,43.0,37.0,27.4,47.0,71.0,37.0,46.0,29.0,100.0,79.0,36.0,39.0,31.0,55.0,71.0,39.0,0.0,68.8,66.0,23.0,0.0,34.0,83,66,92,18,76,27,62,77,67,58,59,0,81,70,52,27,19,18,43,30,29,59,52,33,53,100,68,26,38,17,86,62,81,27,59,66,63,30,-1.0,0,57,43,69,83,48,66,23,86,32,14,30,47,-1.0,54,35,48,81,63,46,55,0,21,23,63,79,52,-1.0,33,41
3,2015-10-26,57.0,44.0,68.0,51.0,55.0,19.0,37.0,59.8,54.0,64.0,38.0,43.0,43.0,96.0,74.0,36.0,42.0,31.0,22.0,73.0,42.0,17.0,71.8,65.0,23.0,0.0,43.0,82,67,70,17,74,28,62,76,69,52,62,17,95,70,60,32,20,17,44,24,39,63,35,33,54,73,68,30,39,13,73,76,83,33,66,53,72,30,-1.0,0,54,72,69,81,46,63,46,85,38,16,0,27,-1.0,56,29,0,84,61,44,55,0,26,14,58,76,33,-1.0,32,37
4,2015-11-02,49.0,39.0,66.0,39.0,63.0,16.0,37.0,44.0,51.0,67.0,49.0,46.0,0.0,59.0,70.0,34.0,35.0,32.0,22.0,71.0,37.0,34.0,64.2,67.0,21.0,0.0,34.0,76,66,69,17,77,32,60,77,70,56,57,0,66,66,55,33,19,18,35,55,39,63,0,36,58,66,72,27,39,10,78,78,83,32,53,56,72,25,-1.0,0,66,64,69,91,48,55,47,85,34,20,0,27,-1.0,61,33,0,85,61,46,55,0,18,18,60,80,37,-1.0,98,38
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
215,2019-11-18,91.0,53.0,37.0,66.0,72.0,66.0,30.0,10.8,67.0,79.0,40.0,73.0,59.0,67.0,94.0,39.0,89.0,84.0,63.0,86.0,47.0,0.0,79.8,75.0,23.0,0.0,75.0,86,66,79,17,80,47,57,82,48,57,81,43,70,60,89,43,77,18,41,50,24,81,14,62,64,63,88,35,71,71,84,60,77,25,83,63,84,41,-1.0,0,64,48,76,56,89,87,64,76,68,38,0,27,-1.0,80,63,0,94,82,57,63,0,19,19,86,92,49,-1.0,27,60
216,2019-11-25,94.0,45.0,35.0,56.0,67.0,35.0,50.0,25.0,63.0,83.0,46.0,83.0,51.0,78.0,91.0,44.0,92.0,100.0,45.0,79.0,45.0,14.0,81.8,81.0,22.0,0.0,47.0,84,69,80,19,91,98,58,86,50,59,84,100,73,63,93,51,78,18,37,54,32,86,0,72,69,65,93,25,89,95,88,63,65,18,83,93,66,41,-1.0,0,48,56,86,54,77,91,25,75,87,49,49,27,-1.0,79,84,0,98,84,57,63,0,22,25,78,83,30,-1.0,80,61
217,2019-12-02,100.0,47.0,28.0,73.0,71.0,59.0,20.0,9.4,72.0,86.0,43.0,68.0,44.0,61.0,94.0,45.0,100.0,93.0,73.0,89.0,40.0,41.0,78.8,78.0,24.0,0.0,52.0,83,62,79,16,75,33,59,85,50,68,92,29,75,65,100,43,100,18,38,25,24,89,14,63,69,65,94,25,100,72,96,62,72,25,83,100,83,33,-1.0,0,76,69,87,53,74,77,32,80,84,42,25,6,-1.0,75,70,0,95,93,56,66,31,21,15,87,100,34,-1.0,0,61
218,2019-12-09,93.0,47.0,40.0,49.0,72.0,47.0,31.0,10.8,75.0,87.0,43.0,64.0,100.0,52.0,83.0,45.0,90.0,89.0,37.0,76.0,46.0,42.0,81.2,75.0,26.0,0.0,40.0,81,62,80,15,77,25,60,93,48,58,91,29,74,61,88,43,94,17,37,41,33,82,29,71,66,60,95,25,87,92,84,64,72,21,77,79,68,50,-1.0,74,54,59,88,59,71,87,32,73,62,54,50,22,-1.0,69,99,40,93,62,57,83,0,18,19,79,86,22,-1.0,27,59


''

In [190]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, data):
        gtrend, image_feature = [], []
        image_transformer = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
                ]) # Transform image based on ImageNet standard
        
        for (idx, prod) in tqdm(data.iterrows(), total=data.shape[0]):
            # Deal with categorical variables
            prod_release_date, prod_category, prod_color, prod_fabric = prod["release_date"], prod["category"], prod["color"], prod["fabric"]    
            prod_gtrend_start = prod_release_date - relativedelta(weeks=52) # Gtrend data from 1 year(52 weeks) ago
            prod_gtrend = g_trend[(g_trend["date"] > prod_gtrend_start)
                                &(g_trend["date"] <= prod_release_date)]

            # Obtain muti_gtrends
            category_gtrend = prod_gtrend[prod_category].values
            color_gtrend = prod_gtrend[prod_color].values
            fabric_gtrend = prod_gtrend[prod_fabric].values

            # Input does not have to be inverse transformed nor fitted for validation set
            category_gtrend = MinMaxScaler().fit_transform(category_gtrend.reshape(-1, 1)).flatten()
            color_gtrend = MinMaxScaler().fit_transform(color_gtrend.reshape(-1, 1)).flatten()
            fabric_gtrend = MinMaxScaler().fit_transform(fabric_gtrend.reshape(-1, 1)).flatten()
            
            multi_gtrends = np.stack([category_gtrend, color_gtrend, fabric_gtrend], axis=-1)
            gtrend.append(multi_gtrends)

            # Read image
            prod_img_path = prod["image_path"]
            img = Image.open(f"../visuelle/images/{prod_img_path}").convert("RGB")
            image_feature.append(image_transformer(img))

            if idx == 10:
                break

        gtrend = np.array(gtrend)

        
    def __getitem__(self, idx):
        return self.data[idx]
    
    def __len__(self):
        return len(self.data)
    
train_dataset = Dataset(df_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

  0%|          | 10/5080 [00:00<02:51, 29.49it/s]


AttributeError: 'Dataset' object has no attribute 'data'