<a href="https://colab.research.google.com/github/nicolas-dufour/rakuten_colour_extraction/blob/master/rakuten_challenge.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Load Git folder
import os
from getpass import getpass
import urllib
repo_user = 'nicolas-dufour'
user = 'nicolas-dufour'
password = getpass('Password: ')
repo_name = 'rakuten_colour_extraction'
# your password is converted into url format
password = urllib.parse.quote(password)
cmd_string = 'git clone https://{0}:{1}@github.com/{2}/{3}.git'.format(user, password, repo_user, repo_name)
os.system(cmd_string)
cmd_string, password = "", "" # removing the password from the variable
# Bad password fails silently so make sure the repo was copied
assert os.path.exists(f"/content/{repo_name}"), "Incorrect Password or Repo Not Found, please try again"

Password: ··········


In [5]:
%cd rakuten_colour_extraction/

/content/rakuten_colour_extraction


In [17]:
# Google drive connection
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
# Git Ignore setup
!echo 'lightning_logs' >> .gitignore
!echo 'wandb' >> .gitignore

In [None]:
!git status

On branch master
Your branch is up to date with 'origin/master'.

nothing to commit, working tree clean


In [24]:
# Save to git
!git config --global user.email "nicolas.dufourn@gmail.com"
!git config --global user.name "Nicolas DUFOUR"
!git add --all
!git commit -m "Added NFNET"
!git push --force

[master a99c8b5] Added NFNET
 6 files changed, 68 insertions(+)
