In [1]:
import boto3
s3 = boto3.client('s3')
s3.download_file('sagemaker-us-east-1-953585160895',
                 'data/5eaec695e1970/preprocess_data.txt', 
                 'preprocess_data.txt')

In [3]:
import pandas as pd
import os


column_names = ['_id', 'message', 'image_concept', 'image', 'published', 'disabled']
target_names = ['published', 'disabled']


df = pd.read_csv('preprocess_data.txt', names=column_names, engine='python')
df

Unnamed: 0,_id,message,image_concept,image,published,disabled
0,5e55dca437fa5927dcdf02f3,en route to gbr embrace the elevation in luxur...,nature travel diving water sea underwater ocea...,https://scontent.xx.fbcdn.net/v/t51.2885-15/81...,1,0
1,5e55d69eb9e5b725cd7ba02f,golf course views ⛳ #hamiltonislandgolfcourse ...,outdoors landscape beach sky nature rural no p...,https://scontent.xx.fbcdn.net/v/t51.2885-15/87...,1,0
2,5e55ca8d0f1aeb23b862f240,hamo family vay kay #hamiltonisland,boat water sunglasses leisure recreation one s...,https://scontent.xx.fbcdn.net/v/t51.2885-15/87...,1,0
3,5e55a83c75d14b7257d1aceb,news: entries booming for hamilton island race...,audience sports fan marathon people crowd grou...,https://uploads-cdn.stackla.com/10/hamiltonisl...,1,0
4,5e55a3e657f15e175b8fa58f,my series on the great barrier reef and surrou...,vintage texture no person abstract desktop nat...,https://scontent.xx.fbcdn.net/v/t51.2885-15/87...,1,0
...,...,...,...,...,...,...
1995,5e47bb6b1096abb55d1873d3,what a weekend! this place is amazing 😍 #hamil...,person human vehicle transportation golf cart,https://scontent.xx.fbcdn.net/v/t51.2885-15/85...,0,1
1996,5e47bb6b1096abb55d1873dc,what a weekend! this place is amazing 😍 #hamil...,land nature outdoors shoreline water sea ocean...,https://scontent.xx.fbcdn.net/v/t51.2885-15/84...,0,1
1997,5e47bb691096abb55d1873cb,"this time last year, chilling at the whitsunda...",land nature outdoors water sea ocean shoreline...,https://scontent.xx.fbcdn.net/v/t51.2885-15/84...,0,1
1998,5e47b9e37f9fcbb333a170ab,back when i was live living the hamilton islan...,human person clothing apparel food pork,https://scontent.xx.fbcdn.net/v/t51.2885-15/80...,0,1


In [4]:
published_count = len(df.loc[df['published'] == 1])
disabled_count = len(df.loc[df['disabled'] == 1])
print(f"published {published_count} disabled {disabled_count}") 
      
      

published 500 disabled 1500


In [13]:
import argparse
import logging
import os
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision
import torchvision.models
import torchvision.transforms as transforms
import torch.nn.functional as F
from flair.data_fetcher import NLPTaskDataFetcher
from flair.embeddings import WordEmbeddings, FlairEmbeddings, DocumentRNNEmbeddings, Sentence
from flair.models import TextClassifier
from flair.trainers import ModelTrainer
from pathlib import Path
import numpy as np
import random
from flair.training_utils import store_embeddings


train_on_gpu = torch.cuda.is_available()

def oversample_df(df, classes):
    classes_count = []
    for c in classes:    
        classes_count.append(len(df.loc[df[c] == 1]))
    
    max_count = max(classes_count)
    resample_ratios = [round(max_count/c) for c in classes_count]
            
    resampled = []
    for i in range(len(resample_ratios)):
        c = classes[i]
        ratio = resample_ratios[i]        
        for r in range(ratio):            
            resampled.append(df.loc[df[c] == 1])
            
    resampled_df = pd.concat(resampled, ignore_index=True)
    resampled_df = resampled_df.sample(frac=1)
    resampled_df = resampled_df.reset_index(drop=True)
    return resampled_df


