In [11]:
import cv2
import PIL
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import tqdm
import os
import random

from PIL import Image
from torchvision import transforms as T
from torchvision.models.alexnet import AlexNet
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152
from torch.utils.data import IterableDataset, DataLoader, Dataset
from sklearn.metrics import classification_report, accuracy_score, f1_score, average_precision_score
from sklearn.model_selection import train_test_split, KFold

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


def linear_combination(x, y, epsilon): 
    return epsilon*x + (1-epsilon)*y


def reduce_loss(loss, reduction='mean'):
    return loss.mean() if reduction=='mean' else loss.sum() if reduction=='sum' else loss


class LabelSmoothingCrossEntropy(nn.Module):
    def __init__(self, epsilon:float=0.1, reduction='mean'):
        super().__init__()
        self.epsilon = epsilon
        self.reduction = reduction
    
    def forward(self, preds, target):
        n = preds.size()[-1]
        log_preds = F.log_softmax(preds, dim=-1)
        loss = reduce_loss(-log_preds.sum(dim=-1), self.reduction)
        nll = F.nll_loss(log_preds, target, reduction=self.reduction)
        return linear_combination(loss/n, nll, self.epsilon)
    
    
class BonzTrainDataset(Dataset):
    def __init__(self, data):
        super(BonzTrainDataset, self).__init__()
        self.image_prefixes = data.image_prefixes.values
        self.features = data.features.values
        self.img_tensors = data.img_tensors.values
        if 'labels' in data:
            self.labels = data.labels.values
            self.one_hot_labels = data.one_hot_labels.values
        else:
            self.labels=None
    
    def __len__(self):
        #return 3000
        return len(self.image_prefixes)**2
    
    def __getitem__(self, idx):
        ''' Full data'''
        x1 = idx // len(self.image_prefixes)
        x2 = idx % len(self.image_prefixes)
        
        
        ''' Random select
        x1 = random.randint(0, len(self.image_prefixes)-1)
        x2 = random.randint(0, len(self.image_prefixes)-1)
        
        if idx%3 == 0:
            while self.labels[x1] != self.labels[x2]:
                x2 = random.randint(0, len(self.image_prefixes)-1)
        else:
            while self.labels[x1] == self.labels[x2]:
                x2 = random.randint(0, len(self.image_prefixes)-1)
        '''
        
        outputs = (self.img_tensors[x1], 
                   self.features[x1], 
                   self.img_tensors[x2], 
                   self.features[x2],)
        if self.labels[x1] == self.labels[x2]:
            outputs += (torch.tensor([1]),)
        else:
            outputs += (torch.tensor([0]),)
        return outputs
    
        
class Bonz(nn.Module):
    def __init__(self, hidden_dim=100, feature_selection=12):
        super(Bonz, self).__init__()
        self.resnet = resnet50(True)
        self.BiLSTM = nn.LSTM(2048, hidden_dim, num_layers=2, batch_first=True, bidirectional=True)
        self.bn = nn.BatchNorm1d(hidden_dim*2 + feature_selection)
        self.classifier = nn.Sequential(
            nn.Linear(hidden_dim*2 + feature_selection, hidden_dim*2 + feature_selection),
            nn.LeakyReLU(0.001, inplace=True),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim*2 + feature_selection, 1)
        )
        self.classifier.apply(self._init_weights)
        self.classifier.apply(self._xavier)
        self.BiLSTM.apply(self._xavier)
    
    def forward(self, x1, x1_f, x2, x2_f):
        ''' PROCESS X1'''
        # Generate featuress from each images
        x = []
        for img in x1:
            temp = self.do_resnet(img)
            x.append(temp.unsqueeze(0))
        x = torch.cat(x, 0)
        
        # LSTM step
        x, _ = self.BiLSTM(x)
        lstm_features = x[:,-1,:]
        
        # Concate features
        x1 = torch.cat([lstm_features, x1_f], -1)
        #x1 = self.bn(x1)
        
        ''' PROCESS X2'''
        # Generate featuress from each images
        x = []
        for img in x2:
            temp = self.do_resnet(img)
            x.append(temp.unsqueeze(0))
        x = torch.cat(x, 0)
        
        # LSTM step
        x, _ = self.BiLSTM(x)
        lstm_features = x[:,-1,:]
        
        # Concate features
        x2 = torch.cat([lstm_features, x2_f], -1)
        #x2 = self.bn(x2)
        
        ''' DIFFERENCE BETWEEN X1 and X2'''
        dif = torch.abs(x1-x2)   
        
        predict = self.classifier(dif)

        return (predict, dif,)
    
    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)):
            module.weight.data.normal_(mean=0.0, std=0.02)
            if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
            
    def _xavier(self, module):
        for name, param in module.named_parameters():
            if 'weight' in name:
                nn.init.xavier_normal_(param)
            elif 'bias' in name:
                param.data.zero_()
    
    def do_resnet(self, img):
        x = self.resnet.conv1(img)
        x = self.resnet.bn1(x)
        x = self.resnet.relu(x)
        x = self.resnet.maxpool(x)

        x = self.resnet.layer1(x)
        x = self.resnet.layer2(x)
        x = self.resnet.layer3(x)
        x = self.resnet.layer4(x)

        x = self.resnet.avgpool(x)
        x = torch.flatten(x, 1)
        
        return x
    
    def freeze_resnet(self):
        for w in self.resnet.parameters():
            w.requires_grad = False
            
    def unfreeze_resnet(self):
        for w in self.resnet.parameters():
            w.requires_grad = True 