Counting objects: 12, done.
Delta compression using up to 2 threads.
Compressing objects: 100% (12/12), done.
Writing objects: 100% (12/12), 3.19 KiB | 3.19 MiB/s, done.
Total 12 (delta 4), reused 0 (delta 0)
remote: Resolving deltas: 100% (4/4), completed with 4 local objects.[K
To https://github.com/nicolas-dufour/rakuten_colour_extraction.git
   2ef5cdb..a99c8b5  master -> master


In [4]:
%%capture
!pip install transformers
!pip install pytorch-lightning
!pip install wandb
!pip install git+https://github.com/rwightman/pytorch-image-models

In [7]:
%load_ext autoreload
%autoreload 2

import pandas as pd
from skimage import io
import numpy as np
import ast
from tqdm.notebook import tqdm

import wandb

import timm
from timm.data import create_transform

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.metrics.classification import Accuracy, F1

from transformers import BertTokenizer, BertModel

from data.bert import Bert_dataset
from data.images import ImageDataset, TestImageDataset
from models.bert_model import Bert_classifier, train
from sklearn.preprocessing import MultiLabelBinarizer

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [8]:
data_path = '/content/rakuten_colour_extraction/data_files/'

# Data Loading

In [7]:
!mkdir data_files

In [8]:
!echo 'data_files' >> .gitignore

In [9]:
!wget  https://challengedata.ens.fr/participants/challenges/59/download/x-train --load-cookies /content/drive/MyDrive/rakuten_challenge/ens.fr_cookies.txt -O /content/rakuten_colour_extraction/data_files/X_train.csv

--2021-03-05 10:33:09--  https://challengedata.ens.fr/participants/challenges/59/download/x-train
Resolving challengedata.ens.fr (challengedata.ens.fr)... 129.199.99.143
Connecting to challengedata.ens.fr (challengedata.ens.fr)|129.199.99.143|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 329723647 (314M) [application/octet-stream]
Saving to: ‘/content/rakuten_colour_extraction/data_files/X_train.csv’


2021-03-05 10:33:28 (16.6 MB/s) - ‘/content/rakuten_colour_extraction/data_files/X_train.csv’ saved [329723647/329723647]



In [10]:
!wget  https://challengedata.ens.fr/participants/challenges/59/download/y-train --load-cookies /content/drive/MyDrive/rakuten_challenge/ens.fr_cookies.txt -O /content/rakuten_colour_extraction/data_files/y_train.csv

--2021-03-05 10:33:28--  https://challengedata.ens.fr/participants/challenges/59/download/y-train
Resolving challengedata.ens.fr (challengedata.ens.fr)... 129.199.99.143
Connecting to challengedata.ens.fr (challengedata.ens.fr)|129.199.99.143|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 5592658 (5.3M) [application/octet-stream]
Saving to: ‘/content/rakuten_colour_extraction/data_files/y_train.csv’


2021-03-05 10:33:29 (16.5 MB/s) - ‘/content/rakuten_colour_extraction/data_files/y_train.csv’ saved [5592658/5592658]



In [11]:
!wget  https://challengedata.ens.fr/participants/challenges/59/download/x-test --load-cookies /content/drive/MyDrive/rakuten_challenge/ens.fr_cookies.txt -O /content/rakuten_colour_extraction/data_files/X_test.csv

--2021-03-05 10:33:29--  https://challengedata.ens.fr/participants/challenges/59/download/x-test
Resolving challengedata.ens.fr (challengedata.ens.fr)... 129.199.99.143
Connecting to challengedata.ens.fr (challengedata.ens.fr)|129.199.99.143|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 57856660 (55M) [application/octet-stream]
Saving to: ‘/content/rakuten_colour_extraction/data_files/X_test.csv’


2021-03-05 10:33:32 (15.8 MB/s) - ‘/content/rakuten_colour_extraction/data_files/X_test.csv’ saved [57856660/57856660]



In [12]:
!wget  https://challengedata.ens.fr/participants/challenges/59/download/supplementary-files --load-cookies /content/drive/MyDrive/rakuten_challenge/ens.fr_cookies.txt -O ../supplementary-files

--2021-03-05 10:33:32--  https://challengedata.ens.fr/participants/challenges/59/download/supplementary-files
Resolving challengedata.ens.fr (challengedata.ens.fr)... 129.199.99.143
Connecting to challengedata.ens.fr (challengedata.ens.fr)|129.199.99.143|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2924854699 (2.7G) [application/octet-stream]
Saving to: ‘../supplementary-files’


2021-03-05 10:36:05 (18.3 MB/s) - ‘../supplementary-files’ saved [2924854699/2924854699]



In [13]:
!tar -zxf ../supplementary-files -C /content/rakuten_colour_extraction/data_files --checkpoint=.10000

................................

In [14]:
!rm ../supplementary-files

# Data Processing

In [9]:
pd.read_csv(data_path+'X_train.csv',index_col=0)

Unnamed: 0,image_file_name,item_name,item_caption
0,278003_10389968_1.jpg,三協アルミ M.シェード2 梁置きタイプ 片側支持 5818 H30 ポリカーボネート屋根　...,商品番号19235601メーカー三協アルミサイズ幅 1931.0mm × 奥行き 5853....
1,220810_10010506_1.jpg,【40%OFF SALE/セール】30代〜40代 ファッション コーディネート 太サッシュ ...,太サッシュベルトで存在感アップ 柔軟性に優れた馬革を使用 幅が太めで存在感◎ キレイな形が出...
2,207456_10045549_1.jpg,下駄 桐 日本製 女性用 TONE 鼻緒巾が広め 黒塗り台 適合足サイズ 23〜24.5cm...,項目 桐の下駄 ※特別価格にて浴衣、半幅帯（浴衣帯）、巾着等も同時出品中です！ サイズ 下駄...
3,346541_10000214_1.jpg,＼期間限定【1000円OFF】クーポン 発行中／ シューズボックス 幅60 奥行33 15足...,■商品説明 ルーバーシューズボックス60幅のシングルタイプが登場。お部屋に合わせて色、サイズ...
4,240426_10024071_1.jpg,ポスト 郵便ポスト 郵便受け 集合住宅用ポスト 可変式プッシュ錠集合郵便受箱 PKS-M15...,集合住宅用ポスト 可変式プッシュ錠集合郵便受箱 PKS-M15-3 1列3段 暗証番号を自由...
...,...,...,...
212115,332136_10000371_1.jpg,サボテン おしゃれな寄せ植え アニマルカクタス ジラフ アニマルフィギア付き プレゼントに,
212116,286000_12212768_1.jpg,【代金引換不可】【アンドモア】 二つ折り財布 財布 小銭入れ 札入れ カード入れ ウォレット...,【ご注意】※メーカー直送のため代金引換はお受けできません。※代金引換でのご注文はキャンセルさ...
212117,254241_10307285_1.jpg,Love Sam　コットン　フレアスカート XS オフベージュ,商品名Love Sam　コットン　フレアスカート カラーオフベージュ サイズ ( cm )サ...
212118,259814_10002299_1.jpg,壁面収納 リビング 薄型 【送料無料】『耐震機能付リビング・書斎収納SELECT〔セレクト〕...,【代引不可商品です】 こちらの商品はメーカー直送品のため代金引換はご利用いただけません。 お...


In [10]:
image_paths = pd.read_csv(data_path+'X_train.csv',index_col=0)['image_file_name']

In [11]:
labels = pd.read_csv(data_path+'y_train.csv',index_col=0)
labels=labels['color_tags'].apply(ast.literal_eval)
labels

0                       [Silver, Grey, Black]
1                              [Brown, Black]
2                              [White, Black]
3                       [Beige, Brown, Black]
4                                    [Silver]
                         ...                 
212115                                [Brown]
212116    [Red, Black, Multiple Colors, Navy]
212117                                [Beige]
212118                         [White, Brown]
212119                           [Blue, Navy]
Name: color_tags, Length: 212120, dtype: object

In [12]:
mlb = MultiLabelBinarizer()
onehot_labels = mlb.fit_transform(labels)
classes_correp = mlb.classes_

In [13]:
classes_correp

array(['Beige', 'Black', 'Blue', 'Brown', 'Burgundy', 'Gold', 'Green',
       'Grey', 'Khaki', 'Multiple Colors', 'Navy', 'Orange', 'Pink',
       'Purple', 'Red', 'Silver', 'Transparent', 'White', 'Yellow'],
      dtype=object)

In [14]:
n_classes = len(classes_correp)
n_classes

19

In [15]:
image_dataset = ImageDataset(image_paths,
                             data_path+'images/',
                             onehot_labels)

# BERT FineTuning

In [None]:
MAX_LEN = 200
LEARNING_RATE = 1e-05
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
X_path = '/content/drive/MyDrive/rp/X_train_12tkObq.csv'
y_path = '/content/drive/MyDrive/rp/y_train_Q9n2dCu.csv'
train_dataset = Bert_dataset(X_path, y_path, MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)
model = Bert_classifier(len(train_dataset.colors_dict))
model.to(device)
optimizer = torch.optim.Adam(params =  model.parameters(), lr=LEARNING_RATE)
train(1, train_loader, device, model, optimizer)

# Image Models

In [22]:
torch.hub.list('facebookresearch/deit:main')

Using cache found in /root/.cache/torch/hub/facebookresearch_deit_main


[]

In [22]:
!rm -r /root/.cache/torch/hub/facebookresearch_deit_main

In [23]:
torch.hub.load('facebookresearch/deit:main', 'deit_base_patch16_224', pretrained=True)

Downloading: "https://github.com/facebookresearch/deit/archive/main.zip" to /root/.cache/torch/hub/main.zip


RuntimeError: ignored

In [11]:
class Deit(pl.LightningModule):
    def __init__(self, lr):
        super().__init__()
        self.save_hyperparameters()
        self.backbone = torch.hub.load('facebookresearch/deit:main', 'deit_base_distilled_patch16_384', pretrained=True)
        self.backbone.head.fc = nn.Linear(self.backbone.head.fc.in_features, 19)
        self.sigmoid = nn.Sigmoid()
        self.criterium = nn.BCEWithLogitsLoss()

        self.acc_train = Accuracy()
        self.f1_train = F1(num_classes=19, average='weighted')

        self.acc_val = Accuracy()
        self.f1_val = F1(num_classes=19, average='weighted')

        self.acc_test = Accuracy()
        self.f1_test = F1(num_classes=19, average='weighted')

    def forward(self, x):
        embedding = self.backbone(x)
        return self.sigmoid(embedding)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        labels = self.backbone(images)
        loss = self.criterium(labels, targets)
        self.log('train_loss', loss, on_epoch=True,on_step=True)
        self.acc_train(torch.sigmoid(labels), targets.long())
        self.f1_train(torch.sigmoid(labels), targets.long())
        return loss

    def training_epoch_end(self, loss):
        self.log('train_acc', self.acc_train.compute())
        self.log('train_f1_score', self.f1_train.compute())
        self.acc_train.reset()
        self.f1_train.reset()

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        labels = self.backbone(images)
        loss = self.criterium(labels, targets)
        self.log('valid_loss', loss, on_epoch=True)
        self.acc_val(torch.sigmoid(labels), targets.long())
        self.f1_val(torch.sigmoid(labels), targets.long())
    
    def validation_epoch_end(self, loss):
        self.log('val_acc', self.acc_val.compute())
        self.log('val_f1_score', self.f1_val.compute())
        self.acc_val.reset()
        self.f1_val.reset()
    
    def test_step(self, batch, batch_idx):
        images, targets = batch
        labels = self.backbone(images)
        self.acc_test(torch.sigmoid(labels), targets.long())
        self.f1_test(torch.sigmoid(labels), targets.long())
    
    def test_epoch_end(self, loss):
        self.log('test_acc', self.acc_test.compute())
        self.log('test_f1_score', self.f1_test.compute())
        self.acc_test.reset()
        self.f1_test.reset()

    def configure_optimizers(self):
        return torch.optim.AdamW(self.parameters(), lr=self.hparams.lr)

In [12]:
np.random.seed(42)
idx = np.random.permutation(len(image_dataset))
sep = int(len(image_dataset)*0.9)
idx_train, idx_val = idx[:sep], idx[sep:]
train_set, val_set= torch.utils.data.Subset(image_dataset, idx_train), torch.utils.data.Subset(image_dataset, idx_val)

# train_transform = transforms.Compose([
#     transforms.Resize((300,300)),
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
# ])

train_transform = create_transform(
            input_size=300,
            is_training=True,
            auto_augment='rand-m9-mstd0.5-inc1',
            interpolation='bilinear'
        )
val_transform = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_set.dataset.transform = train_transform
val_set.dataset.transform = val_transform

train_loader = DataLoader(train_set,
                          shuffle=True,
                          num_workers=8,
                          batch_size=32)
val_loader = DataLoader(val_set,
                          shuffle=False,
                          num_workers=8,
                          batch_size=32)

In [13]:
wandb.init(project='Rakuten-colour-classification')

[34m[1mwandb[0m: Currently logged in as: [33mnicolas-dufour[0m (use `wandb login --relogin` to force relogin)


In [14]:
logger = WandbLogger()

checkpoint_callback = ModelCheckpoint(
     mode ='max',
     monitor='val_f1_score',
     dirpath='/content/drive/MyDrive/rakuten_challenge/models',
    filename='resnet18-all-data-{epoch:02d}-{val_f1_score:.2f}'
)

trainer = pl.Trainer(
    gpus=1,
    logger=logger,
    callbacks = [checkpoint_callback]
)

model = NFNet(lr=1e-4)


GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [None]:
lr_finder = trainer.tuner.lr_find(model,train_loader)

# Results can be found in
print(lr_finder.results)

# Plot with
fig = lr_finder.plot(suggest=True)
fig.show()

print(lr_finder.suggestion())


In [15]:
trainer.fit(model, train_loader, val_loader)


  | Name      | Type              | Params
------------------------------------------------
0 | backbone  | NormFreeNet       | 68.5 M
1 | sigmoid   | Sigmoid           | 0     
2 | criterium | BCEWithLogitsLoss | 0     
3 | acc_train | Accuracy          | 0     
4 | f1_train  | F1                | 0     
5 | acc_val   | Accuracy          | 0     
6 | f1_val    | F1                | 0     
7 | acc_test  | Accuracy          | 0     
8 | f1_test   | F1                | 0     
------------------------------------------------
68.5 M    Trainable params
0         Non-trainable params
68.5 M    Total params
273.899   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…






1

In [16]:
del model, trainer
torch.cuda.empty_cache()

In [17]:
model = NFNet.load_from_checkpoint('/content/drive/MyDrive/rakuten_challenge/models/resnet18-all-data-epoch=05-val_f1_score=0.65.ckpt').to('cuda')
image_paths = pd.read_csv(data_path+'X_test.csv',index_col=0)['image_file_name']
test_transform = transforms.Compose([
    transforms.Resize((300,300)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
test_set = TestImageDataset(image_paths, data_path+'images/',transform=test_transform)
test_loader = DataLoader(test_set,
                         shuffle=False,
                         batch_size=32,
                         num_workers=8)
output_df = pd.DataFrame(columns=['color_tags'])
for i,(idx,images) in enumerate(tqdm(test_loader)):
    model.eval()
    labels = model(images.to('cuda')).cpu().detach().numpy()
    labels_hard = labels>0.5
    colors = [list(classes_correp[t.nonzero()[0]]) for t in labels_hard]
    output_df_inter = pd.DataFrame(columns=['color_tags'])
    output_df_inter['color_tags'] = colors
    output_df = pd.concat([output_df, output_df_inter])
output_df = output_df.reset_index()
del output_df['index']

HBox(children=(FloatProgress(value=0.0, max=1168.0), HTML(value='')))




In [24]:
output_df

Unnamed: 0,color_tags
0,[]
1,[Black]
2,[Khaki]
3,[Navy]
4,[Grey]
...,...
37342,[White]
37343,"[Black, White]"
37344,"[Black, White]"
37345,[Green]


In [18]:
output_df.to_csv('/content/drive/MyDrive/rakuten_challenge/submissions/submission_3.csv')