def undersample_df(df, classes):
    classes_count = []
    for c in classes:    
        classes_count.append(len(df.loc[df[c] == 1]))
    
    min_count = min(classes_count)
    
    resampled = []
    for c in classes:
        resampled.append(df[df[c] == 1][:min_count])
        
    resampled_df = pd.concat(resampled, ignore_index=True)
    resampled_df = resampled_df.sample(frac=1)
    resampled_df = resampled_df.reset_index(drop=True)
    return resampled_df


def get_batches(df, target_names, mode=None, batch_size=16):
    if mode == 'oversample':
        df = oversample_df(df, target_names)
    elif mode == 'undersample':
        df = undersample_df(df, target_names)
        
    df = df.sample(frac=1).reset_index(drop=True)
    for i in range(0, len(df), batch_size):
        ids = []
        x = []
        y = []
        for _, row in df[i:i+batch_size].iterrows():
            
            image_concept = '' if pd.isna(row['image_concept']) else row['image_concept']
            message = '' if pd.isna(row['message']) else row['message']                        
            
            # shuffle image concepts
            words = image_concept.split()
            random.shuffle(words)
            image_concept = ' '.join(words)
            
            # join message and image_concept together
            txt = ' '.join([message, image_concept])                    
            x.append(Sentence(txt))                        
            y.append([row[t] for t in target_names])
            ids.append(row['_id'])
        
        yield ids, x, torch.FloatTensor(y)

        
def train_model(model, epochs, lr, train_df, val_df, target_names, checkpoint_file, early_stopping=5):        
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    best_loss = np.inf
    no_improvement = 0

    if train_on_gpu:
        model = model.cuda()
        
    for epoch in range(epochs):        
        total_train_loss = 0
        total_val_loss = 0
        train_loss = 0
        val_loss = 0
        
        # Train
        model.train()        
        for i, (ids, sentences, labels) in enumerate(get_batches(train_df, target_names, 'oversample')):
            if train_on_gpu:
                labels = labels.cuda()
            
            
            optimizer.zero_grad()
            
            out = model(sentences)            
                        
            loss = criterion(out, labels)
            loss.backward()
            
            nn.utils.clip_grad_norm_(model.parameters(), 5.0)
            optimizer.step()
            
            total_train_loss += loss.item()
            
            store_embeddings(sentences, 'cpu')
            
            if i % 10 == 0:
                print(f"Epoch {epoch}, batch {i}, train loss {loss.item()/labels.size(0)}")
            
            
        train_loss = total_train_loss/len(train_df)

        
        # Eval
        model.eval()
        all_pred = np.array([])
        all_labels = np.array([])
        for _, sentences, labels in get_batches(val_df, target_names):
            if train_on_gpu:
                labels = labels.cuda()
            

            out = model(sentences)
            loss = criterion(out, labels)
            total_val_loss += loss.item()
            
            store_embeddings(sentences, 'cpu')
            
            pred = torch.round(torch.sigmoid(out))
            pred_np = pred.data.cpu().numpy() if train_on_gpu else pred.data.numpy()
            labels_np = labels.data.cpu().numpy() if train_on_gpu else labels.data.numpy()
            all_pred = np.concatenate([all_pred, pred_np.flatten()])
            all_labels = np.concatenate([all_labels, labels_np.flatten()])

            
        val_loss = total_val_loss / len(val_df)
        f1 = f1_score(all_labels, all_pred, average='weighted')
        acc = accuracy_score(all_labels, all_pred)
        
        print(f"Epoch {epoch}, train loss {train_loss}, val loss {val_loss}, accuracy {acc}, f1 {f1}")
        
        if val_loss < best_loss:
            best_loss = val_loss
            no_improvement = 0
            torch.save(model.state_dict(), checkpoint_file)
            print(f"Save model at Epoch {epoch}, train loss {train_loss}, val loss {val_loss}, accuracy {acc}, f1 {f1}")
        else:
            no_improvement += 1
            print("No improvement.")
            if no_improvement >= early_stopping:
                print(f"Early Stopping at Epoch {epoch}")
                break
                