In [2]:
def get_data(path):
    data = pd.read_csv(path)
    resnet_idx = list(data.columns).index('data_AUTOGRAPHER_RESNET_mean_tench, Tinca tinca')
    new_data = data.iloc[:,:resnet_idx]
    new_data = new_data.drop([col_len for col_len in new_data.keys() if '_len' in col_len], 1) # Drop columns with _LEN
    new_data['labels'] = [int(i[-2:])-1 for i in new_data.event_id.values]
    new_data['image_prefixes'] = list(map(lambda x,y,z: str(x)+'_'+str(y)+'_'+str(z), 
                                      new_data.sub_id.values, 
                                      new_data.source.values, 
                                      new_data.event_id.values))
    new_data['one_hot_labels'] = list(map(lambda x: nn.functional.one_hot(torch.tensor(x), 20).float(), new_data.labels.values))
    new_data['features'] = [torch.tensor(i).float() for i in new_data.iloc[:, 3:-3].values]
    new_data['img_tensors'] = get_img_tensors(new_data['image_prefixes'].values)
    return new_data.iloc[:, -5:]


def get_test_data(path):
    data = pd.read_csv(path, index_col=0)
    resnet_idx = list(data.columns).index('data_AUTOGRAPHER_RESNET_mean_tench, Tinca tinca')
    new_data = data.iloc[:,:resnet_idx]
    new_data = new_data.drop([col_len for col_len in new_data.keys() if '_len' in col_len], 1) # Drop columns with _LEN
    new_data['image_prefixes'] = list(map(lambda x,y: str(x)+'_pred'+str(y), 
                                      new_data.sub_id.values, 
                                      new_data.event_id.values))
    new_data['features'] = [torch.tensor(i).float() for i in new_data.iloc[:, 3:-1].values]
    new_data['img_tensors'] = get_img_tensors(new_data['image_prefixes'].values)
    return new_data.iloc[:, [0,1,2,-3,-2,-1]]


def get_img_tensors(image_prefixes):
    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    temp = []
    
    for img_prefix in tqdm.notebook.tqdm(image_prefixes):
        img_paths = []
        i = 0
        img_name = img_prefix+'_'+str(i)+'.jpg'
        while img_name in os.listdir('./OUTPUT_MERGED/AUTOGRAPHER/'):
            img_paths.append('./OUTPUT_MERGED/AUTOGRAPHER/'+img_name)
            i += 1
            img_name = img_prefix+'_'+str(i)+'.jpg'

        # Transform images to tensors
        img_tensors = []
        for path in img_paths:
            img = Image.open(path)
            img_tensors.append(transform(img))
            
        # padding img tensors
        if len(img_tensors) < 16:
            dump = torch.zeros((3,224,224)).float()
            dump = [dump] * (16 - len(img_tensors))
            img_tensors.extend(dump)
            
        temp.append(torch.stack(img_tensors,0))
    
    return temp

def check_params(model):
    model.freeze_resnet()
    print(sum([i.numel() for i in model.parameters() if i.requires_grad]))
    model.unfreeze_resnet()
    print(sum([i.numel() for i in model.parameters() if i.requires_grad]))
    

        
train_data = get_data('train_min_max.csv')
train_data = train_data.sort_values(by='image_prefixes', ignore_index=True)
test_data = get_test_data('test_min_max.csv')

HBox(children=(FloatProgress(value=0.0, max=280.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, max=140.0), HTML(value='')))




# Init MODEL