def get_dfs(data_dir, column_names):    
    csv_file = os.path.join(data_dir, 'preprocess_data.txt')
    df = pd.read_csv(csv_file, names=column_names, engine='python')    
    df = df.loc[df['message'].notnull() & df['image_concept'].notnull()]            
    train_df, validation_df = train_test_split(df, test_size=0.3, random_state=42)    
    train_df = train_df.reset_index(drop=True)
    validation_df = validation_df.reset_index(drop=True)    
    return train_df, validation_df


train_df, validation_df = get_dfs(".", column_names)
    

for target in target_names:
    train_target_count = len(train_df.loc[train_df[target] == 1])
    val_target_count = len(validation_df.loc[validation_df[target] == 1])
    print(f"class: ({target}), train: {train_target_count}, val: {val_target_count}")


document_embeddings = DocumentRNNEmbeddings([
        WordEmbeddings('twitter'),
    ], 
    hidden_size=128,
    reproject_words=True,
    reproject_words_dimension=128
)

classifier = TextClassifier(
    document_embeddings, 
    label_dictionary=target_names,
    multi_label=True
)

print(classifier)
    
checkpoint_file = os.path.join('/tmp', 'model.pt')
lr = 0.001       
train_model(
    classifier, 
    10, 
    lr, 
    train_df, 
    validation_df, 
    target_names, 
    checkpoint_file,
    early_stopping=5
)

print("success!")    

class: (published), train: 342, val: 156
class: (disabled), train: 1050, val: 441
TextClassifier(
  (document_embeddings): DocumentRNNEmbeddings(
    (embeddings): StackedEmbeddings(
      (list_embedding_0): WordEmbeddings('twitter')
    )
    (word_reprojection_map): Linear(in_features=100, out_features=128, bias=True)
    (rnn): GRU(128, 128, batch_first=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Linear(in_features=128, out_features=2, bias=True)
  (loss_function): BCEWithLogitsLoss()
  (beta): 1.0
  (weights): None
  (weight_tensor) None
)
Epoch 0, batch 0, train loss 0.04097532480955124
Epoch 0, batch 10, train loss 0.0422392338514328
Epoch 0, batch 20, train loss 0.05442492663860321
Epoch 0, batch 30, train loss 0.04002945125102997
Epoch 0, batch 40, train loss 0.048217594623565674
Epoch 0, batch 50, train loss 0.03839818015694618
Epoch 0, batch 60, train loss 0.034306056797504425
Epoch 0, batch 70, train loss 0.03502165898680687
Epoch 0, batch 80, train 

Epoch 9, batch 20, train loss 0.03278768062591553
Epoch 9, batch 30, train loss 0.015525519847869873
Epoch 9, batch 40, train loss 0.033642061054706573
Epoch 9, batch 50, train loss 0.021566977724432945
Epoch 9, batch 60, train loss 0.03916717320680618
Epoch 9, batch 70, train loss 0.05242852866649628
Epoch 9, batch 80, train loss 0.03453497588634491
Epoch 9, batch 90, train loss 0.035776812583208084
Epoch 9, batch 100, train loss 0.023819224908947945
Epoch 9, batch 110, train loss 0.024403797462582588
Epoch 9, batch 120, train loss 0.03292492777109146
Epoch 9, train loss 0.04635313132927678, val loss 0.03206972024049391, accuracy 0.7554438860971524, f1 0.7554377104377105
Save model at Epoch 9, train loss 0.04635313132927678, val loss 0.03206972024049391, accuracy 0.7554438860971524, f1 0.7554377104377105
success!


In [14]:
model2 = TextClassifier(
    document_embeddings, 
    label_dictionary=target_names,
    multi_label=True
)

print(model2)

model2.load_state_dict(torch.load(checkpoint_file))


TextClassifier(
  (document_embeddings): DocumentRNNEmbeddings(
    (embeddings): StackedEmbeddings(
      (list_embedding_0): WordEmbeddings('twitter')
    )
    (word_reprojection_map): Linear(in_features=100, out_features=128, bias=True)
    (rnn): GRU(128, 128, batch_first=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Linear(in_features=128, out_features=2, bias=True)
  (loss_function): BCEWithLogitsLoss()
  (beta): 1.0
  (weights): None
  (weight_tensor) None
)


<All keys matched successfully>

In [17]:
def eval_model(model, test_df, target_names):
    if train_on_gpu:
        model = model.cuda()
        
    model.eval()
    all_pred = np.array([])
    all_labels = np.array([])
    for _, sentences, labels in get_batches(test_df, target_names):
        if train_on_gpu:
            labels = labels.cuda()

        out = model(sentences)

        store_embeddings(sentences, 'cpu')

        pred = torch.round(torch.sigmoid(out))
        pred_np = pred.data.cpu().numpy() if train_on_gpu else pred.data.numpy()
        labels_np = labels.data.cpu().numpy() if train_on_gpu else labels.data.numpy()
        all_pred = np.concatenate([all_pred, pred_np.flatten()])
        all_labels = np.concatenate([all_labels, labels_np.flatten()])


    f1 = f1_score(all_labels, all_pred, average='weighted')
    acc = accuracy_score(all_labels, all_pred)

    print(f"Accuracy {acc}, F1 {f1}")
    
eval_model(model2, train_df, target_names)

Accuracy 0.8200431034482759, F1 0.8200419657453925


In [24]:
train_df.loc[train_df['published'] == 1]

Unnamed: 0,_id,message,image_concept,image,published,disabled
5,5e2eb28842964194c85a9e79,リーフビューホテルでは16階の部屋でとにかく眺めが良い！ ホテルにはプールもあって賑やかだっ...,railing animal bird cockatoo parrot wildlife m...,https://scontent.xx.fbcdn.net/v/t51.2885-15/83...,1,0
25,5deed4333a42802f61b1e2b8,wild oats xi and ribco ready for the grinders ...,transportation watercraft vessel vehicle boat ...,https://scontent.xx.fbcdn.net/v/t51.2885-15/75...,1,0
27,5e341b6c76f8896f9132ad1a,"you, me & the sea (2/3) . . . #hamiltonisland ...",person human jet ski transportation vehicle ap...,https://scontent.xx.fbcdn.net/v/t51.2885-15/82...,1,0
29,5e133457e066b13fc80f0fa6,코스 ⛳ 클럽하우스 🍽 챌린지 💪 ⠀ 골프매니아들에게 이미 입소문 난 해밀턴아일랜드...,nature outdoors land water shoreline ocean sea...,https://scontent.xx.fbcdn.net/v/t51.2885-15/79...,1,0
31,5e38c6b822dd29f0a15b98f1,cruising around hamilton island. 🛩💥 #onetreehi...,shoreline water nature outdoors land sea ocean...,https://scontent.xx.fbcdn.net/v/t51.2885-15/83...,1,0
...,...,...,...,...,...,...
1372,5e468192d47bfd56c5265b7f,sunset at one tree hill! #onetreehill #sunset ...,nature outdoors red sky dawn sky dusk sunset s...,https://scontent.xx.fbcdn.net/v/t51.2885-15/83...,1,0
1377,5e49b0960bfa51488b937823,sunny days and mondays.,land outdoors nature shoreline water ocean sea...,https://scontent-iad3-1.cdninstagram.com/v/t51...,1,0
1379,5e1d7c54cbd4d346ef3679ea,“the office”,land outdoors nature water shoreline sea ocean...,https://scontent-lga3-1.cdninstagram.com/v/t51...,1,0
1381,5dfa92b197304f0168d622fe,i don’t know if i was holding her or she was h...,animal mammal wildlife bear koala giant panda,https://scontent.xx.fbcdn.net/v/t51.2885-15/78...,1,0


In [25]:
train_df.loc[train_df['disabled'] == 1]

Unnamed: 0,_id,message,image_concept,image,published,disabled
0,5e490801abcf0016a173a576,whitehaven beach 🤍🐚 - definitely one of our hi...,person human accessories accessory sunglasses ...,https://scontent.xx.fbcdn.net/v/t51.2885-15/84...,0,1
1,5e4bb53a41da64df1c418792,amazing dinner with even more amazing ladies a...,food meal dish seasoning bowl restaurant cafet...,https://scontent.xx.fbcdn.net/v/t51.2885-15/83...,0,1
2,5e4cf285843e6c3e41217c99,pulau hamilton adalah salah satu tujuan libura...,tropical no person paper water turquoise deskt...,https://scontent.xx.fbcdn.net/v/t51.2885-15/84...,0,1
3,5e4d0ccef9b64844676addff,※ honeymoon---6days--part2--✈ . . . . #思い出記録 #...,education child young portrait isolated leisur...,https://scontent.xx.fbcdn.net/v/t51.2885-15/85...,0,1
4,5e50e293cb93eb9052d3264a,just livin' doing that whitsundays thing!!😊😊🦄🦄...,couple leisure ocean water vacation girl summe...,https://scontent.xx.fbcdn.net/v/t51.2885-15/85...,0,1
...,...,...,...,...,...,...
1387,5e4ebcf7ab423de8e751bee4,"still can't find nemo, but found all of his co...",marine underwater deep coral fish scuba ocean ...,https://scontent.xx.fbcdn.net/v/t51.2885-15/87...,0,1
1388,5e4d15140f85f846761e7d2c,action shots of wednesday night twilight saili...,water sailboat sea yacht sail crew ship waterc...,https://scontent.xx.fbcdn.net/v/t51.2885-15/87...,0,1
1389,5e51f931c82fabe3073098cd,one of the most amazing places to me #fun #sum...,fashion foot man shoe footwear girl beach two ...,https://scontent.xx.fbcdn.net/v/t51.2885-15/83...,0,1
1390,5e4bb6c45d0cb5e11f457c9e,#australia #hamiltonisland #pioveecésole #bell...,land nature outdoors water shoreline ocean sea...,https://scontent.xx.fbcdn.net/v/t51.2885-15/84...,0,1


In [22]:
train_df.iloc[0].tolist()

['5e490801abcf0016a173a576',
 'whitehaven beach \U0001f90d🐚 - definitely one of our highlights in australia #tb #whitsundays #australia #queensland #qld #travel #visitaustralia #photography #sea #pacific #nationalpark #hamiltonisland #whitehaven #whitehavenbeach #hamiltonislandair #helicopter',
 'person human accessories accessory sunglasses nature land outdoors electronics window cockpit glasses',
 'https://scontent.xx.fbcdn.net/v/t51.2885-15/84490140_567379633988670_6664859505015865217_n.jpg?_nc_cat=110&_nc_ohc=6M2SDnKcNUQAX8EM1hW&_nc_ht=scontent.xx&oh=b7d3790f673706cf9ecf2916b07c0bb2&oe=5EBCA4CE',
 0,
 1]

In [27]:
def predict(model, input_data):
    # message and image concepts
    text = input_data[1] + ' ' + input_data[2]
    sentences = [Sentence(text)]
    out = model([Sentence(text)])
    store_embeddings(sentences, 'cpu')

    pred = torch.round(torch.sigmoid(out) * 100) / 100
    pred_np = pred.data.cpu().numpy() if train_on_gpu else pred.data.numpy()

    out = []
    for i, p in enumerate(pred_np[0]):
        out.append(str(p))

    return out

sample = train_df.iloc[1372].tolist()
print(sample)
print(predict(model2, sample))


['5e468192d47bfd56c5265b7f', 'sunset at one tree hill! #onetreehill #sunset #hamiltonisland', 'nature outdoors red sky dawn sky dusk sunset sunlight sunrise plant tree abies fir light flare', 'https://scontent.xx.fbcdn.net/v/t51.2885-15/83597253_187296752513549_6081794946749982366_n.jpg?_nc_cat=100&_nc_ohc=DOLnGMq13DIAX-AjFn9&_nc_ht=scontent.xx&oh=26c1c71142328ccd84b0955aa22b3a6a&oe=5ECD421B', 1, 0]
['0.7', '0.31']