In [3]:
model = Bonz(hidden_dim=2048, feature_selection=train_data.features[0].shape[0])
check_params(model)

185377037
210934069


# Create Dataset & Dataloader

In [12]:
dataset = BonzTrainDataset(train_data)
dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

dataset.__len__()

78400

# Find Best Learning Rate

In [10]:
model.freeze_resnet()

model.to(DEVICE)
torch.save(model.state_dict(), 'origin_sd.pt')
model.train()

start_lr = 1e-4
lr_find_epochs = 5

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), start_lr)
criterion = nn.BCEWithLogitsLoss()



# Make lists to capture the logs
lr_find_acc = []
lr_find_loss = []
lr_find_lr = []


for _ in range(2):
    scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, 
                                                  base_lr=0, 
                                                  max_lr=start_lr*10, 
                                                  step_size_up=lr_find_epochs,
                                                  cycle_momentum=False)
    scheduler.step()
    
    for i in tqdm.notebook.trange(lr_find_epochs):
        # Load origin state dict
        model.load_state_dict(torch.load('origin_sd.pt'))

        predicts = []
        y_true = []
        total_loss = 0

        for x1, x1_f, x2, x2_f, label in tqdm.notebook.tqdm(dataloader, desc='Training: '):

            x1 = [ts.to(DEVICE) for ts in x1]
            x1_f = x1_f.to(DEVICE)
            x2 = [ts.to(DEVICE) for ts in x2]
            x2_f = x2_f.to(DEVICE)
            label = label.to(DEVICE)

            optimizer.zero_grad()

            predict = model(x1, x1_f, x2, x2_f)[0]
            loss = criterion(predict, label.float())

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            predict = predict.detach().cpu()
            predict = torch.sigmoid(predict)
            predicts.extend(predict.tolist())
            y_true.extend(label.detach().cpu().tolist())

        train_acc = accuracy_score(np.array(y_true), np.array(predicts)>0.5)
        lr_step = optimizer.state_dict()["param_groups"][0]["lr"]
        print(f'epoch={i}, Acc={train_acc*100:.2f}, Loss={loss:.2f}, LR={lr_step:.2e}')

        lr_find_lr.append(lr_step)
        lr_find_acc.append(train_acc)
        lr_find_loss.append(total_loss)

        scheduler.step()
    
    start_lr *= 10
        
        
for a, b, c in zip(lr_find_lr, lr_find_acc, lr_find_loss):
    print(f'{a}\t{b}\t{c}')



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=0, Acc=93.83, Loss=0.17, LR=2.00e-04


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=1, Acc=93.73, Loss=0.09, LR=4.00e-04


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=2, Acc=93.90, Loss=0.12, LR=6.00e-04


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=3, Acc=93.70, Loss=0.17, LR=8.00e-04


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=4, Acc=94.00, Loss=0.04, LR=1.00e-03



HBox(children=(FloatProgress(value=0.0, max=5.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=0, Acc=94.00, Loss=0.25, LR=2.00e-03


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=1, Acc=93.93, Loss=0.09, LR=4.00e-03


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=2, Acc=93.80, Loss=0.13, LR=6.00e-03


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=3, Acc=94.03, Loss=0.21, LR=8.00e-03


HBox(children=(FloatProgress(value=0.0, description='Training: ', max=47.0, style=ProgressStyle(description_wi…


epoch=4, Acc=94.07, Loss=0.09, LR=1.00e-02

0.00020000000000000017	0.9383333333333334	9.58180908113718
0.0003999999999999999	0.9373333333333334	9.230482377111912
0.0006000000000000001	0.939	8.963266298174858
0.0007999999999999998	0.937	8.816364150494337
0.001	0.94	8.94610458984971
0.0020000000000000018	0.94	8.18113087117672
0.003999999999999999	0.9393333333333334	9.424686886370182
0.006000000000000001	0.938	8.614524226635695
0.007999999999999998	0.9403333333333334	8.83814811706543
0.01	0.9406666666666667	9.135346800088882


In [13]:
model.freeze_resnet()

model.to(DEVICE)
model.train()
model.load_state_dict(torch.load('origin_sd.pt'))

start_lr = 1e-3
lr_find_epochs = 20

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), start_lr)
criterion = nn.BCEWithLogitsLoss()


# Make lists to capture the logs
best_metrics = {'train_loss': 1e10, 
                'train_acc': 0}

metrics = {'train_loss': [], 
           'train_acc': []}

    
for e in tqdm.notebook.trange(lr_find_epochs):

    predicts = []
    y_true = []
    total_loss = 0

    for x1, x1_f, x2, x2_f, label in tqdm.notebook.tqdm(dataloader):

        x1 = [ts.to(DEVICE) for ts in x1]
        x1_f = x1_f.to(DEVICE)
        x2 = [ts.to(DEVICE) for ts in x2]
        x2_f = x2_f.to(DEVICE)
        label = label.to(DEVICE)

        optimizer.zero_grad()

        predict = model(x1, x1_f, x2, x2_f)[0]
        loss = criterion(predict, label.float())

        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        predict = predict.detach().cpu()
        predict = torch.sigmoid(predict)
        predicts.extend(predict.tolist())
        y_true.extend(label.detach().cpu().tolist())

    train_acc = accuracy_score(np.array(y_true), np.array(predicts)>0.5)
    print(f'epoch={e}, Acc={train_acc*100:.2f}, Loss={loss:.2f}')

    metrics['train_acc'].append(train_acc)
    metrics['train_loss'].append(total_loss)
    
    if total_loss < best_metrics['train_loss']:
        best_metrics['train_acc'] = train_acc
        best_metrics['train_loss'] = total_loss
        
        torch.save(model.state_dict(), 'best_model.pt')
        print(f'BEST MODEL at Epoch = {e}')

print(metrics)

HBox(children=(FloatProgress(value=0.0, max=20.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=0, Acc=95.52, Loss=0.10
BEST MODEL at Epoch = 0


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=1, Acc=96.49, Loss=0.09
BEST MODEL at Epoch = 1


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=2, Acc=97.83, Loss=0.09
BEST MODEL at Epoch = 2


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=3, Acc=98.78, Loss=0.01
BEST MODEL at Epoch = 3


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=4, Acc=99.02, Loss=0.03
BEST MODEL at Epoch = 4


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=5, Acc=99.11, Loss=0.02
BEST MODEL at Epoch = 5


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=6, Acc=99.35, Loss=0.00
BEST MODEL at Epoch = 6


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=7, Acc=99.38, Loss=0.00
BEST MODEL at Epoch = 7


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=8, Acc=99.58, Loss=0.01
BEST MODEL at Epoch = 8


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=9, Acc=99.54, Loss=0.01


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=10, Acc=99.61, Loss=0.00
BEST MODEL at Epoch = 10


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=11, Acc=99.60, Loss=0.01


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=12, Acc=99.74, Loss=0.03
BEST MODEL at Epoch = 12


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=13, Acc=99.69, Loss=0.03


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=14, Acc=99.66, Loss=0.01


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=15, Acc=99.72, Loss=0.05
BEST MODEL at Epoch = 15


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=16, Acc=99.58, Loss=0.00


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=17, Acc=99.74, Loss=0.01
BEST MODEL at Epoch = 17


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=18, Acc=99.77, Loss=0.00
BEST MODEL at Epoch = 18


HBox(children=(FloatProgress(value=0.0, max=1225.0), HTML(value='')))


epoch=19, Acc=99.87, Loss=0.00
BEST MODEL at Epoch = 19

{'train_loss': [163.11059016361833, 112.17023746995255, 67.29989651683718, 41.6530455268221, 33.095375760109164, 29.133335248188814, 22.520687466618256, 22.507519800423324, 14.685416374288252, 15.943816489012534, 13.93064460644473, 14.114390472642754, 10.378330527740502, 11.608633353419918, 12.348099744272758, 9.670739976643745, 14.13293377559603, 9.330683657447025, 8.788927955342388, 5.551301488467743], 'train_acc': [0.9552168367346939, 0.9648724489795918, 0.978265306122449, 0.9878188775510204, 0.9902168367346939, 0.9910586734693878, 0.9934566326530613, 0.99375, 0.9958163265306123, 0.9954464285714286, 0.9961479591836735, 0.9959821428571428, 0.9974489795918368, 0.9968622448979592, 0.996594387755102, 0.9971556122448979, 0.9958418367346938, 0.997359693877551, 0.9976530612244898, 0.9987372448979592]}


In [30]:
accuracy_score(np.array(y_true).squeeze(), (np.array(predicts)>0.5).squeeze())

0.7453333333333333

In [20]:
model.eval()
#model.load_state_dict(torch.load('best_model.pt'))
predict_matrix = []
val_loss = 0
val_data_loader = DataLoader(BonzDataset(train_data), batch_size=32)
for img_tensors, features, one_hot_label, label in val_data_loader:
    img_tensors = [ts.to(DEVICE) for ts in img_tensors]
    features = features.to(DEVICE)
    one_hot_label = one_hot_label.to(DEVICE)
    label = label.to(DEVICE)

    with torch.no_grad():
        predict = model(img_tensors, features)
    predict = predict.detach().cpu()
    predict_matrix.extend(torch.softmax(predict, 1))
    
onehot_true = np.array(list(i.numpy() for i in train_data.one_hot_labels.values))
onehot_predict = torch.stack(predict_matrix).numpy()

average_precision_score(onehot_true, onehot_predict)

0.8264700613043315

In [18]:
np.argmax(onehot_predict, 1)

array([11, 11,  7,  3,  5,  5,  6,  7,  6,  9, 14, 11, 12, 13, 14,  1, 16,
       17, 18, 19, 10,  1,  1,  3,  3,  9,  8, 14,  5, 10, 10,  1, 18, 13,
       13, 10, 16, 17, 18, 19,  2,  1,  0,  2,  4,  5,  6,  7,  8, 14, 15,
       10, 15, 13, 14,  1, 19, 19, 18, 19,  0, 11, 10, 10, 11,  9,  6, 11,
        5,  5,  1, 11, 12, 14, 12, 15, 12, 19, 18, 19,  0, 10,  0,  7, 16,
        7,  6,  7,  8,  9, 10, 10,  6, 13, 14,  1, 16, 17, 18, 19,  0, 15,
        0,  3, 10,  5, 16,  7,  6,  9, 10, 15, 12, 13, 11,  3, 16, 17, 18,
       19,  0,  1,  4, 11,  4,  5,  6,  7,  5,  5,  1,  1, 12, 13, 14, 15,
       17, 19, 18, 19,  2,  1,  1,  3,  5,  7,  6,  7,  8,  9, 10,  5, 12,
       13, 14, 15, 16, 17, 18, 19,  0,  1,  2,  3, 11,  5,  8,  7,  8,  9,
       10,  1, 12, 13, 14, 15, 16, 17, 18, 19,  0,  1,  2,  3,  4,  0,  6,
        7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 11,  1,  1, 15,
        4,  5,  6,  7,  8,  9, 10,  1, 12, 13, 14, 15, 16, 17, 18, 19,  0,
        1,  2,  3, 11,  5

In [21]:
model.eval()
predict_matrix = []
val_data_loader = DataLoader(BonzDataset(test_data), batch_size=32)
for img_tensors, features in val_data_loader:
    img_tensors = [ts.to(DEVICE) for ts in img_tensors]
    features = features.to(DEVICE)
    with torch.no_grad():
        predict = model(img_tensors, features)
    predict = predict.detach().cpu()
    predict_matrix.extend(torch.softmax(predict, 1))

onehot_predict = torch.stack(predict_matrix).numpy()

COLUMN_NAMES = ['act'+str(i)+str(j) for i in range(2) for j in range(10)]
COLUMN_NAMES.pop(0)
COLUMN_NAMES.append('act20')

df_prob = pd.DataFrame(data=onehot_predict, columns=COLUMN_NAMES)
df_prob['event'] = test_data['image_prefixes']

submission = []
for act in list(df_prob.keys())[:-1]:
    ranked_ = df_prob.sort_values(by=[act], ascending=False)['event'].values
    ranked_ = act+' '+ranked_
    submission.extend(ranked_)

pd.DataFrame(submission).to_csv('submission.txt', header=False, index=False)

np.argmax(onehot_predict, 1)

array([10, 10, 17, 17,  3, 14,  7,  3,  7,  3, 11, 14,  3, 14, 14, 14,  3,
       13, 14,  3,  3, 14,  3, 14,  3, 14, 14,  3, 14, 14,  3,  3, 14, 14,
        3,  3, 14, 14,  3,  3,  3,  6, 13, 13, 14, 14,  3,  3,  3, 14,  3,
       14, 13, 14,  3, 14,  3, 14,  3,  3,  3,  3,  3, 14, 14, 14,  3,  7,
        3,  3, 14, 14, 13,  3, 14, 14, 14, 14, 14,  3, 14, 14,  3, 14,  3,
       14,  3, 14, 14,  3, 13, 14, 14,  3,  3, 14, 14, 14, 14,  3,  3,  3,
        3, 14, 14, 14, 14, 14, 14, 13,  3,  3, 14,  3,  6, 14,  3, 16, 14,
        3, 11, 13, 14,  7, 14,  3, 14,  3,  3, 14, 14, 14,  3,  7, 14,  3,
        3,  3, 14, 14], dtype=int64)